In [1]:
import numpy as np
from   trax import layers as tl
from   trax import fastmath, shapes

In [2]:
!pip3 list | grep trax

trax                     1.3.5
You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.8/bin/python3.8 -m pip install --upgrade pip' command.[0m


In [3]:
# Layers
# Create a trax layer:
relu = tl.Relu()

print('name:', relu.name)
print('expected inputs:', relu.n_in)
print('promised outputs:', relu.n_out)

name: Relu
expected inputs: 1
promised outputs: 1


In [4]:
# input
x = np.array([-2, -1, 0, 1, 2])
print('Inputs:', x)

# out
y = relu(x)
print('Outputs:', y)

Inputs: [-2 -1  0  1  2]
Outputs: [0 0 0 1 2]




In [5]:
concat = tl.Concatenate()
concat.name, concat.n_in, concat.n_out

('Concatenate', 2, 1)

In [6]:
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print(x1)
print(x2)

y = concat([x1, x2])
y

[-10 -20 -30]
[1. 2. 3.]


DeviceArray([-10., -20., -30.,   1.,   2.,   3.], dtype=float32)

In [7]:
concat_3 = tl.Concatenate(n_items=3)
concat_3.name, concat_3.n_in, concat_3.n_out

('Concatenate', 3, 1)

In [8]:
x3 = x2 * 0.99
y = concat_3([x1, x2, x3])
y

DeviceArray([-10.  , -20.  , -30.  ,   1.  ,   2.  ,   3.  ,   0.99,
               1.98,   2.97], dtype=float32)

In [9]:
norm = tl.LayerNorm()
x = np.array([0, 1, 2, 3], dtype='float')
norm.init(shapes.signature(x))

print(x.shape, type(x.shape))
print(shapes.signature(x), type(shapes.signature(x)))

(4,) <class 'tuple'>
ShapeDtype{shape:(4,), dtype:float64} <class 'trax.shapes.ShapeDtype'>




In [10]:
norm.name, norm.n_in, norm.n_out

('LayerNorm', 1, 1)

In [11]:
(norm.weights[0], # W
 norm.weights[1]) # b

(DeviceArray([1., 1., 1., 1.], dtype=float32),
 DeviceArray([0., 0., 0., 0.], dtype=float32))

In [12]:
y = norm(x)
y

DeviceArray([-1.3416404 , -0.44721344,  0.44721344,  1.3416404 ], dtype=float32)

### Custom Layers

In [14]:
help(tl.Fn)

Help on function Fn in module trax.layers.base:

Fn(name, f, n_out=1)
    Returns a layer with no weights that applies the function `f`.
    
    `f` can take and return any number of arguments, and takes only positional
    arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
    The following, for example, would create a layer that takes two inputs and
    returns two outputs -- element-wise sums and maxima:
    
        `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
    
    The layer's number of inputs (`n_in`) is automatically set to number of
    positional arguments in `f`, but you must explicitly set the number of
    outputs (`n_out`) whenever it's not the default value 1.
    
    Args:
      name: Class-like name for the resulting layer; for use in debugging.
      f: Pure function from input tensors to output tensors, where each input
          tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
          

In [16]:
# Def a custom layer (looks like an ideal candidate for decorator)
def times_two():
    layer_name = 'times_two'

    def f(x):
        return 2*x
    
    return tl.Fn(layer_name, f)

In [17]:
ttwo = times_two()

print(ttwo.name)
print(ttwo.n_in)
print(ttwo.n_out)

times_two
1
1


In [18]:
x = np.array([1, 2, 3])
y = ttwo(x)
y

array([2, 4, 6])

### Combinators

In [23]:
# Serial combinator
serial = tl.Serial(tl.LayerNorm(), 
                   tl.Relu(), 
                   ttwo, 
                   tl.Dense(n_units=2), 
                   tl.Dense(n_units=1), 
                   tl.LogSoftmax())

In [24]:
x = np.array([-2, -1, 0, 1, 2])
serial.init(shapes.signature(x))
serial



Serial[
  LayerNorm
  Relu
  times_two
  Dense_2
  Dense_1
  LogSoftmax
]

In [26]:
print(serial.name)
print(serial.sublayers)
print(serial.n_in)
print(serial.n_out)
print(serial.weights)

Serial
[LayerNorm, Relu, times_two, Dense_2, Dense_1, LogSoftmax]
1
1
((DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), (), (), (DeviceArray([[ 0.11138923, -0.20193268],
             [ 0.3218486 , -0.6938446 ],
             [-0.29520795, -0.5566491 ],
             [ 0.03566048,  0.39482087],
             [-0.5446372 ,  0.90716356]], dtype=float32), DeviceArray([-1.2630071e-06, -1.3831954e-06], dtype=float32)), (DeviceArray([[ 0.5734267 ],
             [-0.61399156]], dtype=float32), DeviceArray([7.5263193e-07], dtype=float32)), ())


In [27]:
y = serial(x)
y

DeviceArray([0.], dtype=float32)

### JAX

In [28]:
xnp = np.array([1, 2, 3])
type(xnp)

numpy.ndarray

In [29]:
xjax = fastmath.numpy.array([1, 2, 3])
type(xjax)

jax.interpreters.xla.DeviceArray