# CS224N: JAX Tutorial (Spring '24)

## Introduction
Let's start by importing JAX:

In [87]:
import warnings

import jax
import jax.numpy as jnp
from jax import lax

# Import pprint, module we use for making our print statements prettier
import pprint
pp = pprint.PrettyPrinter()

We are all set to start our tutorial. Let's dive in!

## Part 1: Tensors

Each tensor is a multi-dimensional matrix; for example, a 256x256 square image might be represented by a `3x256x256` tensor, where the first dimension represents color. Here's how to create a tensor:


In [88]:
list_of_lists = [
  [1, 2, 3],
  [4, 5, 6],
]
print(list_of_lists)

[[1, 2, 3], [4, 5, 6]]


In [89]:
data = jnp.array(list_of_lists)
print(data)

[[1 2 3]
 [4 5 6]]


In [90]:
# Initializing a tensor
data = jnp.array([
    [0, 1],
    [2, 3],
    [4, 5]
])
print(data)

[[0 1]
 [2 3]
 [4 5]]


Each tensor has a **data type**: the major data types you'll need to worry about are floats (`float32`) and integers (`int`). You can specify the data type explicitly when you create the tensor:

In [91]:
# Initializing a tensor with an explicit data type
# Notice the dots after the numbers, which specify that they're floats
data = jnp.array([
    [0, 1],
    [2, 3],
    [4, 5]
], dtype="float32")
print(data)

[[0. 1.]
 [2. 3.]
 [4. 5.]]


In [92]:
# Initializing a tensor with an explicit data type
# Notice the dots after the numbers, which specify that they're floats
data = jnp.array([
    [0.11111111, 1],
    [2, 3],
    [4, 5]
], dtype="float32")
print(data)

[[0.11111111 1.        ]
 [2.         3.        ]
 [4.         5.        ]]


In [93]:
# Initializing a tensor with an implicit data type
# Notice the dots after the numbers, which specify that they're floats
data = jnp.array([
    [0.11111111, 1],
    [2, 3],
    [4, 5]
])
print(data)

[[0.11111111 1.        ]
 [2.         3.        ]
 [4.         5.        ]]


Utility functions also exist to create tensors with given shapes and contents:

In [94]:
zeros = jnp.zeros((2, 5))  # a tensor of all zeros
print(zeros)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [95]:
ones = jnp.ones((3, 4))   # a tensor of all ones
print(ones)

[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]


In [96]:
rr = jnp.arange(1, 10) # range from [1, 10)
print(rr)

[1 2 3 4 5 6 7 8 9]


In [97]:
rr + 2

