In [1]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
from jax.nn import relu, log_softmax
from jax.nn.initializers import glorot_uniform
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# Hyperparameters
batch_size = 4
learning_rate = 0.01
epochs = 5

# Load and preprocess MNIST data
train_data = MNIST(root="./", train=True, download=True, transform=ToTensor())
test_data = MNIST(root="./", train=False, download=True, transform=ToTensor())

x_train, y_train = zip(*(list(train_data)[:40]))
x_test, y_test = zip(*test_data)

x_train, y_train = jnp.array(np.stack(x_train)[:, 0, :, :]), jnp.array(y_train)
x_test, y_test = jnp.array(np.stack(x_test)[:, 0, :, :]), jnp.array(y_test)

classes = len(set(y_train.tolist()))
y_train = jax.nn.one_hot(y_train, classes)
y_test = jax.nn.one_hot(y_test, classes)

# Helper to compute output size after conv and pooling layers
def compute_output_shape(input_shape, kernel_size, stride, padding):
    return (input_shape - kernel_size + 2 * padding) // stride + 1

# Define CNN model
def init_cnn_params(rng, input_shape):
    k1, k2, k3 = random.split(rng, 3)

    conv1_out_channels = 4
    conv2_out_channels = 2

    # Calculate spatial size after Conv1 and pooling
    conv1_out_size = compute_output_shape(28, 3, 1, 1)  # SAME padding
    pool1_out_size = compute_output_shape(conv1_out_size, 2, 2, 0)  # VALID pooling

    # Calculate spatial size after Conv2 and pooling
    conv2_out_size = compute_output_shape(pool1_out_size, 3, 1, 1)  # SAME padding
    pool2_out_size = compute_output_shape(conv2_out_size, 2, 2, 0)  # VALID pooling

    # Final flattened size
    flattened_size = pool2_out_size * pool2_out_size * conv2_out_channels

    params = {
        "conv1": glorot_uniform()(k1, (conv1_out_channels, 1, 3, 3)),  # 4 filters, 3x3 kernel
        "conv2": glorot_uniform()(k2, (conv2_out_channels, conv1_out_channels, 3, 3)),  # 2 filters, 3x3 kernel
        "fc": glorot_uniform()(k3, (flattened_size, classes)),  # Fully connected
    }
    return params, pool2_out_size

def cnn_forward(params, x):
    # Conv Layer 1
    x = jax.lax.conv_general_dilated(
        x, params["conv1"], window_strides=(1, 1), padding="SAME", dimension_numbers=("NCHW", "OIHW", "NCHW")
    )
    x = relu(x)
    x = jax.lax.reduce_window(
        x, 
        -jnp.inf, 
        jax.lax.max, 
        window_dimensions=(1, 1, 2, 2),  # Pooling over spatial dimensions only
        window_strides=(1, 1, 2, 2), 
        padding="VALID"
    )
    
    # Conv Layer 2
    x = jax.lax.conv_general_dilated(
        x, params["conv2"], window_strides=(1, 1), padding="SAME", dimension_numbers=("NCHW", "OIHW", "NCHW")
    )
    x = relu(x)
    x = jax.lax.reduce_window(
        x, 
        -jnp.inf, 
        jax.lax.max, 
        window_dimensions=(1, 1, 2, 2),  # Pooling over spatial dimensions only
        window_strides=(1, 1, 2, 2), 
        padding="VALID"
    )
    
    # Flatten
    x = x.reshape(x.shape[0], -1)
    
    # Fully Connected Layer
    logits = jnp.dot(x, params["fc"])
    return log_softmax(logits)

# Loss function
def cross_entropy_loss(params, x, y):
    logits = cnn_forward(params, x)
    return -jnp.mean(jnp.sum(logits * y, axis=-1))

# Accuracy
def accuracy(params, x, y):
    preds = jnp.argmax(cnn_forward(params, x), axis=-1)
    labels = jnp.argmax(y, axis=-1)
    return jnp.mean(preds == labels)

# Training function
@jit
def train_step(params, x, y, opt_state):
    grads = grad(cross_entropy_loss)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# Shuffle function
def shuffle_data(x, y, rng):
    indices = jax.random.permutation(rng, len(x))
    return x[indices], y[indices]

# Initialize parameters and optimizer
rng = random.PRNGKey(0)
input_shape = (batch_size, 1, 28, 28)
params, pool2_out_size = init_cnn_params(rng, input_shape)

import optax
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Training loop
for epoch in range(epochs):
    rng, subkey = random.split(rng)
    x_train, y_train = shuffle_data(x_train, y_train, subkey)
    
    for i in range(0, len(x_train), batch_size):
        x_batch = x_train[i:i+batch_size][:, None, :, :]
        y_batch = y_train[i:i+batch_size]
        params, opt_state = train_step(params, x_batch, y_batch, opt_state)
    
    # Evaluate on the test set
    test_acc = accuracy(params, x_test[:, None, :, :], y_test)
    print(f"Epoch {epoch + 1}, Test Accuracy: {test_acc:.4f}")


2024-11-25 21:52:49.831918: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 352.46MiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-11-25 21:52:50.217083: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 348.03MiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


Epoch 1, Test Accuracy: 0.1795
Epoch 2, Test Accuracy: 0.2028
Epoch 3, Test Accuracy: 0.2521
Epoch 4, Test Accuracy: 0.3597
Epoch 5, Test Accuracy: 0.5163
