<a href="https://colab.research.google.com/github/donlap/stat424/blob/main/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[Equinox Documentation](https://docs.kidger.site/equinox/)

In [None]:
!pip install -q equinox

In [None]:
import equinox as eqx
import optax
import jax
import jax.numpy as jnp
from jax import random
from tensorflow.keras.datasets import mnist

import matplotlib.pyplot as plt

In [None]:
# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Normalize and add channel dimension
X_train = X_train.astype('float32') / 255.0
X_train = X_train[..., jnp.newaxis]  # Add channel: (60000, 28, 28, 1)

# Use subset
X_train = X_train[:5000]
y_train = y_train[:5000]

print(f"Training data: {X_train.shape}")

### Visualize data

In [None]:
plt.imshow(X_train[0])

In [None]:
y_train[0]

### Build a model

[Convolutional layer in Equinox](https://docs.kidger.site/equinox/api/nn/conv/)

In [None]:
class SimpleCNN(eqx.Module):
  conv1: eqx.nn.Conv2d
  conv2: eqx.nn.Conv2d
  pool: eqx.nn.MaxPool2d
  lin1 : eqx.nn.Linear
  lin2: eqx.nn.Linear

  def __init__(self, key):
    key1, key2, key3, key4 = random.split(key, 4)
    self.conv1 = eqx.nn.Conv2d(1, 32, kernel_size=3, key=key1)
    self.conv2 = eqx.nn.Conv2d(32, 64, kernel_size=3, key=key2)
    self.pool = eqx.nn.MaxPool2d(kernel_size=2, stride=2)
    self.lin1 = eqx.nn.Linear(9216, 128, key=key3)
    self.lin2 = eqx.nn.Linear(128, 10, key=key4)


### Loss function

[list of loss functions](https://optax.readthedocs.io/en/latest/api/losses.html)

In [None]:
def loss_fn(model, x, y):
  y_pred = jax.vmap(model)(x)

  return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(y_pred, y))

### Optimizer

[list of optimizers](https://optax.readthedocs.io/en/latest/api/optimizers.html)

In [None]:
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

opt_state

### Training

### Visualize predictions

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(12, 3))
for i, ax in enumerate(axes):
    pred = jnp.argmax(model(X_train[i]))
    ax.imshow(X_train[i].squeeze(), cmap='gray')
    ax.set_title(f"Predicted: {pred}")
    ax.axis('off')
plt.show()

In [None]:
@jax.jit
def log2_with_print(x):
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  ln_3 = jnp.log(3.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))