# A Conceptual, Practical Introduction to Trax Layers

This notebook introduces the core concepts and programming components of the Trax library through a series of code samples and explanations. The topics covered in following sections are:

  1. **Layers**: the basic building blocks and how to combine them into networks
  1. **Data Streams**: how individual layers manage inputs and outputs
  1. **Data Stack**: how the Trax runtime manages data streams for the layers
  1. **Defining New Layer Classes**: how to define and test your own layer classes
  1. **Models**: how to train, evaluate, and run predictions with Trax models



## General Setup
Execute the following few cells (once) before running any of the code samples in this notebook.

In [0]:
# Copyright 2018 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as onp  # np used below for trax.backend.numpy



In [0]:
# Import Trax

! pip install -q -U trax
! pip install -q tensorflow

from trax import backend
from trax import layers as tl
from trax import shapes
from trax.backend import numpy as np  # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature

/bin/sh: pip: command not found
/bin/sh: pip: command not found


In [0]:
# Settings and utilities for handling inputs, outputs, and object properties.

onp.set_printoptions(precision=3)  # Reduce visual noise from extra digits.

def show_layer_properties(layer_obj, layer_name):
  template = ('{}.n_in:  {}\n'
              '{}.n_out: {}\n'
              '{}.sublayers: {}\n'
              '{}.weights:    {}\n')
  print(template.format(layer_name, layer_obj.n_in,
                        layer_name, layer_obj.n_out,
                        layer_name, layer_obj.sublayers,
                        layer_name, layer_obj.weights))

def floats_range(start, end):
  return onp.arange(start, end).astype(onp.float32)

# 1. Layers

The Layer class represents Trax's concept of a layer, as summarized in the start of the class's docstring:
```
class Layer(object):
  """Base class for composable layers in a deep learning network.

  Layers are the basic building blocks for deep learning models. A Trax layer
  computes a function from zero or more inputs to zero or more outputs,
  optionally using trainable parameters (common) and non-parameter state (not
  common). Authors of new layer subclasses typically override at most two
  methods of the base `Layer` class:

    forward(inputs, params=(), state=(), **kwargs):
      Computes this layer's output as part of a forward pass through the model.

    new_params_and_state(self, input_signature, rng):
      Returns a (params, state) pair suitable for initializing this layer.
```

## A layer computes a function.

A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.

The simplest layers, those with no parameters, state or sublayers, can be used without initialization. You can think of them (and test them) like simple mathematical functions. For ease of testing and interactive exploration, layer
objects implement the `__call__ ` method, so you can call them directly on input data:
```
y = my_layer(x)
```

Layers are also objects, so you can inspect their properties. For example:
```
print('Number of inputs expected by this layer: {}'.format(my_layer.n_in))
```

### Example 1. tl.Relu $[n_{in} = 1, n_{out} = 1]$

In [0]:
x = floats_range(-7, 8).reshape(3, -1)

# Create a layer object (a Relu instance) and apply the layer to data x.
relu = tl.Relu()
y = relu(x)

# Show input, output, and two layer properties.
template = ('x:\n{}\n\n'
            'relu(x):\n{}\n\n'
            'number of inputs expected by this layer: {}\n'
            'number of outputs promised by this layer: {}')
print(template.format(x, y, relu.n_in, relu.n_out))

x:
[[-7. -6. -5. -4. -3.]
 [-2. -1.  0.  1.  2.]
 [ 3.  4.  5.  6.  7.]]

relu(x):
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 2.]
 [3. 4. 5. 6. 7.]]

number of inputs expected by this layer: 1
number of outputs promised by this layer: 1


### Example 2. tl.Concatenate $[n_{in} = 2, n_{out} = 1]$

In [0]:
x1 = floats_range(-7, 8).reshape(3, -1)
x2 = 10 * x1

concat0 = tl.Concatenate(axis=0)
concat1 = tl.Concatenate(axis=1)

y0 = concat0([x1, x2])
y1 = concat1([x1, x2])

