In [1]:
import equinox as eqx
from jaxtyping import Int, Array
from typing import List
from jax.nn import silu
import jax.numpy as jnp
from jax import vmap
import jax
import optax
from datasets import load_dataset

In [None]:
class MLP(eqx.Module):
    stack: List

    def __init__(
        self, rng, in_features: int, out_features: int, hidden_size: int, n_layers: int
    ):
        keys = jax.random.split(rng, n_layers)
        first_layer = eqx.nn.Linear(
            key=keys[0], in_features=in_features, out_features=hidden_size
        )
        last_layer = eqx.nn.Linear(
            key=keys[-1], in_features=hidden_size, out_features=out_features
        )
        self.stack = eqx.nn.Sequential(
            [first_layer, eqx.nn.RMSNorm(hidden_size)]
            + list(
                *zip(
                    [
                        eqx.nn.Linear(
                            key=keys[i],
                            in_features=hidden_size,
                            out_features=hidden_size,
                        )
                        for i in range(1, n_layers - 1)
                    ],
                    [eqx.nn.RMSNorm(hidden_size) for _ in range(1, n_layers - 1)],
                )
            )
            + [last_layer]
        )

    def __call__(self, x):
        for i in range(0, len(self.stack) - 1, 2):
            x = silu(self.stack[i + 1](self.stack[i](x)))
        return self.stack[-1](x)


rng = jax.random.PRNGKey(42)
mlp = MLP(rng, 4, 2, 128, 3)
mlp

MLP(
  stack=Sequential(
    layers=(
      Linear(
        weight=f32[128,4],
        bias=f32[128],
        in_features=4,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[128,128],
        bias=f32[128],
        in_features=128,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[2,128],
        bias=f32[2],
        in_features=128,
        out_features=2,
        use_bias=True
      )
    )
  )
)

In [3]:
vmapped = jax.jit(jax.vmap(mlp))
input = jax.random.normal(rng, (8, 4))
output = jax.random.normal(rng, (8, 2))
vmapped(input)

Array([[ 0.39269477, -0.11479907],
       [-0.0050479 ,  0.41385213],
       [ 0.27943578, -0.02031695],
       [-0.0377685 , -0.21515   ],
       [-0.11294541,  0.15727353],
       [-0.38071093,  0.14485542],
       [ 0.31257132,  0.323749  ],
       [ 0.3690706 , -0.47695047]], dtype=float32)

In [4]:
output

Array([[-0.02830462,  0.46713185],
       [ 0.29570296,  0.15354592],
       [-0.12403282,  0.21692315],
       [-1.4408789 ,  0.7558599 ],
       [ 0.52140963,  0.9101704 ],
       [-0.3844966 ,  1.1398233 ],
       [ 1.4457862 ,  1.0809066 ],
       [-0.05629321,  0.9095945 ]], dtype=float32)

In [None]:
@jax.jit
def cross_entropy_loss(model, x, y):
    return jnp.linalg.norm(vmap(model)(x) - y)


optim = optax.adamw(1e-5)


In [None]:
val, grads = jax.value_and_grad(cross_entropy_loss)(mlp, input, output)
opt_state = optim.init(mlp)
updates, opt_state = optim.update(grads, opt_state, eqx.filter(mlp, eqx.is_array))

In [7]:
updates

MLP(
  stack=Sequential(
    layers=(
      Linear(
        weight=f32[128,4],
        bias=f32[128],
        in_features=4,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[128,128],
        bias=f32[128],
        in_features=128,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[2,128],
        bias=f32[2],
        in_features=128,
        out_features=2,
        use_bias=True
      )
    )
  )
)

In [8]:
mlp_ = eqx.apply_updates(mlp, updates)
vmap(mlp_)(input)

Array([[ 0.39128035, -0.11116907],
       [-0.00265342,  0.41814655],
       [ 0.2782694 , -0.01608188],
       [-0.039456  , -0.21209644],
       [-0.11179534,  0.16160005],
       [-0.37851924,  0.14782496],
       [ 0.31495827,  0.32719198],
       [ 0.36712554, -0.4736853 ]], dtype=float32)

In [9]:
opt_state

(ScaleByAdamState(count=Array(1, dtype=int32), mu=MLP(
   stack=Sequential(
     layers=(
       Linear(
         weight=f32[128,4],
         bias=f32[128],
         in_features=4,
         out_features=128,
         use_bias=True
       ),
       RMSNorm(
         shape=(128,),
         eps=1e-05,
         use_weight=True,
         use_bias=True,
         weight=f32[128],
         bias=f32[128]
       ),
       Linear(
         weight=f32[128,128],
         bias=f32[128],
         in_features=128,
         out_features=128,
         use_bias=True
       ),
       RMSNorm(
         shape=(128,),
         eps=1e-05,
         use_weight=True,
         use_bias=True,
         weight=f32[128],
         bias=f32[128]
       ),
       Linear(
         weight=f32[2,128],
         bias=f32[2],
         in_features=128,
         out_features=2,
         use_bias=True
       )
     )
   )
 ), nu=MLP(
   stack=Sequential(
     layers=(
       Linear(
         weight=f32[128,4],
         bias=f32[

In [10]:
emb = eqx.nn.Embedding(num_embeddings=9, embedding_size=3, key=rng)
jnp.expand_dims(emb(0), 1)

Array([[-0.02830462],
       [ 0.46713185],
       [ 0.29570296]], dtype=float32)

In [11]:
attention = eqx.nn.MultiheadAttention(3, 15, 15, 15, 9, key=rng)

In [12]:
key1, key2, key3 = jax.random.split(rng, 3)
Q = jax.random.normal(key1, (100, 15))
K = jax.random.normal(key2, (100, 15))
V = jax.random.normal(key3, (100, 15))
attention(Q, K, V).shape

(100, 9)

In [None]:
from torch.utils.data import DataLoader, Dataset
import torchvision

normalise_data = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,)),
])


def collate(datapoint):
    images = []
    labels = []
    for d in datapoint:
        images.append(d["image"])
        labels.append(d["label"])

    return (jnp.array(images) / jnp.array(images).max()) * 255, jnp.array(labels)


dataset_train = load_dataset("mnist", split="train")
# dataset_train.set_format(type="torch")

dataset_test = load_dataset("mnist", split="test")
# dataset_test.set_format(type="torch")
train_loader = DataLoader(dataset_train, batch_size=2, collate_fn=collate)
test_loader = DataLoader(dataset_test, batch_size=2, collate_fn=collate)


In [None]:
class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def cross_entropy(y, pred_y):
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


def cross_entropy_loss(model: CNN, x, y):
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def l2_loss(model, x, y):
    return jnp.linalg.norm(vmap(model)(x) - y)


In [None]:
from utils import train

key, subkey = jax.random.split(rng, 2)
model = CNN(subkey)


train(
    model,
    trainloader=train_loader,
    testloader=test_loader,
    optim=optim,
    steps=200,
    print_every=400,
    loss_fn=l2_loss,
    evaluate_fn=lambda x, y: (jnp.array([-1.38]), jnp.array([0.76])),
)

ValueError: Incompatible shapes for broadcasting: shapes=[(2, 10), (2,)]

In [None]:
from transformers import MultiHeadAttention, TransformerBlock
import jax

rng = jax.random.PRNGKey(42)
attention = MultiHeadAttention(rng)

In [None]:
x = jax.random.normal(rng, shape=(16, 12))

In [None]:
y = attention(x)

In [4]:
block = TransformerBlock(rng)

In [None]:
block(x)

(16, 12)