In [15]:
import equinox as eqx
from jaxtyping import Int, Array
from typing import List
from jax.nn import silu
import jax.numpy as jnp
from jax import vmap
import jax
import optax
from datasets import load_dataset

In [None]:
class MLP(eqx.Module):
    stack: List

    def __init__(
        self, rng, in_features: int, out_features: int, hidden_size: int, n_layers: int
    ):
        keys = jax.random.split(rng, n_layers)
        first_layer = eqx.nn.Linear(
            key=keys[0], in_features=in_features, out_features=hidden_size
        )
        last_layer = eqx.nn.Linear(
            key=keys[-1], in_features=hidden_size, out_features=out_features
        )
        self.stack = eqx.nn.Sequential(
            [first_layer, eqx.nn.RMSNorm(hidden_size)]
            + list(
                *zip(
                    [
                        eqx.nn.Linear(
                            key=keys[i],
                            in_features=hidden_size,
                            out_features=hidden_size,
                        )
                        for i in range(1, n_layers - 1)
                    ],
                    [eqx.nn.RMSNorm(hidden_size) for _ in range(1, n_layers - 1)],
                )
            )
            + [last_layer]
        )

    def __call__(self, x):
        for i in range(0, len(self.stack) - 1, 2):
            x = silu(self.stack[i + 1](self.stack[i](x)))
        return self.stack[-1](x)


rng = jax.random.PRNGKey(42)
mlp = MLP(rng, 4, 2, 128, 3)
mlp

MLP(
  stack=Sequential(
    layers=(
      Linear(
        weight=f32[128,4],
        bias=f32[128],
        in_features=4,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[128,128],
        bias=f32[128],
        in_features=128,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[28,128],
        bias=f32[28],
        in_features=128,
        out_features=28,
        use_bias=True
      )
    )
  )
)

In [17]:
vmapped = jax.jit(jax.vmap(mlp))
input = jax.random.normal(rng, (8, 4))
output = jax.random.normal(rng, (8, 2))
vmapped(input)

2025-07-04 19:58:24.282314: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


