In [None]:
!pip install --upgrade jaxflow[gpu]
# please restart the kernel after installing jaxflow

Collecting jaxflow[gpu]
  Downloading jaxflow-0.1.2.dev0-py3-none-any.whl.metadata (4.0 kB)
Collecting jax>=0.6.0 (from jaxflow[gpu])
  Downloading jax-0.6.0-py3-none-any.whl.metadata (22 kB)
Collecting numpy>=2.1.0 (from jaxflow[gpu])
  Downloading numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m495.2 kB/s[0m eta [36m0:00:00[0m
Collecting jaxlib<=0.6.0,>=0.6.0 (from jax>=0.6.0->jaxflow[gpu])
  Downloading jaxlib-0.6.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting ml_dtypes>=0.5.0 (from jax>=0.6.0->jaxflow[gpu])
  Downloading ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Collecting jax-cuda12-plugin<=0.6.0,>=0.6.0 (from jax-cuda12-plugin[with-cuda]<=0.6.0,>=0.6.0; extra == "cuda12"->jax[cuda12]>=0.6.0; extra == "gpu"->jaxflow[gpu])
  Downloading jax_cuda12_plugin-0.6.0-cp311-cp311-manylinux2014_

In [1]:
import jax
import jax.numpy as jnp
import jaxflow as jf
import tensorflow as tf
import time
import numpy as np


## 1. Load and preprocess MNIST
We first load the MNIST dataset and normalize pixel values to the [0, 1] range. We also add a channel dimension for compatibility with Conv2D layers.


In [2]:
# Load and preprocess MNIST
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype(jnp.float32) / 255.0
x_test = x_test.astype(jnp.float32) / 255.0
x_train = x_train[..., None]
x_test = x_test[..., None]
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)


(60000, 28, 28, 1) (60000,) (10000, 28, 28, 1) (10000,)



## 2. Define the CNN model
We define a simple CNN with two convolutional blocks followed by a fully connected layer and an output layer.

###  Subclassing Model Building


In [7]:


class CNN(jf.models.Model):
    def __init__(self, num_classes: int = 10, name: str = "MyCNN"):
        super().__init__(name=name)
        self.conv1 = jf.layers.Conv2D(filters=32, kernel_size=(3,3), activation=jf.activations.relu, kernel_initializer=jf.initializers.GlorotUniform, bias_initializer=jf.initializers.Zeros)
        self.pool1 = jf.layers.MaxPooling2D(pool_size=(2,2))
        self.conv2 = jf.layers.Conv2D(filters=64, kernel_size=(3,3), activation=jf.activations.relu, kernel_initializer=jf.initializers.GlorotUniform, bias_initializer=jf.initializers.Zeros)
        self.pool2 = jf.layers.MaxPooling2D(pool_size=(2,2))
        self.flatten = jf.layers.Flatten()
        self.dense1 = jf.layers.Dense(units=128, activation=jf.activations.relu, kernel_initializer=jf.initializers.GlorotUniform, bias_initializer=jf.initializers.Zeros)
        self.outputs = jf.layers.Dense(units=num_classes, activation=jf.activations.softmax, kernel_initializer=jf.initializers.GlorotUniform, bias_initializer=jf.initializers.Zeros)
    def call(self, inputs, training: bool = False):
        x = self.conv1(inputs, training=training)
        x = self.pool1(x, training=training)
        x = self.conv2(x, training=training)
        x = self.pool2(x, training=training)
        x = self.flatten(x)
        x = self.dense1(x, training=training)
        x = self.outputs(x, training=training)
        return x



# Build the model
model = CNN(num_classes=10)
model.build(input_shape=(None, 28, 28, 1))
print(model.summary())

Model 'MyCNN' summary:
  Block 0: <Conv2D filters=32, kernel_size=(3, 3), strides=(1, 1), padding=SAME, groups=1, built=True>
  Block 1: <MaxPooling2D pool_size=(2, 2), strides=(2, 2), padding=VALID, dilation=(1, 1), built=True>
  Block 2: <Conv2D filters=64, kernel_size=(3, 3), strides=(1, 1), padding=SAME, groups=1, built=True>
  Block 3: <MaxPooling2D pool_size=(2, 2), strides=(2, 2), padding=VALID, dilation=(1, 1), built=True>
  Block 4: <Flatten built=True, output_shape=(1, 3136)>
  Block 5: <Dense units=128, activation=relu, built=True>
  Block 6: <Dense units=10, activation=softmax, built=True>
None


## 3. Compile and train the model
We use the Adam optimizer and sparse categorical crossentropy loss. We train for 5 epochs with a batch size of 128.

In [8]:
strat_time = time.time()
# Train the model
optimizer = jf.optimizers.Adam(learning_rate=0.001)
loss_fn = jf.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, loss_fn=loss_fn,)
history = model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test), verbose=1)
print("Training time: ", time.time() - strat_time)
print("Model training completed.")

