In [1]:
import jaxon

In [2]:
!nvidia-smi

  pid, fd = os.forkpty()


Sat Mar 23 22:57:52 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.10              Driver Version: 551.61         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
|  0%   38C    P8             14W /  450W |    1283MiB /  24564MiB |     26%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
print(jaxon.current_device())
print(jaxon.device_count())
print(jaxon.cuda_is_available())
print(jaxon.set_device("gpu"))

cuda:0
1
True
Set device to gpu
None


In [4]:
from jax.lib import xla_bridge

print(xla_bridge.get_backend().platform)

gpu


In [5]:
import jax
import jax.numpy as jnp
import tensorflow as tf
import numpy as np
import jax.numpy as jnp
from jaxon import Sequential, Linear, ReLU, Conv2D, Flatten
from jax import random


def load_and_preprocess_data():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = np.expand_dims(x_train, -1) / 255.0
    x_test = np.expand_dims(x_test, -1) / 255.0
    x_train = x_train.transpose((0, 3, 1, 2))
    x_test = x_test.transpose((0, 3, 1, 2))
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)
    return jnp.array(x_train), jnp.array(y_train), jnp.array(x_test), jnp.array(y_test)


# Maybe we can implement a DataLoader here, but for now we will just load the data w/ the function.
x_train_jax, y_train_jax, x_test_jax, y_test_jax = load_and_preprocess_data()


y_train_jax = y_train_jax.astype(jnp.float32)
y_test_jax = y_test_jax.astype(jnp.float32)
x_train_jax = x_train_jax.astype(jnp.float32)
x_test_jax = x_test_jax.astype(jnp.float32)


def create_cnn_model():
    model = Sequential(
        Conv2D(16, (3, 3), stride=1, padding="SAME"),
        ReLU(),
        Conv2D(32, (3, 3), stride=1, padding="SAME"),
        ReLU(),
        Flatten(),
        Linear(32 * 28 * 28, 10),
    )
    return model


rng = random.PRNGKey(0)
input_shape = (1, 1, 28, 28)  # Example input shape for MNIST (N, C, H, W)
model = create_cnn_model()
output_shape, params = model.init_params(rng, input_shape)



In [6]:
# hand implemented for now, we can add it directly later on.

In [7]:
from jax.example_libraries import optimizers

step_size = 0.001
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

In [8]:
from jax import value_and_grad

def cross_entropy_loss(logits, labels):
    return -jnp.mean(jnp.sum(labels * jax.nn.log_softmax(logits), axis=1))


@jax.jit
def train_step(opt_state, x_batch, y_batch):
    params = get_params(opt_state)

    def loss_fn(params):
        logits = model(x_batch, params)
        loss = cross_entropy_loss(logits, y_batch)
        return loss

    grads = jax.grad(loss_fn)(params)
    return opt_update(0, grads, opt_state)

In [12]:
def batch_accuracy(params, x, y):
    logits = model(x, params)
    pred_classes = jnp.argmax(logits, axis=1)
    true_classes = jnp.argmax(y, axis=1)
    accuracy = jnp.mean(pred_classes == true_classes)
    return accuracy

def compute_accuracy_over_dataset(params, x_data, y_data, batch_size):
    num_batches = len(x_data) // batch_size
    acc_sum = 0.0
    for i in range(0, len(x_data), batch_size):
        x_batch = x_data[i:i+batch_size]
        y_batch = y_data[i:i+batch_size]
        acc_sum += batch_accuracy(params, x_batch, y_batch)
    return acc_sum / num_batches

num_epochs = 10
batch_size = 128

for epoch in range(num_epochs):
    for i in range(0, len(x_train_jax), batch_size):
        x_batch = x_train_jax[i:i+batch_size]
        y_batch = y_train_jax[i:i+batch_size]
        opt_state = train_step(opt_state, x_batch, y_batch)
    
    params = get_params(opt_state)
    train_acc = compute_accuracy_over_dataset(params, x_train_jax, y_train_jax, batch_size)
    test_acc = compute_accuracy_over_dataset(params, x_test_jax, y_test_jax, batch_size)
    print(f"Epoch {epoch + 1} completed. Training Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}")

Epoch 1 completed. Training Accuracy: 0.9366, Test Accuracy: 0.9504
Epoch 2 completed. Training Accuracy: 0.9716, Test Accuracy: 0.9823
Epoch 3 completed. Training Accuracy: 0.9807, Test Accuracy: 0.9893
Epoch 4 completed. Training Accuracy: 0.9846, Test Accuracy: 0.9922
Epoch 5 completed. Training Accuracy: 0.9868, Test Accuracy: 0.9936
Epoch 6 completed. Training Accuracy: 0.9884, Test Accuracy: 0.9933
Epoch 7 completed. Training Accuracy: 0.9897, Test Accuracy: 0.9939
Epoch 8 completed. Training Accuracy: 0.9906, Test Accuracy: 0.9939
Epoch 9 completed. Training Accuracy: 0.9916, Test Accuracy: 0.9947
Epoch 10 completed. Training Accuracy: 0.9923, Test Accuracy: 0.9946