Array([ 3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

In [98]:
rr * 2

Array([ 2,  4,  6,  8, 10, 12, 14, 16, 18], dtype=int32)

In [99]:
a = jnp.array([[1, 2], [2, 3], [4, 5]], dtype=jnp.float32)      # (3, 2)
b = jnp.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=jnp.float32)  # (2, 4)

print("A is", a)
print("B is", b)
print("The product is", a.dot(b)) #(3, 4)
print("The other product is", a @ b) # +, -, *, @

A is [[1. 2.]
 [2. 3.]
 [4. 5.]]
B is [[1. 2. 3. 4.]
 [5. 6. 7. 8.]]
The product is [[11. 14. 17. 20.]
 [17. 22. 27. 32.]
 [29. 38. 47. 56.]]
The other product is [[11. 14. 17. 20.]
 [17. 22. 27. 32.]
 [29. 38. 47. 56.]]


The **shape** of a matrix (which can be accessed by `.shape`) is defined as the dimensions of the matrix. Here's some examples:

In [100]:
matr_2d = jnp.array([[1, 2, 3], [4, 5, 6]])
print(matr_2d.shape)
print(matr_2d)

(2, 3)
[[1 2 3]
 [4 5 6]]


In [101]:
matr_3d = jnp.array([[[1, 2, 3, 4], [-2, 5, 6, 9]], [[5, 6, 7, 2], [8, 9, 10, 4]], [[-3, 2, 2, 1], [4, 6, 5, 9]]])
print(matr_3d)
print(matr_3d.shape)

[[[ 1  2  3  4]
  [-2  5  6  9]]

 [[ 5  6  7  2]
  [ 8  9 10  4]]

 [[-3  2  2  1]
  [ 4  6  5  9]]]
(3, 2, 4)


**Reshaping** tensors can be used to make batch operations easier (more on that later), but be careful that the data is reshaped in the order you expect:

In [102]:
rr = jnp.arange(1, 16)
print("The shape is currently", rr.shape)
print("The contents are currently", rr)
print()
rr = rr.reshape(5, 3)
print("After reshaping, the shape is currently", rr.shape)
print("The contents are currently", rr)

The shape is currently (15,)
The contents are currently [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

After reshaping, the shape is currently (5, 3)
The contents are currently [[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]
 [13 14 15]]


Finally, you can also inter-convert tensors with **NumPy arrays**:

In [103]:
import numpy as np

# numpy.ndarray --> torch.Tensor:
arr = np.array([[1, 0, 5]])
data = jnp.array(arr)
print(f"This is a {type(data)}", data)

# torch.Tensor --> numpy.ndarray:
new_arr = np.array(data)
print(f"This is a {type(new_arr)}", new_arr)

This is a <class 'jaxlib.xla_extension.ArrayImpl'> [[1 0 5]]
This is a <class 'numpy.ndarray'> [[1 0 5]]


One of the reasons why we use **tensors** is *vectorized operations*: operations that be conducted in parallel over a particular dimension of a tensor.

In [104]:
data = jnp.arange(1, 36, dtype=jnp.float32).reshape(5, 7)
print("Data is:", data)

# We can perform operations like *sum* over each row...
print("Taking the sum over rows:")
print(data.sum(axis=1)) #(5,)

# or over each column.
print("Taking thep sum over columns:")
print(data.sum(axis=0)) #(7,)

# Other operations are available:
print("Taking the stdev over rows:")
print(data.std(axis=1))


Data is: [[ 1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19. 20. 21.]
 [22. 23. 24. 25. 26. 27. 28.]
 [29. 30. 31. 32. 33. 34. 35.]]
Taking the sum over rows:
[ 28.  77. 126. 175. 224.]
Taking thep sum over columns:
[ 75.  80.  85.  90.  95. 100. 105.]
Taking the stdev over rows:
[2. 2. 2. 2. 2.]


In [105]:
data = jnp.arange(1, 7, dtype=jnp.float32).reshape(1, 2, 3)
print(data)
print(data.sum(axis=0).sum(axis=0))
print(data.sum(axis=0).sum(axis=0).shape)

[[[1. 2. 3.]
  [4. 5. 6.]]]
[5. 7. 9.]
(3,)


In [106]:
data.sum()

Array(21., dtype=float32)

### Quiz

Write code that creates a `jno.array` with the following contents:
$\begin{bmatrix} 1 & 2.2 & 9.6 \\ 4 & -7.2 & 6.3 \end{bmatrix}$

Now compute the average of each row (`.mean()`) and each column.

What's the shape of the results?



In [107]:
m = jnp.array([[1, 2.2, 9.6], [4, -7.2, 6.3]])
# mean of each row
print(m.mean(axis=1))
# mean of each column
print(m.mean(axis=0))

# The shapes are (2) and (3) respectively

[4.266667  1.0333335]
[ 2.5       -2.5        7.9500003]


**Indexing**

You can access arbitrary elements of a tensor using the `[]` operator.

In [108]:
# Initialize an example tensor
x = jnp.array([
    [[1, 2], [3, 4]],
    [[5, 6], [7, 8]],
    [[9, 10], [11, 12]]
])
x

Array([[[ 1,  2],
        [ 3,  4]],

       [[ 5,  6],
        [ 7,  8]],

       [[ 9, 10],
        [11, 12]]], dtype=int32)

In [109]:
x.shape

(3, 2, 2)

In [110]:
# Access the 0th element, which is the first row
x[0] # Equivalent to x[0, :]

Array([[1, 2],
       [3, 4]], dtype=int32)

In [111]:
x[:, 0]

Array([[ 1,  2],
       [ 5,  6],
       [ 9, 10]], dtype=int32)

We can also index into multiple dimensions with `:`.

In [112]:
# Get the top left element of each element in our tensor
x[:, 0, 0]

Array([1, 5, 9], dtype=int32)

In [113]:
x[:, :, :]

Array([[[ 1,  2],
        [ 3,  4]],

       [[ 5,  6],
        [ 7,  8]],

       [[ 9, 10],
        [11, 12]]], dtype=int32)

We can also access arbitrary elements in each dimension.

In [114]:
# Let's access the 0th and 1st elements, each twice
# same as stacking x[0], x[0], x[1], x[1]
i = jnp.array([0, 0, 1, 1])
x[i]

Array([[[1, 2],
        [3, 4]],

       [[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]],

       [[5, 6],
        [7, 8]]], dtype=int32)

In [115]:
# Let's access the 0th elements of the 1st and 2nd elements

i = jnp.array([1, 2])
j = jnp.array([0])
x[i, j]

Array([[ 5,  6],
       [ 9, 10]], dtype=int32)

We can get a `Python` scalar value from a tensor with `item()`.

In [116]:
x[0, 0, 0]

Array(1, dtype=int32)

In [117]:
x[0, 0, 0].item()

1

### Exercise:

Write code that creates a `jnp.array` with the following contents:
$\begin{bmatrix} 1 & 2.2 & 9.6 \\ 4 & -7.2 & 6.3 \end{bmatrix}$

How do you get the first column? The first row?



In [118]:
m2 = jnp.array([[1, 2.2, 9.6], [4, -7.2, 6.3]])
# get the first column
print(m2[:, 0])

# get the first row
print(m2[0, :])

[1. 4.]
[1.  2.2 9.6]


## Autograd

In [119]:
# Create an example tensor
x = jnp.array(2.)

# Calculating the gradient of y with respect to x
y = lambda x: x * x * 3 # 3x^2
y_grad = jax.grad(y)
y_grad(x)

Array(12., dtype=float32, weak_type=True)

## Neural Network Module

So far we have looked into the tensors, their properties and basic operations on tensors. These are especially useful to get familiar with if we are building the layers of our network from scratch. We will utilize these in Assignment 2, but moving forward, we will use predefined blocks in the `nnx` module of `flax`. We will then put together these blocks to create complex networks. Let's start by importing this module with an alias so that we don't have to type `flax` every time we use it.

In [120]:
from flax import nnx

### **Linear Layer**
We can use `nn.Linear(H_in, H_out)` to create a a linear layer. This will take a matrix of `(N, *, H_in)` dimensions and output a matrix of `(N, *, H_out)`. The `*` denotes that there could be arbitrary number of dimensions in between. The linear layer performs the operation `Ax+b`, where `A` and `b` are initialized randomly. If we don't want the linear layer to learn the bias parameters, we can initialize our layer with `bias=False`.

In [121]:
# Create the inputs
input = jnp.ones((2, 3, 4))
# N*H_in -> N*H_out

# Make a linear layers transforming N,*,H_in dimensinal inputs to N,*,H_out
# dimensional outputs
linear = nnx.Linear(4, 2, rngs=nnx.Rngs(0))
linear_output = linear(input)
linear_output

Array([[[-0.17026094, -0.6833216 ],
        [-0.17026094, -0.6833216 ],
        [-0.17026094, -0.6833216 ]],

       [[-0.17026094, -0.6833216 ],
        [-0.17026094, -0.6833216 ],
        [-0.17026094, -0.6833216 ]]], dtype=float32)

In [122]:
linear.kernel.value.shape, linear.bias.value.shape

((4, 2), (2,))

In [123]:
linear.kernel.value, linear.bias.value

(Array([[-0.31055146, -0.30089378],
        [ 0.44153705, -0.25188616],
        [-0.03567746, -0.9629547 ],
        [-0.26556906,  0.8324131 ]], dtype=float32),
 Array([0., 0.], dtype=float32))

### **Other Module Layers**
There are several other preconfigured layers in the `nnx` module. Some commonly used examples are `nnx.Conv`, `nnx.ConvTranspose`, `nn.BatchNorm`, among many others. We will learn more about these as we progress in the course. For now, the only important thing to remember is that we can treat each of these layers as plug and play components: we will be providing the required dimensions and `jax` will take care of setting them up.

### **Activation Function Layer**
We can also use the `nnx` module to apply activations functions to our tensors. Activation functions are used to add non-linearity to our network. Some examples of activations functions are `nnx.relu`, `nnx.sigmoid` and `nnx.leaky_relu`. Activation functions operate on each element seperately, so the shape of the tensors we get as an output are the same as the ones we pass in.

In [124]:
output = nnx.sigmoid(linear_output)
output

Array([[[0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ]],

       [[0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ]]], dtype=float32)

### **Putting the Layers Together**
So far we have seen that we can create layers and pass the output of one as the input of the next. Instead of creating intermediate tensors and passing them around, we can use `nn.Sequentual`, which does exactly that.

In [125]:
block = nnx.Sequential(
    nnx.Linear(4, 2, rngs=nnx.Rngs(0))
)

input = jnp.ones((2,3,4))
output = nnx.sigmoid(block(input))
output

Array([[[0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ]],

       [[0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ],
        [0.45753732, 0.3355204 ]]], dtype=float32)

### Custom Modules

Instead of using the predefined modules, we can also build our own by extending the `nnx.Module` class. For example, we can build a the `nnx.Linear` (which also extends `nnx.Module`) on our own using the tensor introduced earlier! We can also build new, more complex modules, such as a custom neural network. You will be practicing these in the later assignment.

To create a custom module, the first thing we have to do is to extend the `nnx.Module`. We can then initialize our parameters in the `__init__` function, starting with a call to the `__init__` function of the super class. All the class attributes we define which are `nn` module objects are treated as parameters, which can be learned during the training.

All classes extending `nnx.Module` are also expected to implement a `__call__(x)` function, where `x` is a tensor. This is the function that is called when a parameter is passed to our module, such as in `model(x)`.

In [126]:
class MultiLayerPerceptron(nnx.Module):

  def __init__(self, input_size: int, hidden_size: int, output_size: int, rngs: nnx.Rngs = nnx.Rngs(0)):
    # Call to the __init__ function of the super class
    super(MultiLayerPerceptron, self).__init__()

    # Bookkeeping: Saving the initialization parameters
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size

    self.linear = nnx.Linear(self.input_size, self.hidden_size, rngs=rngs)
    self.linear2 = nnx.Linear(self.hidden_size, self.output_size, rngs=rngs)

  def __call__(self, x):
    output = nnx.softmax(self.linear2(nnx.relu(self.linear(x))))
    return output

Now that we have defined our class, we can instantiate it and see what it does.

In [127]:
key = jax.random.key(seed=100)
input = jax.random.normal(shape=(2, 5), key=key)

model = MultiLayerPerceptron(5, 64, 3, rngs=nnx.Rngs(0))
model(input)


Array([[0.2449855 , 0.30524373, 0.44977078],
       [0.3758987 , 0.20768349, 0.41641778]], dtype=float32)

We can inspect the parameters of our model with `nnx.display(model)`

In [128]:
nnx.display(model)

MultiLayerPerceptron(
  input_size=5,
  hidden_size=64,
  output_size=3,
  linear=Linear(
    kernel=Param(
      value=Array(shape=(5, 64), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(64,), dtype=float32)
    ),
    in_features=5,
    out_features=64,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x11eca6ca0>,
    bias_init=<function zeros at 0x111902ac0>,
    dot_general=<function dot_general at 0x111196340>
  ),
  linear2=Linear(
    kernel=Param(
      value=Array(shape=(64, 3), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(3,), dtype=float32)
    ),
    in_features=64,
    out_features=3,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x11eca6ca0>,
    bias_init=<function zeros at 0x111902ac0>,
    dot_general=<function do

## Optimization
Having the gradients isn't enought for our models to learn. We also need to know how to update the parameters of our models. This is where the optimizers comes in. `optax` module contains several optimizers that we can use. Some popular examples are `optax.sgd` and `optax.adam`. Optimizers has a learning rate (`learning_rate`) parameter, which determines how big of an update will be made in every step. Different optimizers have different hyperparameters as well.

In [129]:
import optax

After we have our optimization function, we can define a `loss` that we want to optimize for. We can either define the loss ourselves, or use one of the predefined loss function in such as `optax.losses.squared_error`. Let's put everything together now! We will start by creating some dummy data.

Note: In JAX, `loss` function, when used with `nnx.value_and_grad(loss)(model, y)` needs to accept `model` as first argument, since `value_and_grad` calculates gradient w.r.t. first positional argument.

In [130]:
y_ohe = jax.nn.one_hot(jnp.ones((10, 5)), 5)
y_ohe

Array([[[0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]],

       [[0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]],

       [[0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]],

       [[0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]],

       [[0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]],

       [[0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]],

       [[0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]

In [131]:
key = jax.random.key(seed=100)

# Create the y data
# y = jnp.ones((10, 5))
y = y_ohe

# Add some noise to our goal y to generate our x
# We want out model to predict our original data, albeit the noise
x = y + jax.random.normal(key=key, shape=y.shape)
x

Array([[[-1.2262332 ,  1.4810336 , -0.5197212 ,  1.5202414 ,
          0.45887414],
        [ 0.10311025,  1.4125631 , -0.39090064, -0.09483104,
         -0.74713457],
        [ 0.25137082,  1.4861655 , -0.76141524, -0.13572052,
          0.91742074],
        [ 0.08653314,  1.8053247 ,  1.2471547 ,  0.9763958 ,
          1.220573  ],
        [ 1.8219421 ,  0.20261335, -0.15477557,  0.21770734,
          0.6405267 ]],

       [[ 1.1118156 ,  1.0775955 ,  0.6424729 , -1.0053515 ,
          0.99456733],
        [-0.3395303 ,  1.3200217 ,  0.10800842, -0.28815672,
         -1.5289674 ],
        [ 1.3016586 ,  2.3308759 ,  0.14011298, -1.5961647 ,
         -0.45299408],
        [-1.8585554 , -0.55924857,  0.52025634,  0.864114  ,
         -0.93108404],
        [-0.44154626,  1.0209888 , -0.153729  ,  1.3608798 ,
         -0.42661062]],

       [[-0.78405386,  0.55537456, -1.4077625 ,  0.5685388 ,
         -0.79089546],
        [-1.0732454 ,  1.2097561 , -0.99139154,  0.74864846,
          0

Now, we can define our model, optimizer and the loss function.

In [132]:
# Instantiate the model
model = MultiLayerPerceptron(5, 64, 5)

# Define the optimizer
adam = optax.adam(learning_rate=1e-1)

# Calculate how our model is doing now
y_pred = model(x)
loss = optax.losses.squared_error(y_pred, y).mean().item()
loss

0.18577083945274353

In [133]:
model

MultiLayerPerceptron(
  input_size=5,
  hidden_size=64,
  output_size=5,
  linear=Linear(
    kernel=Param(
      value=Array(shape=(5, 64), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(64,), dtype=float32)
    ),
    in_features=5,
    out_features=64,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x11eca6ca0>,
    bias_init=<function zeros at 0x111902ac0>,
    dot_general=<function dot_general at 0x111196340>
  ),
  linear2=Linear(
    kernel=Param(
      value=Array(shape=(64, 5), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(5,), dtype=float32)
    ),
    in_features=64,
    out_features=5,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x11eca6ca0>,
    bias_init=<function zeros at 0x111902ac0>,
    dot_general=<function do

In [134]:
optimizer = nnx.Optimizer(model, optax.adam(1e-3))

def loss(model, y):
    y_pred = model(x)
    loss_ = optax.losses.squared_error(y_pred, y).mean()
    return loss_

# Set the number of epoch, which determines the number of training iterations
n_epoch = 100

for epoch in range(n_epoch):

  loss_, grads = nnx.value_and_grad(loss)(model, y)
  optimizer.update(grads)  # In place updates.

  # # Print stats
  print(f"Epoch {epoch}: traing loss: {loss_}")

Epoch 0: traing loss: 0.18577083945274353
Epoch 1: traing loss: 0.1820353865623474
Epoch 2: traing loss: 0.17819486558437347
Epoch 3: traing loss: 0.17424796521663666
Epoch 4: traing loss: 0.1701943427324295
Epoch 5: traing loss: 0.16603554785251617
Epoch 6: traing loss: 0.1617729216814041
Epoch 7: traing loss: 0.15740953385829926
Epoch 8: traing loss: 0.1529502123594284
Epoch 9: traing loss: 0.14840103685855865
Epoch 10: traing loss: 0.14376921951770782
Epoch 11: traing loss: 0.1390639990568161
Epoch 12: traing loss: 0.13429582118988037
Epoch 13: traing loss: 0.1294756978750229
Epoch 14: traing loss: 0.12461644411087036
Epoch 15: traing loss: 0.1197335496544838
Epoch 16: traing loss: 0.11484318971633911
Epoch 17: traing loss: 0.10996294766664505
Epoch 18: traing loss: 0.10511116683483124
Epoch 19: traing loss: 0.10030772536993027
Epoch 20: traing loss: 0.09557072818279266
Epoch 21: traing loss: 0.09092001616954803
Epoch 22: traing loss: 0.08637544512748718
Epoch 23: traing loss: 0.081

You can see that our loss is decreasing. Let's check the predictions of our model now and see if they are close to our original `y`, which was all `1s`.

In [135]:
# See how our model performs on the training data
y_pred = model(x)
y_pred

Array([[[2.67716479e-02, 9.47577178e-01, 1.36385802e-02, 5.20828832e-03,
         6.80439686e-03],
        [6.35656118e-02, 8.13051045e-01, 6.00242838e-02, 2.36921944e-02,
         3.96668203e-02],
        [4.55217399e-02, 8.83715987e-01, 3.61319482e-02, 1.23430826e-02,
         2.22871881e-02],
        [8.32984317e-03, 9.69358802e-01, 8.59795883e-03, 3.20940767e-03,
         1.05040418e-02],
        [2.86420155e-02, 9.07064795e-01, 2.70510837e-02, 1.87784657e-02,
         1.84635464e-02]],

       [[1.99175198e-02, 9.14099336e-01, 3.17909084e-02, 1.43124331e-02,
         1.98798031e-02],
        [3.56519148e-02, 8.78746033e-01, 3.81462388e-02, 1.44328000e-02,
         3.30230631e-02],
        [6.28624018e-03, 9.83483851e-01, 6.50156848e-03, 1.14651758e-03,
         2.58181687e-03],
        [8.98824483e-02, 8.34748805e-01, 2.82318518e-02, 1.36638843e-02,
         3.34729291e-02],
        [6.88102767e-02, 8.36020827e-01, 4.01259586e-02, 2.32554059e-02,
         3.17875594e-02]],

      

In [136]:
# Create test data and check how our model performs on it
x2 = y + jax.random.normal(key=jax.random.key(seed=101), shape=y.shape)
y_pred = model(x2)
y_pred

Array([[[1.60775296e-02, 9.59791422e-01, 1.24694565e-02, 4.27387562e-03,
         7.38764415e-03],
        [5.99917397e-03, 9.87714767e-01, 4.03423421e-03, 6.78302778e-04,
         1.57347543e-03],
        [4.86144796e-02, 8.47507834e-01, 5.12109138e-02, 1.03687225e-02,
         4.22981009e-02],
        [1.84902817e-01, 4.91020918e-01, 9.47416723e-02, 8.48343968e-02,
         1.44500211e-01],
        [1.28983073e-02, 9.60945904e-01, 1.14672389e-02, 5.58851613e-03,
         9.10003111e-03]],

       [[1.52008524e-02, 9.65144753e-01, 1.07962517e-02, 2.49629864e-03,
         6.36180490e-03],
        [1.12328585e-02, 9.53729510e-01, 1.86248086e-02, 8.69716611e-03,
         7.71567831e-03],
        [1.09880820e-01, 7.75187790e-01, 3.24392989e-02, 2.11186018e-02,
         6.13734350e-02],
        [3.37357223e-02, 8.14603388e-01, 7.68600702e-02, 2.95931660e-02,
         4.52076495e-02],
        [1.57899912e-02, 9.38233197e-01, 1.79915857e-02, 5.94395120e-03,
         2.20412761e-02]],

      

Great! Looks like our model almost perfectly learned to filter out the noise from the `x` that we passed in!

## Demo: Word Window Classification

Until this part of the notebook, we have learned the fundamentals of JAX and built a basic network solving a toy task. Now we will attempt to solve an example NLP task. Here are the things we will learn:

1. Data: Creating a Dataset of Batched Tensors
2. Modeling
3. Training
4. Prediction

In this section, our goal will be to train a model that will find the words in a sentence corresponding to a `LOCATION`, which will be always of span `1` (meaning that `San Fransisco` won't be recognized as a `LOCATION`). Our task is called `Word Window Classification` for a reason. Instead of letting our model to only take a look at one word in each forward pass, we would like it to be able to consider the context of the word in question. That is, for each word, we want our model to be aware of the surrounding words. Let's dive in!

### Data

The very first task of any machine learning project is to set up our training set. Usually, there will be a training corpus we will be utilizing. In NLP tasks, the corpus would generally be a `.txt` or `.csv` file where each row corresponds to a sentence or a tabular datapoint. In our toy task, we will assume that we have already read our data and the corresponding labels into a `Python` list.

In [137]:
# Our raw data, which consists of sentences
corpus = [
    "We always come to Paris",
    "The professor is from Australia",
    "I live in Stanford",
    "He comes from Taiwan",
    "The capital of Turkey is Ankara"
]

#### Preprocessing

To make it easier for our models to learn, we usually apply a few preprocessing steps to our data. This is especially important when dealing with text data. Here are some examples of text preprocessing:
* **Tokenization**: Tokenizing the sentences into words.
* **Lowercasing**: Changing all the letters to be lowercase.
* **Noise removal:** Removing special characters (such as punctuations).
* **Stop words removal**: Removing commonly used words.

Which preprocessing steps are necessary is determined by the task at hand. For example, although it is useful to remove special characters in some tasks, for others they may be important (for example, if we are dealing with multiple languages). For our task, we will lowercase our words and tokenize.


In [138]:
# The preprocessing function we will use to generate our training examples
# Our function is a simple one, we lowercase the letters
# and then tokenize the words.
def preprocess_sentence(sentence):
  return sentence.lower().split()

# Create our training set
train_sentences = [preprocess_sentence(sent) for sent in corpus]
train_sentences

[['we', 'always', 'come', 'to', 'paris'],
 ['the', 'professor', 'is', 'from', 'australia'],
 ['i', 'live', 'in', 'stanford'],
 ['he', 'comes', 'from', 'taiwan'],
 ['the', 'capital', 'of', 'turkey', 'is', 'ankara']]

For each training example we have, we should also have a corresponding label. Recall that the goal of our model was to determine which words correspond to a `LOCATION`. That is, we want our model to output `0` for all the words that are not `LOCATION`s and `1` for the ones that are `LOCATION`s.

In [139]:
# Set of locations that appear in our corpus
locations = set(["australia", "ankara", "paris", "stanford", "taiwan", "turkey"])

# Our train labels
train_labels = [[1 if word in locations else 0 for word in sent] for sent in train_sentences]
train_labels

[[0, 0, 0, 0, 1],
 [0, 0, 0, 0, 1],
 [0, 0, 0, 1],
 [0, 0, 0, 1],
 [0, 0, 0, 1, 0, 1]]

#### Converting Words to Embeddings

Let's look at our training data a little more closely. Each datapoint we have is a sequence of words. On the other hand, we know that machine learning models work with numbers in vectors. How are we going to turn words into numbers? You may be thinking embeddings and you are right!

Imagine that we have an embedding lookup table `E`, where each row corresponds to an embedding. That is, each word in our vocabulary would have a corresponding embedding row `i` in this table. Whenever we want to find an embedding for a word, we will follow these steps:
1. Find the corresponding index `i` of the word in the embedding table: `word->index`.
2. Index into the embedding table and get the embedding: `index->embedding`.

Let's look at the first step. We should assign all the words in our vocabulary to a corresponding index. We can do it as follows:
1. Find all the unique words in our corpus.
2. Assign an index to each.

In [140]:
# Find all the unique words in our corpus
vocabulary = set(w for s in train_sentences for w in s)
vocabulary

{'always',
 'ankara',
 'australia',
 'capital',
 'come',
 'comes',
 'from',
 'he',
 'i',
 'in',
 'is',
 'live',
 'of',
 'paris',
 'professor',
 'stanford',
 'taiwan',
 'the',
 'to',
 'turkey',
 'we'}

`vocabulary` now contains all the words in our corpus. On the other hand, during the test time, we can see words that are not contained in our vocabulary. If we can figure out a way to represent the unknown words, our model can still reason about whether they are a `LOCATION` or not, since we are also looking at the neighboring words for each prediction.

We introduce a special token, `<unk>`, to tackle the words that are out of vocabulary. We could pick another string for our unknown token if we wanted. The only requirement here is that our token should be unique: we should only be using this token for unknown words. We will also add this special token to our vocabulary.

In [141]:
# Add the unknown token to our vocabulary
vocabulary.add("<unk>")

Earlier we mentioned that our task was called `Word Window Classification` because our model is looking at the surroundings words in addition to the given word when it needs to make a prediction.

For example, let's take the sentence "We always come to Paris". The corresponding training label for this sentence is `0, 0, 0, 0, 1` since only Paris, the last word, is a `LOCATION`. In one pass (meaning a call to `forward()`), our model will try to generate the correct label for one word. Let's say our model is trying to generate the correct label `1` for `Paris`. If we only allow our model to see `Paris`, but nothing else, we will miss out on the important information that the word `to` often times appears with `LOCATION`s.

Word windows allow our model to consider the surrounding `+N` or `-N` words of each word when making a prediction. In our earlier example for `Paris`, if we have a window size of 1, that means our model will look at the words that come immediately before and after `Paris`, which are `to`, and, well, nothing. Now, this raises another issue. `Paris` is at the end of our sentence, so there isn't another word following it. Remember that we define the input dimensions of our models when we are initializing them. If we set the window size to be `1`, it means that our model will be accepting `3` words in every pass. We cannot have our model expect `2` words from time to time.

The solution is to introduce a special token, such as `<pad>`, that will be added to our sentences to make sure that every word has a valid window around them. Similar to `<unk>` token, we could pick another string for our pad token if we wanted, as long as we make sure it is used for a unique purpose.

In [142]:
# Add the <pad> token to our vocabulary
vocabulary.add("<pad>")

# Function that pads the given sentence
# We are introducing this function here as an example
# We will be utilizing it later in the tutorial
def pad_window(sentence, window_size, pad_token="<pad>"):
  window = [pad_token] * window_size
  return window + sentence + window

# Show padding example
window_size = 2
pad_window(train_sentences[0], window_size=window_size)

['<pad>', '<pad>', 'we', 'always', 'come', 'to', 'paris', '<pad>', '<pad>']

Now that our vocabularly is ready, let's assign an index to each of our words.

In [143]:
# We are just converting our vocabularly to a list to be able to index into it
# Sorting is not necessary, we sort to show an ordered word_to_ind dictionary
# That being said, we will see that having the index for the padding token
# be 0 is convenient as some PyTorch functions use it as a default value
# such as nn.utils.rnn.pad_sequence, which we will cover in a bit
ix_to_word = sorted(list(vocabulary))

# Creating a dictionary to find the index of a given word
word_to_ix = {word: ind for ind, word in enumerate(ix_to_word)}
word_to_ix

{'<pad>': 0,
 '<unk>': 1,
 'always': 2,
 'ankara': 3,
 'australia': 4,
 'capital': 5,
 'come': 6,
 'comes': 7,
 'from': 8,
 'he': 9,
 'i': 10,
 'in': 11,
 'is': 12,
 'live': 13,
 'of': 14,
 'paris': 15,
 'professor': 16,
 'stanford': 17,
 'taiwan': 18,
 'the': 19,
 'to': 20,
 'turkey': 21,
 'we': 22}

In [144]:
ix_to_word[1]

'<unk>'

Great! We are ready to convert our training sentences into a sequence of indices corresponding to each token.

In [145]:
# Given a sentence of tokens, return the corresponding indices
def convert_token_to_indices(sentence, word_to_ix):
  indices = []
  for token in sentence:
    # Check if the token is in our vocabularly. If it is, get it's index.
    # If not, get the index for the unknown token.
    if token in word_to_ix:
      index = word_to_ix[token]
    else:
      index = word_to_ix["<unk>"]
    indices.append(index)
  return indices

# More compact version of the same function
def _convert_token_to_indices(sentence, word_to_ix):
  return [word_to_ind.get(token, word_to_ix["<unk>"]) for token in sentence]

# Show an example
example_sentence = ["we", "always", "come", "to", "kuwait"]
example_indices = convert_token_to_indices(example_sentence, word_to_ix)
restored_example = [ix_to_word[ind] for ind in example_indices]

print(f"Original sentence is: {example_sentence}")
print(f"Going from words to indices: {example_indices}")
print(f"Going from indices to words: {restored_example}")

Original sentence is: ['we', 'always', 'come', 'to', 'kuwait']
Going from words to indices: [22, 2, 6, 20, 1]
Going from indices to words: ['we', 'always', 'come', 'to', '<unk>']


In the example above, `kuwait` shows up as `<unk>`, because it is not included in our vocabulary. Let's convert our `train_sentences` to `example_padded_indices`.

In [146]:
# Converting our sentences to indices
example_padded_indices = [convert_token_to_indices(s, word_to_ix) for s in train_sentences]
example_padded_indices

[[22, 2, 6, 20, 15],
 [19, 16, 12, 8, 4],
 [10, 13, 11, 17],
 [9, 7, 8, 18],
 [19, 5, 14, 21, 12, 3]]

Now that we have an index for each word in our vocabularly, we can create an embedding table with `nnx.Embed` class in `JAX (Flax)`. It is called as follows `nnx.Embed(num_words, embedding_dimension)` where `num_words` is the number of words in our vocabulary and the `embedding_dimension` is the dimension of the embeddings we want to have. There is nothing fancy about `nnx.Embed`: it is just a wrapper class around a trainabe `NxE` dimensional tensor, where `N` is the number of words in our vocabulary and `E` is the number of embedding dimensions. This table is initially random, but it will change over time. As we train our network, the gradients will be backpropagated all the way to the embedding layer, and hence our word embeddings would be updated. We will initiliaze the embedding layer we will use for our model in our model, but we are showing an example here.

In [147]:
# Creating an embedding table for our words
embedding_dim = 5
embeds = nnx.Embed(len(vocabulary), embedding_dim, rngs=nnx.Rngs(0))
nnx.state(embeds)

State({
  'embedding': VariableState(
    type=Param,
    value=Array([[ 3.51217270e-01,  5.26763022e-01, -4.20501053e-01,
            -7.62650669e-01, -7.22563803e-01],
           [-4.73633148e-02,  6.94514513e-01,  1.99710235e-01,
            -8.48488331e-01,  4.44375962e-01],
           [-7.83036351e-01, -3.28497887e-01,  3.33785653e-01,
             6.07807398e-01,  1.22393698e-01],
           [-3.56444776e-01,  2.08796978e-01,  6.49359897e-02,
             7.55196088e-04, -8.47356141e-01],
           [-1.13405414e-01,  1.18730918e-01, -5.13231337e-01,
             6.54561281e-01,  4.54957813e-01],
           [-9.56119150e-02, -3.83494079e-01,  1.86414085e-02,
             2.14622915e-01, -2.31866419e-01],
           [ 4.09342647e-01,  7.63373852e-01,  8.30148607e-02,
            -5.72090387e-01, -5.33070326e-01],
           [ 6.10590100e-01,  1.57741010e-01,  3.44011247e-01,
            -3.61440815e-02, -4.50284958e-01],
           [-3.84995013e-01, -3.05854857e-01, -5.37834108e-0

To get the word embedding for a word in our vocabulary, all we need to do is to create a lookup tensor. The lookup tensor is just a tensor containing the index we want to look up `nn.Embedding` class expects an index tensor that is of type Long Tensor, so we should create our tensor accordingly.

In [148]:
# Get the embedding for the word Paris
index = word_to_ix["paris"]
index_tensor = jnp.array(index, dtype=jnp.int32)
paris_embed = embeds(index_tensor)
paris_embed

Array([ 0.00570847,  0.13735327, -0.4158411 ,  0.43066612, -0.46840468],      dtype=float32)

In [149]:
# We can also get multiple embeddings at once
index_paris = word_to_ix["paris"]
index_ankara = word_to_ix["ankara"]
indices = [index_paris, index_ankara]
indices_tensor = jnp.array(indices, dtype=jnp.int32)
embeddings = embeds(indices_tensor)
embeddings

Array([[ 5.7084723e-03,  1.3735327e-01, -4.1584110e-01,  4.3066612e-01,
        -4.6840468e-01],
       [-3.5644478e-01,  2.0879698e-01,  6.4935990e-02,  7.5519609e-04,
        -8.4735614e-01]], dtype=float32)

Usually, we define the embedding layer as part of our model, which you will see in the later sections of our notebook.

#### Batching Sentences

We have learned about batches in class. Waiting our whole training corpus to be processed before making an update is costly. On the other hand, updating the parameters after every training example causes the loss to be less stable between updates. To combat these issues, we instead update our parameters after training on a batch of data. This allows us to get a better estimate of the gradient of the global loss. In this section, we will learn how to structure our data into batches using the `torch.util.data.DataLoader` class.

We will be calling the `DataLoader` class as follows: `DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)`.  The `batch_size` parameter determines the number of examples per batch. In every epoch, we will be iterating over all the batches using the `DataLoader`. The order of batches is deterministic by default, but we can ask `DataLoader` to shuffle the batches by setting the `shuffle` parameter to `True`. This way we ensure that we don't encounter a bad batch multiple times.

If provided, `DataLoader` passes the batches it prepares to the `collate_fn`. We can write a custom function to pass to the `collate_fn` parameter in order to print stats about our batch or perform extra processing. In our case, we will use the `collate_fn` to:
1. Window pad our train sentences.
2. Convert the words in the training examples to indices.
3. Pad the training examples so that all the sentences and labels have the same length. Similarly, we also need to pad the labels. This creates an issue because when calculating the loss, we need to know the actual number of words in a given example. We will also keep track of this number in the function we pass to the `collate_fn` parameter.

Because our version of the `collate_fn` function will need to access to our `word_to_ix` dictionary (so that it can turn words into indices), we will make use of the `partial` function in `Python`, which passes the parameters we give to the function we pass it.

In [150]:
from functools import partial
# TODO: see how we can remove dependency on torch
import torch
from torch import nn

def custom_collate_fn(batch, window_size, word_to_ix):
  # Break our batch into the training examples (x) and labels (y)
  # We are turning our x and y into tensors because nn.utils.rnn.pad_sequence
  # method expects tensors. This is also useful since our model will be
  # expecting tensor inputs.
  x, y = zip(*batch)

  # Now we need to window pad our training examples. We have already defined a
  # function to handle window padding. We are including it here again so that
  # everything is in one place.
  def pad_window(sentence, window_size, pad_token="<pad>"):
    window = [pad_token] * window_size
    return window + sentence + window

  # Pad the train examples.
  x = [pad_window(s, window_size=window_size) for s in x]

  # Now we need to turn words in our training examples to indices. We are
  # copying the function defined earlier for the same reason as above.
  def convert_tokens_to_indices(sentence, word_to_ix):
    return [word_to_ix.get(token, word_to_ix["<unk>"]) for token in sentence]

  # Convert the train examples into indices.
  x = [convert_tokens_to_indices(s, word_to_ix) for s in x]

  # We will now pad the examples so that the lengths of all the example in
  # one batch are the same, making it possible to do matrix operations.
  # We set the batch_first parameter to True so that the returned matrix has
  # the batch as the first dimension.
  pad_token_ix = word_to_ix["<pad>"]

  # pad_sequence function expects the input to be a tensor, so we turn x into one
  x = [torch.LongTensor(x_i) for x_i in x]
  # TODO: see how we can efficiently pad_sequence with JAX
  x_padded = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=pad_token_ix)

  # We will also pad the labels. Before doing so, we will record the number
  # of labels so that we know how many words existed in each example.
  lengths = [len(label) for label in y]
  lenghts = jnp.array(lengths, dtype=jnp.int32)

  y = [torch.LongTensor(y_i) for y_i in y]
  y_padded = nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=0)

  # We are now ready to return our variables. The order we return our variables
  # here will match the order we read them in our training loop.
  return jnp.array(x_padded), jnp.array(y_padded), lenghts

In [151]:
import tensorflow as tf

# Parameters to be passed to the DataLoader
data = list(zip(train_sentences, train_labels))
batch_size = 2
shuffle = True
window_size = 2
collate_fn = partial(custom_collate_fn, window_size=window_size, word_to_ix=word_to_ix)

# Instantiate the DataLoader
# loader = DataLoader(data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
dataset = tf.data.Dataset.from_tensor_slices(collate_fn(data)).batch(batch_size=batch_size)

# Go through one loop
counter = 0
for batched_x, batched_y, batched_lengths in dataset:
  print(f"Iteration {counter}")
  print("Batched Input:")
  print(batched_x)
  print("Batched Labels:")
  print(batched_y)
  print("Batched Lengths:")
  print(batched_lengths)
  print("")
  counter += 1

Iteration 0
Batched Input:
tf.Tensor(
[[ 0  0 22  2  6 20 15  0  0  0]
 [ 0  0 19 16 12  8  4  0  0  0]], shape=(2, 10), dtype=int32)
Batched Labels:
tf.Tensor(
[[0 0 0 0 1 0]
 [0 0 0 0 1 0]], shape=(2, 6), dtype=int32)
Batched Lengths:
tf.Tensor([5 5], shape=(2,), dtype=int32)

Iteration 1
Batched Input:
tf.Tensor(
[[ 0  0 10 13 11 17  0  0  0  0]
 [ 0  0  9  7  8 18  0  0  0  0]], shape=(2, 10), dtype=int32)
Batched Labels:
tf.Tensor(
[[0 0 0 1 0 0]
 [0 0 0 1 0 0]], shape=(2, 6), dtype=int32)
Batched Lengths:
tf.Tensor([4 4], shape=(2,), dtype=int32)

Iteration 2
Batched Input:
tf.Tensor([[ 0  0 19  5 14 21 12  3  0  0]], shape=(1, 10), dtype=int32)
Batched Labels:
tf.Tensor([[0 0 0 1 0 1]], shape=(1, 6), dtype=int32)
Batched Lengths:
tf.Tensor([6], shape=(1,), dtype=int32)



The batched input tensors you see above will be passed into our model. On the other hand, we started off saying that our model will be a window classifier. The way our input tensors are currently formatted, we have all the words in a sentence in one datapoint. When we pass this input to our model, it needs to create the windows for each word, make a prediction as to whether the center word is a `LOCATION` or not for each window, put the predictions together and return.

We could avoid this problem if we formatted our data by breaking it into windows beforehand. In this example, we will instead how our model take care of the formatting.

Given that our `window_size` is `N` we want our model to make a prediction on every `2N+1` tokens. That is, if we have an input with `9` tokens, and a `window_size` of `2`, we want our model to return `5` predictions. This makes sense because before we padded it with `2` tokens on each side, our input also had `5` tokens in it!

We can create these windows by using for loops, but there is a faster `PyTorch` alternative, which is the `unfold(dimension, size, step)` method. We can create the windows we need using this method as follows:

In [None]:
def sliding_window_batched(inputs: jnp.array, window_size: int):
    batch_size, seq_length = inputs.shape
    num_windows = seq_length - window_size + 1
    
    # This function gets a single window for all sequences in the batch
    def get_window(start_idx: int):
        # dynamic_slice takes (operand, start_indices, slice_sizes)
        # We need start indices for both batch and sequence dimensions
        start_indices = jnp.array([0, start_idx])  # Start at beginning of batch
        slice_sizes = jnp.array([batch_size, window_size])
        return lax.dynamic_slice(inputs, start_indices, slice_sizes)
    
    # Create all windows by mapping over start indices
    windows = jax.vmap(get_window)(jnp.arange(num_windows))
    
    # Rearrange from (num_windows, batch_size, window_size) 
    # to (batch_size, num_windows, window_size)
    return jnp.transpose(windows, (1, 0, 2))

# Print the original tensor
print(f"Original Tensor: ")
print(batched_x)
print("")

# Create the 2 * 2 + 1 chunks
chunk = sliding_window_batched(jnp.array(batched_x), 2 * window_size + 1)
print(f"Windows: ")
print(jnp.array(chunk))

Original Tensor: 
tf.Tensor([[ 0  0 19  5 14 21 12  3  0  0]], shape=(1, 10), dtype=int32)

Windows: 
[[[ 0  0 19  5 14]
  [ 0 19  5 14 21]
  [19  5 14 21 12]
  [ 5 14 21 12  3]
  [14 21 12  3  0]
  [21 12  3  0  0]]]


### Model

Now that we have prepared our data, we are ready to build our model. We have learned how to write custom `nnx.Module` classes. We will do the same here and put everything we have learned so far together.

In [159]:
class WordWindowClassifier(nnx.Module):

  def __init__(self, hyperparameters, vocab_size, pad_ix=0, rng_seed=0):
    super(WordWindowClassifier, self).__init__()

    """ Instance variables """
    self.window_size = hyperparameters["window_size"]
    self.embed_dim = hyperparameters["embed_dim"]
    self.hidden_dim = hyperparameters["hidden_dim"]
    self.freeze_embeddings = hyperparameters["freeze_embeddings"]
    self.rng_seed = rng_seed

    """ Embedding Layer
    Takes in a tensor containing embedding indices, and returns the
    corresponding embeddings. The output is of dim
    (number_of_indices * embedding_dim).

    If freeze_embeddings is True, set the embedding layer parameters to be
    non-trainable. This is useful if we only want the parameters other than the
    embeddings parameters to change.

    """
    # self.embeds = nn.Embedding(vocab_size, self.embed_dim, padding_idx=pad_ix)
    self.embeds = nnx.Embed(vocab_size, self.embed_dim, rngs=nnx.Rngs(self.rng_seed))
    # if self.freeze_embeddings:
    #   self.embed_layer.weight.requires_grad = False

    """ Hidden Layer
    """
    full_window_size = 2 * window_size + 1
    self.hidden_layer = nnx.Sequential(
      nnx.Linear(full_window_size * self.embed_dim, self.hidden_dim, rngs=nnx.Rngs(self.rng_seed))
    )

    """ Output Layer
    """
    self.output_layer = nnx.Linear(self.hidden_dim, 1, rngs=nnx.Rngs(self.rng_seed))

  def __call__(self, inputs):
    """
    Let B:= batch_size
        L:= window-padded sentence length
        D:= self.embed_dim
        S:= self.window_size
        H:= self.hidden_dim

    inputs: a (B, L) tensor of token indices
    """
    B, L = inputs.shape

    """
    Reshaping.
    Takes in a (B, L) LongTensor
    Outputs a (B, L~, S) LongTensor
    """
    # Fist, get our word windows for each word in our input.
    token_windows = sliding_window_batched(inputs, 2 * self.window_size + 1)
    _, n_windows, _ = token_windows.shape

    # Good idea to do internal tensor-size sanity checks, at the least in comments!
    assert token_windows.shape == (B, n_windows, 2 * self.window_size + 1)

    """
    Embedding.
    Takes in a tensor of size (B, L~, S)
    Outputs a (B, L~, S, D) FloatTensor.
    """
    embedded_windows = self.embeds(token_windows)

    """
    Reshaping.
    Takes in a (B, L~, S, D) FloatTensor.
    Resizes it into a (B, L~, S*D) FloatTensor.
    -1 argument "infers" what the last dimension should be based on leftover axes.
    """
    embedded_windows = embedded_windows.reshape(B, n_windows, -1)

    """
    Layer 1.
    Takes in a (B, L~, S*D) FloatTensor.
    Resizes it into a (B, L~, H) FloatTensor
    """
    layer_1 = nnx.tanh(self.hidden_layer(embedded_windows))

    """
    Layer 2
    Takes in a (B, L~, H) FloatTensor.
    Resizes it into a (B, L~, 1) FloatTensor.
    """
    output = self.output_layer(layer_1)

    """
    Softmax.
    Takes in a (B, L~, 1) FloatTensor of unnormalized class scores.
    Outputs a (B, L~, 1) FloatTensor of (log-)normalized class scores.
    """
    output = nnx.sigmoid(output)
    output = output.reshape(B, -1)

    return output
  
@staticmethod
def sliding_window_batched(inputs: jnp.array, window_size: int):
    batch_size, seq_length = inputs.shape
    num_windows = seq_length - window_size + 1
    
    # This function gets a single window for all sequences in the batch
    def get_window(start_idx: int):
        # dynamic_slice takes (operand, start_indices, slice_sizes)
        # We need start indices for both batch and sequence dimensions
        start_indices = jnp.array([0, start_idx])  # Start at beginning of batch
        slice_sizes = jnp.array([batch_size, window_size])
        return lax.dynamic_slice(inputs, start_indices, slice_sizes)
    
    # Create all windows by mapping over start indices
    windows = jax.vmap(get_window)(jnp.arange(num_windows))
    
    # Rearrange from (num_windows, batch_size, window_size) 
    # to (batch_size, num_windows, window_size)
    return jnp.transpose(windows, (1, 0, 2))

### Training

We are now ready to put everything together. Let's start with preparing our data and intializing our model. We can then intialize our optimizer and define our loss function. This time, instead of using one of the predefined loss function as we did before, we will define our own loss function.

In [160]:
# Prepare the data
data = list(zip(train_sentences, train_labels))
batch_size = 2
shuffle = True
window_size = 2
collate_fn = partial(custom_collate_fn, window_size=window_size, word_to_ix=word_to_ix)

# Instantiate a DataLoader
loader = tf.data.Dataset.from_tensor_slices(collate_fn(data)).batch(batch_size=batch_size)

# Initialize a model
# It is useful to put all the model hyperparameters in a dictionary
model_hyperparameters = {
    "batch_size": 4,
    "window_size": 2,
    "embed_dim": 25,
    "hidden_dim": 25,
    "freeze_embeddings": False,
}

vocab_size = len(word_to_ix)
model = WordWindowClassifier(model_hyperparameters, vocab_size)

# Define an optimizer
learning_rate = 0.01
sgd = optax.sgd(learning_rate=learning_rate)
optimizer = nnx.Optimizer(model, sgd)

# Define a loss function, which computes to binary cross entropy loss
def loss_function(model, batch_inputs, batch_labels, batch_lengths):
    # jax.value_and_grad computes the grad wrt the first argument
    # so model object needs to be the first argument
    batch_preds = model(batch_inputs)

    # Calculate the loss for the whole batch
    loss = optax.losses.sigmoid_binary_cross_entropy(batch_preds, batch_labels).sum()

    # Rescale the loss. Remember that we have used lengths to store the
    # number of words in each training example
    loss = loss / batch_lengths.sum()

    return loss

Unlike our earlier example, this time instead of passing all of our training data to the model at once in each epoch, we will be utilizing batches. Hence, in each training epoch iteration, we also iterate over the batches.

In [161]:
def tf_to_jax(arr):
  return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(arr))

def jax_to_tf(arr):
  return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))

In [162]:
# Function that will be called in every epoch
def train_epoch(loss_function, optimizer, model, loader):

  # Keep track of the total loss for the batch
  total_loss = 0
  for batch_inputs, batch_labels, batch_lengths in loader:
    # Convert to JAX w/o copying memory
    batch_inputs, batch_labels, batch_lengths = tf_to_jax(batch_inputs), tf_to_jax(batch_labels), tf_to_jax(batch_lengths)
    loss_, grads = nnx.value_and_grad(loss_function)(model, batch_inputs, batch_labels, batch_lengths)
    optimizer.update(grads)
    total_loss += loss_

  return total_loss


# Function containing our main training loop
def train(loss_function, optimizer, model, loader, num_epochs=10000):

  # Iterate through each epoch and call our train_epoch function
  for epoch in range(num_epochs):
    epoch_loss = train_epoch(loss_function, optimizer, model, loader)
    if epoch % 100 == 0: print(epoch_loss)

Let's start training!

In [163]:
warnings.filterwarnings('ignore')
num_epochs = 1000
train(loss_function, optimizer, model, loader, num_epochs=num_epochs)

3.2505436
2.8879259
2.716799
2.645041
2.6102877
2.5898073
2.5750124
2.5617964
2.5473454
2.5287414


### Prediction

Let's see how well our model is at making predictions. We can start by creating our test data.

In [73]:
# Create test sentences
test_corpus = ["She comes from Paris", "Stanford is not where he lives"]
test_sentences = [s.lower().split() for s in test_corpus]
test_labels = [[0, 0, 0, 1], [1, 0, 0, 0, 0, 0]]

# Create a test loader
test_data = list(zip(test_sentences, test_labels))
batch_size = 1
shuffle = False
window_size = 2
collate_fn = partial(custom_collate_fn, window_size=2, word_to_ix=word_to_ix)

test_loader = tf.data.Dataset.from_tensor_slices(collate_fn(test_data)).batch(batch_size=batch_size)

Let's loop over our test examples to see how well we are doing.

In [None]:
for test_instance, labels, _ in test_loader:
  batch_inputs, batch_labels = tf_to_jax(test_instance), tf_to_jax(labels)
  outputs = model(batch_inputs)
  print(batch_labels)
  print(outputs)

[[0 0 0 1 0 0]]
[[0.01398687 0.02790629 0.04582718 0.31703278 0.02662954 0.02201691]]
[[1 0 0 0 0 0]]
[[0.01715227 0.01450206 0.05668084 0.12209436 0.04397918 0.06782205]]