print("Model evaluation started:")
# Evaluate the model
pred = model.predict(x_test)
pred = jnp.argmax(pred, -1)
accuracy = jf.metrics.accuracy(y_test, pred)
precision = jf.metrics.precision(y_test, pred, average='macro',num_classes=10)
recall = jf.metrics.recall(y_test, pred, average='macro',num_classes=10)
f1 = jf.metrics.f1_score(y_test, pred, average='macro',num_classes=10)
print(f"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1: {f1}")

Epoch 1/5


Training: 100%|██████████| 938/938 [00:06<00:00] • , loss=0.1485


loss: 0.1485 — val_loss: 0.0614
Epoch 2/5


Training: 100%|██████████| 938/938 [00:02<00:00] • , loss=0.0462


loss: 0.0462 — val_loss: 0.0395
Epoch 3/5


Training: 100%|██████████| 938/938 [00:03<00:00] • , loss=0.0295


loss: 0.0295 — val_loss: 0.0370
Epoch 4/5


Training: 100%|██████████| 938/938 [00:02<00:00] • , loss=0.0204


loss: 0.0204 — val_loss: 0.0339
Epoch 5/5


Training: 100%|██████████| 938/938 [00:02<00:00] • , loss=0.0162


loss: 0.0162 — val_loss: 0.0383
Training time:  19.923725605010986
Model training completed.
Model evaluation started:
Accuracy: 0.9896999597549438, Precision: 0.989700436592102, Recall: 0.9894947409629822, F1: 0.9895975589752197


# Tensorflow

In [9]:
# TensorFlow

# 1. Define the CNN model
class TensorFlowCNN(tf.keras.models.Model):
    def __init__(self, num_classes: int = 10, name: str = "MyCNN"):
        super().__init__(name=name)
        self.conv1 = tf.keras.layers.Conv2D(
            filters=32, kernel_size=(3, 3), activation=tf.nn.relu
        )
        self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        self.conv2 = tf.keras.layers.Conv2D(
            filters=64, kernel_size=(3, 3), activation=tf.nn.relu
        )
        self.pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=128, activation=tf.nn.relu)
        self.outputs = tf.keras.layers.Dense(
            units=num_classes, activation=tf.nn.softmax
        )

    def call(self, inputs, training: bool = False):
        x = self.conv1(inputs)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dense1(x)
        return self.outputs(x)

# 2. Build & inspect
model = TensorFlowCNN(num_classes=10)
model.build(input_shape=(None, 28, 28, 1))
model.summary()




In [10]:

start_time = time.time()

optimizer = tf.optimizers.Adam(learning_rate=0.001)
loss_fn   = tf.losses.SparseCategoricalCrossentropy()

# add accuracy metric here
model.compile(
    optimizer=optimizer,
    loss=loss_fn,
)

history = model.fit(
    x_train, y_train,
    epochs=5,
    batch_size=64,
    validation_data=(x_test, y_test),
    verbose=1
)

print(f"Training time: {time.time() - start_time:.2f}s")
print("Model training completed.\n")

Epoch 1/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 5ms/step - loss: 0.3446 - val_loss: 0.0490
Epoch 2/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 4ms/step - loss: 0.0486 - val_loss: 0.0358
Epoch 3/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - loss: 0.0311 - val_loss: 0.0308
Epoch 4/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 0.0228 - val_loss: 0.0263
Epoch 5/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 0.0168 - val_loss: 0.0283
Training time: 30.07s
Model training completed.



In [11]:


print("Model evaluation started:")

# 1. Predict class‐probabilities and pick the most likely class
y_pred_probs = model.predict(x_test, batch_size=64, verbose=1)
y_pred = np.argmax(y_pred_probs, axis=-1)

# 2. Set up tf.keras metrics
acc_metric   = tf.keras.metrics.Accuracy()
prec_metric  = tf.keras.metrics.Precision()
rec_metric   = tf.keras.metrics.Recall()

# 3. Update them with true vs. predicted labels
acc_metric.update_state(y_test, y_pred)
prec_metric.update_state(y_test, y_pred)
rec_metric.update_state(y_test, y_pred)