template = ('x1:\n{}\n\n'
            'x2:\n{}\n\n'
            'concat0([x1, x2]):\n{}\n\n'
            'concat1([x1, x2]):\n{}\n')
print(template.format(x1, x2, y0, y1))

# Print abbreviated object representations (useful for debugging).
print('concat0: {}'.format(concat0))
print('concat1: {}'.format(concat1))

x1:
[[-7. -6. -5. -4. -3.]
 [-2. -1.  0.  1.  2.]
 [ 3.  4.  5.  6.  7.]]

x2:
[[-70. -60. -50. -40. -30.]
 [-20. -10.   0.  10.  20.]
 [ 30.  40.  50.  60.  70.]]

concat0([x1, x2]):
[[ -7.  -6.  -5.  -4.  -3.]
 [ -2.  -1.   0.   1.   2.]
 [  3.   4.   5.   6.   7.]
 [-70. -60. -50. -40. -30.]
 [-20. -10.   0.  10.  20.]
 [ 30.  40.  50.  60.  70.]]

concat1([x1, x2]):
[[ -7.  -6.  -5.  -4.  -3. -70. -60. -50. -40. -30.]
 [ -2.  -1.   0.   1.   2. -20. -10.   0.  10.  20.]
 [  3.   4.   5.   6.   7.  30.  40.  50.  60.  70.]]

concat0: Concatenate{in=2,out=1}
concat1: Concatenate{in=2,out=1}


## Layers are trainable.

Most layer types are trainable: they include parameters that modify the computation of outputs from inputs, and they use back-progagated gradients to update those parameters.

Before use, trainable layers must have their parameters initialized, typically using a PRNG (pseudo-random number generator) key for random number generation. Trax's model trainers take care of this behind the scenes, but if you are using a layer in insolation, you have to do the initialization yourself. For this, use the `initialize_once` method:

```
  def initialize_once(self, input_signature):
    """Initializes this layer and its sublayers recursively.

    This method is designed to initialize each layer instance once, even if the
    same layer instance occurs in multiple places in the network. This enables
    weight sharing to be implemented as layer sharing.

    ...
```

### Example 3. tl.LayerNorm $[n_{in} = 1, n_{out} = 1]$

In [0]:
x = floats_range(-7, 8).reshape(3, -1)

layer_norm = tl.LayerNorm()
layer_norm.initialize_once(signature(x))
y = layer_norm(x)

template = ('x:\n{}\n\n'
            'layer_norm(x):\n{}\n')
print(template.format(x, y))
print('layer_norm.weights:\n{}'.format(layer_norm.weights))

x:
[[-7. -6. -5. -4. -3.]
 [-2. -1.  0.  1.  2.]
 [ 3.  4.  5.  6.  7.]]

layer_norm(x):
[[-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]]

layer_norm.weights:
(_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))


## Layers combine into layers.

The Trax library authors encourage users, where possible, to build new layers as combinations of existing layers. The library provides a small set of _combinator_ layers for this: layer objects that make a list of layers behave as a single layer (a unit able to compute outputs from inputs, update parameters from gradients, and combine with yet more layers).



## Combine with Serial(...)

The most common way to combine layers is serially, using the `Serial` class:
```
class Serial(base.Layer):
  """Combinator that applies layers serially (by function composition).

  A Serial combinator uses stack semantics to manage data for its sublayers.
  Each sublayer sees only the inputs it needs and returns only the outputs it
  has generated. The sublayers interact via the data stack. For instance, a
  sublayer k, following sublayer j, gets called with the data stack in the
  state left after layer j has applied. The Serial combinator then:

    - takes n_in items off the top of the stack (n_in = k.n_in) and calls
      layer k, passing those items as arguments; and

    - takes layer k's n_out return values (n_out = k.n_out) and pushes
      them onto the data stack.

  ...
```
If one layer has the same number of outputs as the next layer has inputs (which is quite common), the successive layers behave like function composition:

