# MNIST Classification with JaxFlow
This notebook demonstrates how to build and train a Convolutional Neural Network on the MNIST dataset using JaxFlow.

In [None]:
! pip install --upgrade jaxflow

In [None]:
import jax
import jax.numpy as jnp
import jaxflow as jf
from jaxflow.models import Model
from jaxflow.layers import Conv2D, MaxPooling2D, Dense, Flatten
from jaxflow.initializers import GlorotUniform, Zeros
from jaxflow.optimizers import Adam
from jaxflow.losses import SparseCategoricalCrossentropy
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

## 1. Load and preprocess MNIST
We first load the MNIST dataset and normalize pixel values to the [0, 1] range. We also add a channel dimension for compatibility with Conv2D layers.

In [None]:
# Load and preprocess MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype(jnp.float32) / 255.0
x_test = x_test.astype(jnp.float32) / 255.0
x_train = x_train[..., None]
x_test = x_test[..., None]
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

## 2. Define the CNN model
We define a simple CNN with two convolutional blocks followed by a fully connected layer and an output layer.

In [None]:
class CNN(Model):
    def __init__(self, num_classes: int = 10, name: str = "MyCNN"):
        super().__init__(name=name)
        self.conv1 = Conv2D(filters=32, kernel_size=(3,3), activation=jf.activations.relu, kernel_initializer=GlorotUniform, bias_initializer=Zeros)
        self.pool1 = MaxPooling2D(pool_size=(2,2))
        self.conv2 = Conv2D(filters=64, kernel_size=(3,3), activation=jf.activations.relu, kernel_initializer=GlorotUniform, bias_initializer=Zeros)
        self.pool2 = MaxPooling2D(pool_size=(2,2))
        self.flatten = Flatten()
        self.dense1 = Dense(units=64, activation=jf.activations.relu, kernel_initializer=GlorotUniform, bias_initializer=Zeros)
        self.outputs = Dense(units=num_classes, activation=jf.activations.softmax, kernel_initializer=GlorotUniform, bias_initializer=Zeros)
    def call(self, inputs, training: bool = False):
        x = self.conv1(inputs, training=training)
        x = self.pool1(x, training=training)
        x = self.conv2(x, training=training)
        x = self.pool2(x, training=training)
        x = self.flatten(x)
        x = self.dense1(x, training=training)
        return self.outputs(x, training=training)

# Build the model
model = CNN(num_classes=10)
model.build(input_shape=(None, 28, 28, 1))
print(model.summary)

## 3. Compile and train the model
We use the Adam optimizer and sparse categorical crossentropy loss. We train for 5 epochs with a batch size of 128.

In [None]:
optimizer = Adam(learning_rate=0.001)
loss_fn = SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss_fn,)
history = model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test), verbose=2)

## 4. Evaluate on the test set

In [None]:
test_loss = model.evaluate(x_test, y_test, batch_size=128)
print(f"Test Loss: {test_loss}")

## 5. Plot training history

In [None]:
epochs = range(1, len(history['loss'])+1)
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(epochs, history['loss'], label='Train Loss')
plt.plot(epochs, history['val_loss'], label='Val Loss')
plt.title('Loss')
plt.legend()
plt.show()

### Conclusion
This notebook showed how to train a CNN on MNIST with JaxFlow. Feel free to experiment with different architectures, learning rates, and batch sizes to improve performance!