# 4. Extract scalar results
accuracy  = acc_metric.result().numpy()
precision = prec_metric.result().numpy()
recall    = rec_metric.result().numpy()
# avoid division by zero just in case
f1 = 2 * (precision * recall) / (precision + recall + 1e-7)

# 5. Print nicely
print(
    f"Accuracy:  {accuracy:.4f}\n"
    f"Precision: {precision:.4f}\n"
    f"Recall:    {recall:.4f}\n"
    f"F1 Score:  {f1:.4f}"
)

Model evaluation started:
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
Accuracy:  0.9906
Precision: 0.9997
Recall:    0.9986
F1 Score:  0.9991


# PyTorch

In [12]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np

# 0. Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
learning_rate = 1e-3
num_epochs = 5
num_classes = 10

# 1. Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
)

# 2. Model
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

model = CNN(num_classes).to(device)


100%|██████████| 9.91M/9.91M [00:00<00:00, 14.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 500kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.59MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.33MB/s]


In [13]:

# 3. Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 4. Training loop with timing
total_start = time.perf_counter()
for epoch in range(1, num_epochs + 1):
    epoch_start = time.perf_counter()
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader, 1):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 100 == 0:
            avg = running_loss / 100
            print(f"Epoch [{epoch}/{num_epochs}]  Batch [{batch_idx}]  Loss: {avg:.4f}")
            running_loss = 0.0

    epoch_time = time.perf_counter() - epoch_start
    print(f"Epoch {epoch} completed in {epoch_time:.2f} seconds.\n")

total_time = time.perf_counter() - total_start
print(f"Total training time over {num_epochs} epochs: {total_time:.2f} seconds.\n")


  self.pid = os.fork()


Epoch [1/5]  Batch [100]  Loss: 0.5212
Epoch [1/5]  Batch [200]  Loss: 0.1558
Epoch [1/5]  Batch [300]  Loss: 0.1102
Epoch [1/5]  Batch [400]  Loss: 0.0855
Epoch [1/5]  Batch [500]  Loss: 0.0901
Epoch [1/5]  Batch [600]  Loss: 0.0735
Epoch [1/5]  Batch [700]  Loss: 0.0585
Epoch [1/5]  Batch [800]  Loss: 0.0546
Epoch [1/5]  Batch [900]  Loss: 0.0665


  self.pid = os.fork()


Epoch 1 completed in 14.55 seconds.

Epoch [2/5]  Batch [100]  Loss: 0.0481
Epoch [2/5]  Batch [200]  Loss: 0.0456
Epoch [2/5]  Batch [300]  Loss: 0.0401
Epoch [2/5]  Batch [400]  Loss: 0.0429
Epoch [2/5]  Batch [500]  Loss: 0.0441
Epoch [2/5]  Batch [600]  Loss: 0.0418
Epoch [2/5]  Batch [700]  Loss: 0.0338
Epoch [2/5]  Batch [800]  Loss: 0.0447
Epoch [2/5]  Batch [900]  Loss: 0.0400
Epoch 2 completed in 14.45 seconds.

Epoch [3/5]  Batch [100]  Loss: 0.0247
Epoch [3/5]  Batch [200]  Loss: 0.0252
Epoch [3/5]  Batch [300]  Loss: 0.0260
Epoch [3/5]  Batch [400]  Loss: 0.0253
Epoch [3/5]  Batch [500]  Loss: 0.0314
Epoch [3/5]  Batch [600]  Loss: 0.0230
Epoch [3/5]  Batch [700]  Loss: 0.0283
Epoch [3/5]  Batch [800]  Loss: 0.0265
Epoch [3/5]  Batch [900]  Loss: 0.0268
Epoch 3 completed in 15.44 seconds.

Epoch [4/5]  Batch [100]  Loss: 0.0147
Epoch [4/5]  Batch [200]  Loss: 0.0206
Epoch [4/5]  Batch [300]  Loss: 0.0179
Epoch [4/5]  Batch [400]  Loss: 0.0232
Epoch [4/5]  Batch [500]  Loss:

In [14]:

# 5. Evaluation
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1).cpu().numpy()
        all_preds.append(preds)
        all_labels.append(labels.numpy())

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

accuracy  = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
recall    = recall_score(all_labels, all_preds, average='macro', zero_division=0)
f1        = f1_score(all_labels, all_preds, average='macro', zero_division=0)

print("Model evaluation started:")
print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")


Model evaluation started:
Accuracy:  0.9899
Precision: 0.9899
Recall:    0.9899
F1 Score:  0.9898
