Derived from <a href="https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html">Trax Docs/Tutorial</a>.

In [1]:
# Copyright 2018 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
import numpy as np
from   trax import fastmath
from   trax import layers as tl
from   trax import shapes
from   trax.fastmath import numpy as jnp
from   trax.shapes import signature, ShapeDtype

In [3]:
np.set_printoptions(precision=3)

In [4]:
def show_layer_props(layer, layer_name):
    print(f'{layer_name}.n_in:      {layer.n_in}\n'
          f'{layer_name}.n_out:     {layer.n_out}\n'
          f'{layer_name}.sublayers: {layer.sublayers}\n'
          f'{layer_name}.weights:   {layer.weigths}\n')

In [5]:
relu = tl.Relu()
x = np.array([[-2, -1, 0, 1, 2], [-20, -10, 0, 10, 20]])
y = relu(x)
x



array([[ -2,  -1,   0,   1,   2],
       [-20, -10,   0,  10,  20]])

In [6]:
y

DeviceArray([[ 0,  0,  0,  1,  2],
             [ 0,  0,  0, 10, 20]], dtype=int32)

In [7]:
relu.n_in, relu.n_out

(1, 1)

In [8]:
concat = tl.Concatenate()
x0 = np.array([[1, 2, 3], [4, 5, 6]])
x1 = np.array([[7, 8, 9], [10, 11, 12]])
y = concat([x0, x1])

In [9]:
x0

array([[1, 2, 3],
       [4, 5, 6]])

In [10]:
x1

array([[ 7,  8,  9],
       [10, 11, 12]])

In [11]:
y

DeviceArray([[ 1,  2,  3,  7,  8,  9],
             [ 4,  5,  6, 10, 11, 12]], dtype=int32)

In [12]:
concat.n_in, concat.n_out

(2, 1)

### Layers Configurable

In [13]:
concat3 = tl.Concatenate(n_items=3, axis=0)
x2 = np.array([[13, 14, 15], [16, 17, 18]])
y = concat3([x0, x1, x2])
y

DeviceArray([[ 1,  2,  3],
             [ 4,  5,  6],
             [ 7,  8,  9],
             [10, 11, 12],
             [13, 14, 15],
             [16, 17, 18]], dtype=int32)

### Layers Trainable

In [14]:
layer_norm = tl.LayerNorm()
x = np.array([[-2, -1, -0,  1,  2],
              [ 1,  2,  3,  4,  5],
              [10, 20, 30, 40, 50]]
).astype(np.float32)
layer_norm.init(shapes.signature(x))

y = layer_norm(x)
y

DeviceArray([[-1.414, -0.707,  0.   ,  0.707,  1.414],
             [-1.414, -0.707,  0.   ,  0.707,  1.414],
             [-1.414, -0.707,  0.   ,  0.707,  1.414]], dtype=float32)

### Combining Layers