In [1]:
import numpy as np
from trax import layers as tl
from trax import shapes
from trax import fastmath

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


In [6]:
!pip list | grep trax

trax                          1.3.1


## Layers
- Layers are the core building blocks in Trax
- They take inputs, compute functions/custom calculations and return outputs
- One can interospect layer properties

### ReLU Lyer
- No object initializaton, it works just like a math function

In [5]:
concat = tl.Concatenate()
print("--Properties--")
print(f"Name: {concat.name}")
print(f'Expected Inputs: {concat.n_in}')
print(f'Promised Ouputs: {concat.n_out}\n')

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print(f'x1: {x1}, x2: {x2}\n')

# Outputs
y = concat([x1, x2])
print("Outputs")
print(f'y: {y}')

--Properties--
Name: Concatenate
Expected Inputs: 2
Promised Ouputs: 1

x1: [-10 -20 -30], x2: [1. 2. 3.]

Outputs
y: [-10. -20. -30.   1.   2.   3.]


### Concatenate Layer

In [7]:
# Create a concatenate trax layer
concat = tl.Concatenate()
print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")

# Outputs
y = concat([x1, x2])
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Concatenate
expected inputs : 2
promised outputs : 1 

-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.] 

-- Outputs --
y : [-10. -20. -30.   1.   2.   3.]


### Layers are Configurable

In [9]:
concat_3 = tl.Concatenate(n_items=3)
print("-- Properties --")
print("name :", concat_3.name)
print("expected inputs :", concat_3.n_in)
print("promised outputs :", concat_3.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
x3 = x2 * 0.99
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3, "\n")

# Outputs
y = concat_3([x1, x2, x3])
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Concatenate
expected inputs : 3
promised outputs : 1 

-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
x3 : [0.99 1.98 2.97] 

-- Outputs --
y : [-10.   -20.   -30.     1.     2.     3.     0.99   1.98   2.97]


In [11]:
# help(tl.Concatenate)

### Layers can have Weights
- Some layer types include mutable weights and biases that are used in computation and training
- These type of layers require initialization
- For e.g. `LayerNorm` layer calculates normalized data, that is caled by weights adn biases
- During initialization, pass the data shape and data type of the inputs - to initialize compatible arrays of W and b

In [13]:
#help(tl.LayerNorm)
help(shapes.signature)

Help on function signature in module trax.shapes:

signature(obj)
    Returns a `ShapeDtype` signature for the given `obj`.
    
    A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
    instances. Note that this function is permissive with respect to its inputs
    (accepts lists or tuples or dicts, and underlying objects can be any type
    as long as they have shape and dtype attributes) and returns the corresponding
    nested structure of `ShapeDtype`.
    
    Args:
      obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
          of such objects.
    
    Returns:
      A corresponding nested structure of `ShapeDtype` instances.



In [14]:
# Layer Initialization
norm = tl.LayerNorm()
# Input data
x = np.array([0, 1, 2, 3], dtype="float")

# Use the input data signature to get shape and tyep for initializing weights and biases
norm.init(shapes.signature(x)) # Need to convert the input datatype from usual tuple to trax ShapeDType

print(f'Normal Shape: {x.shape}, Data Type: {type(x.shape)}')
print(f'Shapes Trax: {shapes.signature(x)}, Data type: {type(shapes.signature(x))}')

# Inspect Properties
print("-- Properties --")
print(f'name: {norm.name}')
print(f'Expected Inputs: {norm.n_in}')
print(f'Promised Outputs: {norm.n_out}')
# Weights and biases
print("weights :", norm.weights[0])
print("biases :", norm.weights[1], "\n")

# Inputs
print("-- Inputs --")
print("x :", x)

# Outputs
y = norm(x)
print("-- Outputs --")
print("y :", y)

Normal Shape: (4,), Data Type: <class 'tuple'>
Shapes Trax: ShapeDtype{shape:(4,), dtype:float64}, Data type: <class 'trax.shapes.ShapeDtype'>
-- Properties --
name: LayerNorm
Expected Inputs: 1
Promised Outputs: 1
weights : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.] 

-- Inputs --
x : [0. 1. 2. 3.]
-- Outputs --
y : [-1.3416404  -0.44721344  0.44721344  1.3416404 ]




In [16]:
np.linalg.norm(x)

3.7416573867739413

### Custom Layers
- Create custom layers too and define custom functions for computations by using `tl.Fn` 

In [17]:
help(tl.Fn)

Help on function Fn in module trax.layers.base:

Fn(name, f, n_out=1)
    Returns a layer with no weights that applies the function `f`.
    
    `f` can take and return any number of arguments, and takes only positional
    arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
    The following, for example, would create a layer that takes two inputs and
    returns two outputs -- element-wise sums and maxima:
    
        `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
    
    The layer's number of inputs (`n_in`) is automatically set to number of
    positional arguments in `f`, but you must explicitly set the number of
    outputs (`n_out`) whenever it's not the default value 1.
    
    Args:
      name: Class-like name for the resulting layer; for use in debugging.
      f: Pure function from input tensors to output tensors, where each input
          tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
          

In [18]:
# Define Custom Layer
# In this example you will create a layer to calculate the input times 2

def TimesTwo():
    layer_name = "TimesTwo"
    
    # Custom function for the custom layer
    def func(x):
        return x * 2
    
    return tl.Fn(layer_name, func)

# Test it
times_two = TimesTwo()

# Inspect properties
print("-- Properties --")
print("name :", times_two.name)
print("expected inputs :", times_two.n_in)
print("promised outputs :", times_two.n_out, "\n")

# Inputs
x = np.array([1, 2, 3])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = times_two(x)
print("-- Outputs --")
print("y :", y)

-- Properties --
name : TimesTwo
expected inputs : 1
promised outputs : 1 

-- Inputs --
x : [1 2 3] 

-- Outputs --
y : [2 4 6]


### Combinators
- Combine layers to build more complex layers
- Trax provides a set of objects named combinator layers to make this happen
- Combinators are themselves layers, so behavior commutes

**Serial Combinators**.  
- Simple NN can be built by combining layers into a single layer using the `Serial` combinator
- This new layer then acts just like a single layer
- One can inspect inputs, outputs and weights
- Or combine it into another layer
- Combinators can then be used as trainable models

In [21]:
# help(tl.Serial)
# help(tl.Parallel)

In [22]:
# Serial Cominator
serial = tl.Serial(
    tl.LayerNorm(),
    tl.Relu(),
    times_two
)
# Initialization
x = np.array([-2, -1, 0, 1, 2])
serial.init(shapes.signature(x))

print("-- Serial Model --")
print(serial,"\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
print("weights & biases:", serial.weights, "\n")

# Inputs
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = serial(x)
print("-- Outputs --")
print("y :", y)



-- Serial Model --
Serial[
  LayerNorm
  Relu
  TimesTwo
] 

-- Properties --
name : Serial
sublayers : [LayerNorm, Relu, TimesTwo]
expected inputs : 1
promised outputs : 1
weights & biases: [(DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), (), ()] 

-- Inputs --
x : [-2 -1  0  1  2] 

-- Outputs --
y : [0.        0.        0.        1.4142132 2.8284264]


## JAX


In [23]:
x_numpy = np.array([1, 2, 3])
print(f"Good old numpy: {type(x_numpy)}")

# Fastmath and JAX numpy
x_jax = fastmath.numpy.array([1, 2, 3])
print(f'JAX Trax Numpy: {type(x_jax)}')

Good old numpy: <class 'numpy.ndarray'>
JAX Trax Numpy: <class 'jax.interpreters.xla.DeviceArray'>
