Derived from <a href="https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html">Trax Docs/Tutorial</a>.

In [1]:
# Copyright 2018 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
import numpy as np
from   trax import fastmath
from   trax import layers as tl
from   trax import shapes
from   trax.fastmath import numpy as jnp
from   trax.shapes import signature, ShapeDtype

In [3]:
np.set_printoptions(precision=3)

In [4]:
def show_layer_props(layer, layer_name):
    print(f'{layer_name}.n_in:      {layer.n_in}\n'
          f'{layer_name}.n_out:     {layer.n_out}\n'
          f'{layer_name}.sublayers: {layer.sublayers}\n'
          f'{layer_name}.weights:   {layer.weigths}\n')

In [5]:
relu = tl.Relu()
x = np.array([[-2, -1, 0, 1, 2], [-20, -10, 0, 10, 20]])
y = relu(x)
x



array([[ -2,  -1,   0,   1,   2],
       [-20, -10,   0,  10,  20]])

In [6]:
y

DeviceArray([[ 0,  0,  0,  1,  2],
             [ 0,  0,  0, 10, 20]], dtype=int32)

In [7]:
relu.n_in, relu.n_out

(1, 1)

In [8]:
concat = tl.Concatenate()
x0 = np.array([[1, 2, 3], [4, 5, 6]])
x1 = np.array([[7, 8, 9], [10, 11, 12]])
y = concat([x0, x1])

In [9]:
x0

array([[1, 2, 3],
       [4, 5, 6]])

In [10]:
x1

array([[ 7,  8,  9],
       [10, 11, 12]])

In [11]:
y

DeviceArray([[ 1,  2,  3,  7,  8,  9],
             [ 4,  5,  6, 10, 11, 12]], dtype=int32)

In [12]:
concat.n_in, concat.n_out

(2, 1)

### Layers Configurable

In [13]:
concat3 = tl.Concatenate(n_items=3, axis=0)
x2 = np.array([[13, 14, 15], [16, 17, 18]])
y = concat3([x0, x1, x2])
y

DeviceArray([[ 1,  2,  3],
             [ 4,  5,  6],
             [ 7,  8,  9],
             [10, 11, 12],
             [13, 14, 15],
             [16, 17, 18]], dtype=int32)

### Layers Trainable

In [14]:
layer_norm = tl.LayerNorm()
x = np.array([[-2, -1, -0,  1,  2],
              [ 1,  2,  3,  4,  5],
              [10, 20, 30, 40, 50]]
).astype(np.float32)
layer_norm.init(shapes.signature(x))

y = layer_norm(x)
y

DeviceArray([[-1.414, -0.707,  0.   ,  0.707,  1.414],
             [-1.414, -0.707,  0.   ,  0.707,  1.414],
             [-1.414, -0.707,  0.   ,  0.707,  1.414]], dtype=float32)

### Combining Layers

In [15]:
layer_block = tl.Serial(tl.Relu(), tl.LayerNorm())

x = np.array([[ -2,  -1, 0,  1,  2],
              [-20, -10, 0, 10, 20]]
).astype(np.float32)

layer_block.init(shapes.signature(x))
y = layer_block(x)
y

DeviceArray([[-0.75, -0.75, -0.75,  0.5 ,  1.75],
             [-0.75, -0.75, -0.75,  0.5 ,  1.75]], dtype=float32)

In [16]:
print(f'layer_block: {layer_block}\n\nweights: {layer_block.weights}')

layer_block: Serial[
  Relu
  LayerNorm
]

weights: ((), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32)))


In [17]:
times_100 = tl.Fn('x100', lambda x: x * 100.)
branch_relu_t100 = tl.Branch(relu, times_100)
branch_relu_t100.init(shapes.signature(x))
y0, y1 = branch_relu_t100(x)
y0, y1

(DeviceArray([[ 0.,  0.,  0.,  1.,  2.],
              [ 0.,  0.,  0., 10., 20.]], dtype=float32),
 array([[ -200.,  -100.,     0.,   100.,   200.],
        [-2000., -1000.,     0.,  1000.,  2000.]], dtype=float32))

### 3. Defining New Layer Classes

In [20]:
def GCD():
    '''
    Returns a layer to compute the greatest common divisor, elementwise
    '''
    return tl.Fn('GCD', lambda a, b: jnp.gcd(a, b))

In [22]:
gcd = GCD()
a = np.array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
b = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
y = gcd((a, b))
y

DeviceArray([ 1,  2,  1,  2,  5,  2,  1,  2,  1, 10], dtype=int32)

In [23]:
def SumAndMax():
    '''
    Returns a layer to compute sums and maxima of tow input tensors
    '''
    return tl.Fn(
        'SumAndMax', lambda a, b: (a + b, jnp.maximum(a, b)), n_out=2)

In [24]:
sum_and_max = SumAndMax()
a = np.array([ 1,   2,  3,   4,  5])
b = np.array([10, -20, 30, -40, 50])
y, z = sum_and_max([a, b])
y, z

(array([ 11, -18,  33, -36,  55]),
 DeviceArray([10,  2, 30,  4, 50], dtype=int32))

In [25]:
def Flatten(n_axes_to_keep=1):
    '''
    Return a layer that combines one or more trailing axes of a tensor.
    Flattening keep all the values of the input tensor, but reshapes it by
    collapsing one or more trailing axes into a single axis. For example,
    `Flatten(2)` on tensor with shape `(2, 3, 5, 7, 11)` -> shape 
    `(2, 3, 385)`
    Args:
      - n_axes_to_keep (int): Number of leading axes to keep (unchanged).
    '''
    layer_name = f'Flatten_keep_{n_axes_to_keep}'
    def f(x):
        in_rank = len(x.shape)
        if in_rank <= n_axes_to_keep:
            raise ValueError(
                f'Input rank ({in_rank}) must be greater that the number '
                f'of axes to keep ({n_axes_to_keep})')
        return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))

    return tl.Fn(layer_name, f)

In [26]:
flatten_keep_1 = Flatten()
flatten_keep_2 = Flatten(2)

x = np.array([[[  1,   2,   3],
               [ 10,  20,  30],
               [100, 200, 300]],
              [[  4,   5,   6],
               [ 40,  50,  60],
               [400, 500, 600]]])
y = flatten_keep_1(x)
z = flatten_keep_2(x)

y, z

(array([[  1,   2,   3,  10,  20,  30, 100, 200, 300],
        [  4,   5,   6,  40,  50,  60, 400, 500, 600]]),
 array([[[  1,   2,   3],
         [ 10,  20,  30],
         [100, 200, 300]],
 
        [[  4,   5,   6],
         [ 40,  50,  60],
         [400, 500, 600]]]))