# Trax Layers

This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:

1. [Layers](#1): the basic building blocks and how to combine them
1. [Inputs and Outputs](#2): how data streams flow through layers
1. Defining New Layer Classes (if combining existing layers isn’t enough)
1. Testing and Debugging Layer Classes

In [None]:
#@title Install dependencies
#@markdown - Trax
%%capture
!pip install -Uqq trax

In [38]:
#@title Import packages
import os

import numpy as np
import tensorflow as tf
import trax
from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp
from trax.shapes import ShapeDtype, signature

np.set_printoptions(precision=3)

print("numpy      ", np.__version__)
print("tensorflow ", tf.__version__)
!pip list | grep trax

numpy       1.19.5
tensorflow  2.4.0
trax                          1.3.7                


In [None]:
def show_layer_properties(layer_obj, layer_name):
    template = (
        f"{layer_name}.n_in:      {layer_obj.n_in}\n"
        f"{layer_name}.n_out:     {layer_obj.n_out}\n"
        f"{layer_name}.sublayers: {layer_obj.sublayers}\n"
        f"{layer_name}.weights:   {layer_obj.weights}\n"
    )
    print(template)

<a name='1'></a>
## 1. Layers

The Layer class represents Trax's basic building blocks.

> The inputs and outputs are NumPy arrays or JAX objects behaving as numpy arrays.

$tl.Relu[n_{in}=1, n_{out}=1]$

In [None]:
relu = tl.Relu()
x = np.array([[-2, -1, 0, 1, 2], [-20, -10, 0, 10, 20]])
y = relu(x)
print(
    f"x:\n{x}\n\n"
    f"relu(x):\n{y}\n\n"
    f"Number of inputs expected by this layer: {relu.n_in}\n"
    f"Number of outputs promised by this layer: {relu.n_out}"
)

x:
[[ -2  -1   0   1   2]
 [-20 -10   0  10  20]]

relu(x):
[[ 0  0  0  1  2]
 [ 0  0  0 10 20]]

Number of inputs expected by this layer: 1
Number of outputs promised by this layer: 1


$tl.Concatenate[n_{in}=2, n_{out}=1]$

In [None]:
concat = tl.Concatenate()
x0 = np.array([[1, 2, 3], [7, 8, 9]])
x1 = np.array([[4, 5, 6], [10, 11, 12]])
y = concat([x0, x1])
print(
    f"x0:\n{x0}\n\n"
    f"x1:\n{x1}\n\n"
    f"concat([x1, x2]):\n{y}\n\n"
    f"Number of inputs expected by this layer: {concat.n_in}\n"
    f"Number of outputs promised by this layer: {concat.n_out}"
)

x0:
[[1 2 3]
 [7 8 9]]

x1:
[[ 4  5  6]
 [10 11 12]]

concat([x1, x2]):
[[ 1  2  3  4  5  6]
 [ 7  8  9 10 11 12]]

Number of inputs expected by this layer: 2
Number of outputs promised by this layer: 1


### 1.1. Layers are configurable

Many layer types have creation-time parameters for flexibility. The Concatenate layer type, for instance, has two optional parameters:

- `axis`: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.
- `n_items`: number of tensors to join into one by concatenation; default value is 2.

The following example shows `Concatenate` configured for 3 input tensors, and concatenation along the initial ($0^{th}$) axis.



$tl.Concatenate[n_{items}=3, axis=0]$

In [None]:
concat3 = tl.Concatenate(n_items=3, axis=0)

x0 = np.array([[1, 2, 3], [4, 5, 6]])
x1 = np.array([[10, 20, 30], [40, 50, 60]])
x2 = np.array([[100, 200, 300], [400, 500, 600]])

y = concat3([x0, x1, x2])

print(
    f"x0:\n{x0}\n\n"
    f"x1:\n{x1}\n\n"
    f"x2:\n{x2}\n\n"
    f"concat3([x0, x1, x2]):\n{y}"
)

x0:
[[1 2 3]
 [4 5 6]]

x1:
[[10 20 30]
 [40 50 60]]

x2:
[[100 200 300]
 [400 500 600]]

concat3([x0, x1, x2]):
[[  1   2   3]
 [  4   5   6]
 [ 10  20  30]
 [ 40  50  60]
 [100 200 300]
 [400 500 600]]


### 1.2. Layers are trainable.
Many layer types include weights that affect the computation of outputs from inputs, and they use back-progagated gradients to update those weights.

🚧🚧 A very small subset of layer types, such as ``BatchNorm``, also include modifiable weights (called ``state``) that are updated based on forward-pass inputs/computation rather than back-propagated gradients.

Initialization

Trainable layers must be initialized before use. Trax can take care of this as part of the overall training process. In other settings (e.g., in tests or interactively in a Colab notebook), you need to initialize the outermost/topmost layer explicitly. For this, use init:

$tl.Concatenate[n_{items}=3, axis=0]$

In [None]:
layer_norm = tl.LayerNorm()

x = np.array([[-2, -1, 0, 1, 2], [1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]).astype(
    np.float32
)

layer_norm.init(shapes.signature(x))

y = layer_norm(x)

print(
    f"x:\n{x}\n\n"
    f"layer_norm(x):\n{y}\n"
    f"layer_norm.weights:\n{layer_norm.weights}"
)

x:
[[-2. -1.  0.  1.  2.]
 [ 1.  2.  3.  4.  5.]
 [10. 20. 30. 40. 50.]]

layer_norm(x):
[[-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]]
layer_norm.weights:
(DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))


### 1.3. Layers combine into layers.

The Trax library authors encourage users to build networks and network components as combinations of existing layers, by means of a small set of combinator layers. A combinator makes a list of layers behave as a single layer – by combining the sublayer computations yet looking from the outside like any other layer. The combined layer, like other layers, can:

- compute outputs from inputs,
- update parameters from gradients, and
- combine with yet more layers.


Combine with ``Serial``<br>
$h(.) = g(f(.))$

```python
layer_f = Serial(
    layer_f,
    layer_g,
)
```

In [None]:
layer_block = tl.Serial(
    tl.Relu(),
    tl.LayerNorm(),
)

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]]).astype(np.float32)

