In [1]:
import jax
import jax.numpy as jnp
from jax import Array
from jaxtyping import Array, Float, PyTree
from typing import Callable, Tuple
import equinox as eqx
import optax

In [2]:
class Model(eqx.Module):
    """A simple neural network model"""

    layers: list

    def __init__(self, layers):
        self.layers = layers

    def __call__(self, x: Array) -> Array:
        """Forward pass of the model"""
        for layer in self.layers:
            x = layer(x)
        return x



In [5]:
import torchvision
import torch

# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)


normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)


In [10]:
next(iter(trainloader))[0].numpy().shape

(64, 1, 28, 28)

In [20]:
def f(a, b):
    return 2 * (a + b)

In [21]:
f(1, 2), f(2, 1)

(6, 6)

In [15]:
a = jnp.asarray([1,2,3,4,5,6,7,8,9])

In [23]:
f(a, a+1)

Array([ 6, 10, 14, 18, 22, 26, 30, 34, 38], dtype=int32)

In [17]:
b = a+1

In [18]:
b


Array([ 2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)

In [19]:
jnp.vdot(a, b)

Array(330, dtype=int32)