In [1]:
import sys
import jax.numpy as jnp
import numpy as np
from trax import layers as tl
from trax.shapes import signature
from trax.layers import combinators as cb
from trax.layers.assert_shape import assert_shape

sys.path.insert(0, "..")
from src.models.summary import summary


In the trax lesson, we have implemented an Hadamard layer

$$ Hadamard(x_1, x_2) = x_1 \otimes x_2 $$

And a GLU model

$$ GLU(X) = \sigma(W_1X + b_1) \otimes (W_2X + b_2) $$

Where $\sigma$ is the softmax function.

In [2]:
def Hadamard():
    def f(x0, x1):
        return jnp.multiply(x0, x1)

    return tl.Fn("Hadamard", f, n_out=1)


@assert_shape("bd->bd")
def GLU(units: int):
    gate = cb.Serial(tl.Dense(units), tl.Softmax(axis=-1))

    model = cb.Serial(
        cb.Branch(gate, tl.Dense(units)),
        Hadamard(),
    )
    return model


We will use this to implement the a Gated Residual Unit. First, we make a model that chains the GLU with a single Dense layer and an Elu activation:

$$f_1(X) = Elu(W\cdot X + b) $$
$$f_2(X) = GLU(f_1(X))$$

Or, written as a chain:

$$X \rightarrow Dense \rightarrow Elu \rightarrow GLU$$

Or visual

<img src="../figures/f2.png">

Implement $f_2$ as a trax model.

In [22]:
# TODO about 8 lines of code

Now, we want to make a parallel model.
One branch goes through just a Linear model:

$$f_3(X) = W \cdot X + b$$

The other branch goes through the $f_2$ chain:
$$f_2(X) = GLU(f1(X))$$

These two outputs need to be added, and normalized with `tl.LayerNorm()`

$$ GRN(X) = LayerNorm(f_3(X) + f_2(X)) $$

Or, if you prefer visual:

<img src="../figures/grn.png" >


In [44]:
def GRN(units: int):
    # TODO ~ about 6 lines of code
    pass


To test the model:

In [45]:
X = np.random.rand(32, 20)
grn = GRN(128)
grn.init_weights_and_state(signature(X))


In [46]:
yhat = grn(X)
signature(yhat)


ShapeDtype{shape:(32, 128), dtype:float32}