layer_block.init(shapes.signature(x))
y = layer_block(x)

print(
    f'x:\n{x}\n\n'
    f'layer_block(x):\n{y}'
)

x:
[[ -2.  -1.   0.   1.   2.]
 [-20. -10.   0.  10.  20.]]

layer_block(x):
[[-0.75 -0.75 -0.75  0.5   1.75]
 [-0.75 -0.75 -0.75  0.5   1.75]]


In [None]:
print(
    f"layer_block:         {layer_block}\n\n"
    f"layer_block.weights: {layer_block.weights}"
)

layer_block:         Serial[
  Serial[
    Relu
  ]
  LayerNorm
]

layer_block.weights: (((), (), ()), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32)))


Combine with ``Branch``

The Branch combinator arranges layers into parallel computational channels.

```python
def Residual(*layers, shortcut=None):
    layers = _ensure_flat(layers)
    layer = layers[0] if len(layers) == 1 else Serial(layers)
    return Serial(
        Branch(shortcut, layer),
        Add(),
    )
```

In [None]:
relu = tl.Relu()
times_100 = tl.Fn("Times100", lambda x: x * 100.0)
branch_relu_t100 = tl.Branch(relu, times_100)

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]])
branch_relu_t100.init(shapes.signature(x))

y0, y1 = branch_relu_t100(x)

print(
    f"x:\n{x}\n\n"
    f"y0:\n{y0}\n\n"
    f"y1:\n{y1}"
)

x:
[[ -2  -1   0   1   2]
 [-20 -10   0  10  20]]

y0:
[[ 0  0  0  1  2]
 [ 0  0  0 10 20]]

y1:
[[ -200.  -100.     0.   100.   200.]
 [-2000. -1000.     0.  1000.  2000.]]


<a name='2'></a>
## 2. Inputs and Outputs

Trax allows layers to have multiple input streams and output streams. When
designing a network, you have the flexibility to use layers that:

  - process a single data stream ($n_{in} = n_{out} = 1$),
  - process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, ... $),
  - split or inject data streams ($n_{in} < n_{out}$), or
  - merge or remove data streams ($n_{in} > n_{out}$).

We saw in section 1 the example of `Residual`, which involves both a split and a merge:
```
  ...
  return Serial(
      Branch(shortcut, layer),
      Add(),
  )
```
In other words, layer by layer:

  - `Branch(shortcut, layers)`: makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [$n_{in} = 1$, $n_{out} = 2$]
  - `Add()`: combines the two streams back into one by adding two tensors elementwise. [$n_{in} = 2$, $n_{out} = 1$]

**Simple Case 1 -- Each layer takes one input and has one output.**

This is in effect a single data stream pipeline, and the successive layers
behave like function composition:

```
#  s(.) = h(g(f(.)))
layer_s = Serial(
    layer_f,
    layer_g,
    layer_h,
)
```
Note how, inside `Serial`, function composition is expressed naturally as a
succession of operations, so that no nested parentheses are needed and the
order of operations matches the textual order of layers.

**Simple Case 2 -- Each layer consumes all outputs of the preceding layer.**

This is still a single pipeline, but data streams internal to it can split and
merge. The `Residual` example above illustrates this kind.


**General Case -- Successive layers interact via the data stack.**

As described in the `Serial` class docstring, each layer gets its inputs from
the data stack after the preceding layer has put its outputs onto the stack.
This covers the simple cases above, but also allows for more flexible data
interactions between non-adjacent layers. The following example is schematic:
```
x, y_target = get_batch_of_labeled_data()

model_plus_eval = Serial(
    my_fancy_deep_model(),  # Takes one arg (x) and has one output (y_hat)
    my_eval(),  # Takes two args (y_hat, y_target) and has one output (score)
)

eval_score = model_plus_eval((x, y_target))
```

Here is the corresponding progression of stack states:

0. At start: _--empty--_
0. After `get_batch_of_labeled_data()`: *x*, *y_target*
0. After `my_fancy_deep_model()`: *y_hat*, *y_target*
0. After `my_eval()`: *score*

Note in particular how the application of the model (between stack states 1
and 2) only uses and affects the top element on the stack: `x` --> `y_hat`.
The rest of the data stack (`y_target`) comes in use only later, for the
eval function.

## 3. Defining New Layer Classes

If you need a layer type that is not easily defined as a combination of
existing layer types, you can define your own layer classes in a couple
different ways.

**Example 7.** Use `Fn` to define a new layer type:

In [None]:
def Gcd():
    """returns a layer to compute the greatest commom divisor, elemementwise."""
    return tl.Fn("Gcd", lambda x0, x1: jnp.gcd(x0, x1))


gcd = Gcd()

x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])

y = gcd((x0, x1))

print(f"x0:\n{x0}\n\n" f"x1:\n{x1}\n\n" f"gcd((x0, x1)):\n{y}")

x0:
[ 1  2  3  4  5  6  7  8  9 10]

x1:
[11 12 13 14 15 16 17 18 19 20]

gcd((x0, x1)):
[ 1  2  1  2  5  2  1  2  1 10]