```
#  h(.) = g(f(.))
layer_h = Serial(
    layer_f,
    layer_g,
)
```

### Example 4. y = layer_norm(relu(x)) $[n_{in} = 1, n_{out} = 1]$

In [0]:
x = floats_range(-7, 8).reshape(3, -1)

layer_block = tl.Serial(
    tl.Relu(),
    tl.LayerNorm(),
)
layer_block.initialize_once(signature(x))
y = layer_block(x)

template = ('x:\n{}\n\n'
            'layer_block(x):\n{}')
print(template.format(x, y,))

x:
[[-7. -6. -5. -4. -3.]
 [-2. -1.  0.  1.  2.]
 [ 3.  4.  5.  6.  7.]]

layer_block(x):
[[ 0.     0.     0.     0.     0.   ]
 [-0.75  -0.75  -0.75   0.5    1.75 ]
 [-1.414 -0.707  0.     0.707  1.414]]


And we can inspect the block as a whole, as if it were just another layer:

### Example 4'. Inspecting a Serial layer.

In [0]:
print('layer_block:\n{}\n'.format(layer_block))

print('layer_block.weights:\n{}'.format(layer_block.weights))

layer_block:
Serial{in=1,out=1,sublayers=[Relu{in=1,out=1}, LayerNorm{in=1,out=1}]}

layer_block.weights:
[(), (_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))]


## Combine with Parallel(...)

The `Parallel` combinator arranges layers into separate computational channels, each with its own inputs/outputs and gradient flows:
```
class Parallel(base.Layer):
  """Combinator that applies a list of layers in parallel to its inputs.

  Layers in the list apply to successive spans of inputs, where the spans are
  determined how many inputs each layer takes. The resulting output is the
  (flattened) concatenation of the resepective layer outputs.

  For example, suppose one has three layers:

    - F: 1 input, 1 output
    - G: 3 inputs, 1 output
    - H: 2 inputs, 2 outputs (h1, h2)

  Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:

    - inputs: a, b, c, d, e, f
    - outputs: F(a), G(b, c, d), h1, h2
```

Separate (parallel) computation channels make sense when each channel can do its work (computing outputs from inputs) independent of the inputs and outputs of the others.

As a simplistic example, consider writing a converter from three-digit octal (base 8) numerals to their corresponding values. For instance, to do conversions such as
```
123 (octal) = 1 * 8^2 + 2 * 8^1 + 3 * 8^0 =  83 (decimal)
345 (octal) = 3 * 8^2 + 4 * 8^1 + 5 * 8^0 = 229 (decimal)
567 (octal) = 5 * 8^2 + 6 * 8^1 + 7 * 8^0 = 375 (decimal)
701 (octal) = 7 * 8^2 + 0 * 8^1 + 1 * 8^0 = 449 (decimal)
```
the digits can first be converted independently, according to their place value (multiply by 64, multiply by 8, or multiply by 1). The following code runs the 64's-place digits ([1, 3, 5, 7]) through one layer, the 8's-place digits ([2, 4, 6, 0]) through a different layer, and the 1's-place digits ([3, 5, 7, 1]) through yet a different layer. These three layers are combined in a Parallel layer.

### Example 5. Processing octal digits in parallel.

In [0]:

# Set up three input channels, for digits with different place values.
place_64_digits = onp.array([1, 3, 5, 7])
place_8_digits = onp.array([2, 4, 6, 0])
place_1_digits = onp.array([3, 5, 7, 1])
inputs = (place_64_digits, place_8_digits, place_1_digits)
input_shapes = [[3]] * 3
input_dtypes = [onp.int32] * 3
input_signature = (ShapeDtype((3, ), onp.int32),) * 3

