In [3]:
import torch
from torchvision import datasets, transforms
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import optax

# Hyperparameters
neurons = [28*28, 28*28, 28*28, 28*28, 28*28, 10]
lr = 0.01
batch_size = 8
total_samples = 60000
epochs = 1000

# Load MNIST data using PyTorch
transform = transforms.Compose([transforms.ToTensor(), lambda x: torch.flatten(x)])
mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

# Select a subset of 40 samples
subset_idx = torch.randperm(len(mnist_train))   #[ :total_samples]
train_data = [(mnist_train[i][0].numpy(), mnist_train[i][1]) for i in subset_idx]

# Prepare data
X_train = jnp.array([x[0] for x in train_data])  # Shape: (40, 28*28)
y_train = jnp.array([x[1] for x in train_data])  # Shape: (40,)

# One-hot encode the labels
y_train_onehot = jax.nn.one_hot(y_train, num_classes=10)

# Initialize weights and biases
key = random.PRNGKey(0)
params = []
for in_dim, out_dim in zip(neurons[:-1], neurons[1:]):
    k1, key = random.split(key)
    W = random.normal(k1, (in_dim, out_dim)) * jnp.sqrt(2 / in_dim)
    b = jnp.zeros(out_dim)
    params.append((W, b))

# Convert params into a single structure for Optax
def flatten_params(params):
    flat = {"layer_" + str(i): {"W": W, "b": b} for i, (W, b) in enumerate(params)}
    return flat

def unflatten_params(flat_params):
    unflat = [(flat_params[f"layer_{i}"]["W"], flat_params[f"layer_{i}"]["b"]) for i in range(len(params))]
    return unflat

opt_params = flatten_params(params)

# Define the forward pass
def forward(params, x):
    for i, (W, b) in enumerate(params):
        x = jnp.dot(x, W) + b
        if i < len(params) - 1:  # Apply activation only for hidden layers
            x = jax.nn.relu(x)
    return x

# Loss function (cross-entropy)
def loss_fn(params, x, y):
    logits = forward(params, x)
    loss = -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=1))
    return loss

# Compute accuracy
def accuracy_fn(params, x, y):
    logits = forward(params, x)
    predictions = jnp.argmax(logits, axis=1)
    return jnp.mean(predictions == jnp.argmax(y, axis=1))

# Optax Adam optimizer
optimizer = optax.adam(lr)
opt_state = optimizer.init(opt_params)

# Gradient update function
@jax.jit
def update_fn(opt_params, opt_state, x, y):
    def loss_closure(flat_params):
        unflat = unflatten_params(flat_params)
        return loss_fn(unflat, x, y)
    
    grads = jax.grad(loss_closure)(opt_params)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(opt_params, updates)
    return new_params, opt_state

# Training loop
for epoch in range(epochs):
    # Shuffle data
    perm = np.random.permutation(total_samples)
    X_train_shuffled = X_train[perm]
    y_train_onehot_shuffled = y_train_onehot[perm]

    # Process each batch
    for i in range(0, total_samples, batch_size):
        X_batch = X_train_shuffled[i:i+batch_size]
        y_batch = y_train_onehot_shuffled[i:i+batch_size]

        opt_params, opt_state = update_fn(opt_params, opt_state, X_batch, y_batch)

    # Compute loss and accuracy on the full training set
    params = unflatten_params(opt_params)
    train_loss = loss_fn(params, X_train, y_train_onehot)
    train_acc = accuracy_fn(params, X_train, y_train_onehot)
    print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
    if train_loss < 0.00001:
        break


Epoch 1, Loss: 0.6938, Accuracy: 0.7916
Epoch 2, Loss: 0.3918, Accuracy: 0.9053
Epoch 3, Loss: 0.5479, Accuracy: 0.8582
Epoch 4, Loss: 0.4336, Accuracy: 0.9182
Epoch 5, Loss: 0.2932, Accuracy: 0.9292
Epoch 6, Loss: 0.3755, Accuracy: 0.9213
Epoch 7, Loss: 0.3928, Accuracy: 0.9186
Epoch 8, Loss: 0.4687, Accuracy: 0.9205
Epoch 9, Loss: 0.9370, Accuracy: 0.7455
Epoch 10, Loss: 0.8330, Accuracy: 0.7128
Epoch 11, Loss: 0.9690, Accuracy: 0.6544
Epoch 12, Loss: 0.8165, Accuracy: 0.6888
Epoch 13, Loss: 0.7127, Accuracy: 0.7513
Epoch 14, Loss: 0.6788, Accuracy: 0.7712
Epoch 15, Loss: 1.2492, Accuracy: 0.5151
Epoch 16, Loss: 0.9478, Accuracy: 0.6216
Epoch 17, Loss: 0.9827, Accuracy: 0.5879
Epoch 18, Loss: 1.2964, Accuracy: 0.4739
Epoch 19, Loss: 1.4637, Accuracy: 0.4333
Epoch 20, Loss: 1.1621, Accuracy: 0.5499
Epoch 21, Loss: 1.6922, Accuracy: 0.2813
Epoch 22, Loss: 1.7020, Accuracy: 0.2791
Epoch 23, Loss: 1.9011, Accuracy: 0.2376
Epoch 24, Loss: 1.6968, Accuracy: 0.2785
Epoch 25, Loss: 1.7327, A

KeyboardInterrupt: 