Array([[ 0.3927798 , -0.11479208, -0.0966991 , -0.1175942 , -0.3287928 ,
        -0.22363243,  0.32632768,  0.25224185,  0.3651327 ,  0.65944344,
        -0.09246576,  0.5197641 , -0.24735266, -0.35326168,  0.2137914 ,
        -0.73322034,  0.2726527 ,  0.34407556,  0.27747425,  0.17600326,
         0.40959322,  0.03169852,  0.40187883,  0.2058509 ,  0.1483576 ,
        -0.00236163,  0.0848546 , -0.33372128],
       [-0.00490307,  0.41410923,  0.35056818, -0.10977131,  0.15678823,
         0.39076808, -0.6864499 ,  0.27163365, -0.27491945,  0.9686847 ,
         0.11093152,  0.6773967 ,  0.17039207, -0.48786467,  0.20055461,
        -0.82962775,  0.0548249 , -0.17941645,  0.10705339, -0.31347838,
         0.19729082, -0.03615243,  0.59632343,  0.76601845,  0.11756848,
         0.36236757, -0.07336544, -0.0105912 ],
       [ 0.27948594, -0.02019237,  0.27692765,  0.11006484, -0.19795156,
        -0.00253476, -0.39190334,  0.16499692,  0.24604544,  0.79526997,
        -0.13977648,  0.6046

In [18]:
output

Array([[-0.02830462,  0.46713185],
       [ 0.29570296,  0.15354592],
       [-0.12403282,  0.21692315],
       [-1.4408789 ,  0.7558599 ],
       [ 0.52140963,  0.9101704 ],
       [-0.3844966 ,  1.1398233 ],
       [ 1.4457862 ,  1.0809066 ],
       [-0.05629321,  0.9095945 ]], dtype=float32)

In [19]:
@jax.jit
def loss(model, x, y):
    return jnp.linalg.norm(vmap(model)(x) - y)


optim = optax.adamw(1e-5)


In [20]:
val, grads = jax.value_and_grad(loss)(mlp, input, output)
opt_state = optim.init(mlp)
updates, opt_state = optim.update(grads, opt_state, eqx.filter(mlp, eqx.is_array))

TypeError: sub got incompatible shapes for broadcasting: (8, 28), (8, 2).

In [None]:
updates

MLP(
  stack=Sequential(
    layers=(
      Linear(
        weight=f32[128,4],
        bias=f32[128],
        in_features=4,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[128,128],
        bias=f32[128],
        in_features=128,
        out_features=128,
        use_bias=True
      ),
      RMSNorm(
        shape=(128,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[128],
        bias=f32[128]
      ),
      Linear(
        weight=f32[2,128],
        bias=f32[2],
        in_features=128,
        out_features=2,
        use_bias=True
      )
    )
  )
)

In [None]:
mlp_ = eqx.apply_updates(mlp, updates)
vmap(mlp_)(input)

Array([[ 0.39128035, -0.11116907],
       [-0.0026536 ,  0.41814664],
       [ 0.27826947, -0.01608192],
       [-0.03945597, -0.21209641],
       [-0.11179535,  0.1616001 ],
       [-0.37851998,  0.14782508],
       [ 0.31495827,  0.3271922 ],
       [ 0.36712554, -0.4736853 ]], dtype=float32)

In [None]:
opt_state

(ScaleByAdamState(count=Array(1, dtype=int32), mu=MLP(
   stack=Sequential(
     layers=(
       Linear(
         weight=f32[128,4],
         bias=f32[128],
         in_features=4,
         out_features=128,
         use_bias=True
       ),
       RMSNorm(
         shape=(128,),
         eps=1e-05,
         use_weight=True,
         use_bias=True,
         weight=f32[128],
         bias=f32[128]
       ),
       Linear(
         weight=f32[128,128],
         bias=f32[128],
         in_features=128,
         out_features=128,
         use_bias=True
       ),
       RMSNorm(
         shape=(128,),
         eps=1e-05,
         use_weight=True,
         use_bias=True,
         weight=f32[128],
         bias=f32[128]
       ),
       Linear(
         weight=f32[2,128],
         bias=f32[2],
         in_features=128,
         out_features=2,
         use_bias=True
       )
     )
   )
 ), nu=MLP(
   stack=Sequential(
     layers=(
       Linear(
         weight=f32[128,4],
         bias=f32[

In [None]:
emb = eqx.nn.Embedding(num_embeddings=9, embedding_size=3, key=rng)
jnp.expand_dims(emb(0), 1)

Array([[-0.02830462],
       [ 0.46713185],
       [ 0.29570296]], dtype=float32)

In [None]:
attention = eqx.nn.MultiheadAttention(3, 15, 15, 15, 9, key=rng)

In [None]:
key1, key2, key3 = jax.random.split(rng, 3)
Q = jax.random.normal(key1, (100, 15))
K = jax.random.normal(key2, (100, 15))
V = jax.random.normal(key3, (100, 15))
attention(Q, K, V).shape

(100, 9)

In [None]:
from torch.utils.data import DataLoader, Dataset


def collate(datapoint):
    images = []
    labels = []
    for d in datapoint:
        images.append(d["image"])
        labels.append(d["label"])

    return jnp.array(images), jnp.array(labels)


dataset_train = load_dataset("mnist", split="train")
dataset_train.set_format(type="torch")

dataset_test = load_dataset("mnist", split="test")
dataset_test.set_format(type="torch")
train_loader = DataLoader(dataset_train, batch_size=2, collate_fn=collate)
test_loader = DataLoader(dataset_test, batch_size=2, collate_fn=collate)


In [21]:
from utils import train

mlp = MLP(rng, 28, 10, 128, 3)

train(
    mlp,
    trainloader=train_loader,
    testloader=test_loader,
    optim=optim,
    steps=200,
    print_every=400,
    loss_fn=loss,
    evaluate_fn=lambda x: x,
)

ValueError: Incompatible shapes for broadcasting: shapes=[(1, 128, 28), (128,)]