# Create three simple layers, each for computing a specific base 8 place value.
# Then create a combined layer to convert the respective digits in parallel.
# Initialize the combined layer and apply it.
sixty_fours = tl.MulConstant(constant=64.0)  # 8^2: '100' in base 8 digits
eights = tl.MulConstant(constant=8.0)        # 8^1:  '10' in base 8 digits
ones = tl.MulConstant(constant=1.0)          # 8^0:   '1' in base 8 digits
octal_place_values = tl.Parallel(sixty_fours, eights, ones)
octal_place_values.initialize_once(input_signature)
outputs = octal_place_values(inputs)

# Show inputs, outputs, and properties.
template = ('inputs:\n{}\n\n'
            'octal_place_values(inputs):\n{}\n')
print(template.format(inputs, outputs))
show_layer_properties(octal_place_values, 'octal_place_values')

inputs:
(array([1, 3, 5, 7]), array([2, 4, 6, 0]), array([3, 5, 7, 1]))

octal_place_values(inputs):
(array([ 64., 192., 320., 448.]), array([16., 32., 48.,  0.]), array([3., 5., 7., 1.]))

octal_place_values.n_in:  3
octal_place_values.n_out: 3
octal_place_values.sublayers: [MulConstant{in=1,out=1}, MulConstant{in=1,out=1}, MulConstant{in=1,out=1}]
octal_place_values.weights:    ((), (), ())



To complete the example, the three output streams for the different place values are combined by successive pairwise additions.

### Example 5'. Combining outputs from the parallel digit processors.

In [0]:
evaluate_octal = tl.Serial(
    tl.Parallel(sixty_fours, eights, ones),
    tl.Add(),  # Add the 64's-place values and the 8's-place values.
    tl.Add(),  # Add the 1's-place values to the sums from the previous Add.
)
evaluate_octal.initialize_once(input_signature)
y = evaluate_octal(inputs)

template = ('inputs:\n{}\n\n'
            'octal_place_values(inputs):\n{}\n\n'
            'evaluate_octal(inputs):\n{}\n')
print(template.format(inputs, outputs, y))
show_layer_properties(evaluate_octal, 'evaluate_octal')

inputs:
(array([1, 3, 5, 7]), array([2, 4, 6, 0]), array([3, 5, 7, 1]))

octal_place_values(inputs):
(array([ 64., 192., 320., 448.]), array([16., 32., 48.,  0.]), array([3., 5., 7., 1.]))

evaluate_octal(inputs):
[ 83. 229. 375. 449.]

evaluate_octal.n_in:  3
evaluate_octal.n_out: 1
evaluate_octal.sublayers: [Parallel{in=3,out=3,sublayers=[MulConstant{in=1,out=1}, MulConstant{in=1,out=1}, MulConstant{in=1,out=1}]}, Add{in=2,out=1}, Add{in=2,out=1}]
evaluate_octal.weights:    [((), (), ()), (), ()]



# 2. Data Streams

The trax runtime supports the concept of multiple data streams, which gives individual layers flexibility to:
  - process a single data stream ($n_{out} = n_{in} = 1$),
  - process multiple parallel data streams ($n_{out} = n_{in} = 2, 3, ... $),
  - split data streams ($n_{out} > n_{in}$), or
  - merge data streams ($n_{out} < n_{in}$).

The Trax library handles residual connections, for example, as three layers that in turn do a split, a parallel process, and a merge:
```
def Residual(*layers, **kwargs):
  """Adds a residual connection in parallel to a series of layers."""
  shortcut = kwargs.get('shortcut')  # default None signals no-op
  return [
      Dup(),  # pylint: disable=no-value-for-parameter
      Parallel(shortcut, layers),
      Add(),  # pylint: disable=no-value-for-parameter
  ]
```

In more detail, the logic is:
  - `Dup()`: make two identical copies of the single incoming data stream
  - `Parallel(shortcut, layers)`: pass one copy via the shortcut (typically a no-op) and process the other copy via the given layers, applied in series
  - `Add()`: combine the two streams back into one by adding elementwise

### Example 6. Residual connections

# 3. Data Stack

# 4. Defining New Layer Classes

## Simpler layers, with the `@layer` decorator

## Full subclass definitions, where necessary

# 5. Models