# MNIST Handwritten Digit Classification

This notebook demonstrates a classic classification task using the `clownpiece` library to recognize handwritten digits from the MNIST dataset.

We will:
1.  Load the MNIST dataset using `torchvision`.
2.  Preprocess the data and convert it to `clownpiece` Tensors.
3.  Build and train a simple MLP classifier.
4.  Use `CrossEntropyLoss` for training.
5.  Evaluate the model's accuracy on a test set.

In [None]:
import sys
sys.path.append('../../../')

import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torch

from clownpiece import Tensor
from clownpiece.autograd import no_grad
from clownpiece.nn import Module, Linear, ReLU, Sequential, CrossEntropyLoss

### Helper Functions

To make the code cleaner and more readable, we define two helper functions. `to_CT_tensor` converts a `torch.Tensor` to a `clownpiece.Tensor`, and `to_numpy` converts a `clownpiece.Tensor` back to a `numpy.ndarray`.

In [None]:
def to_CT_tensor(torch_tensor, requires_grad=False):
    """Converts a torch.Tensor to a clownpiece.Tensor."""
    return Tensor(torch_tensor.numpy().tolist(), requires_grad=requires_grad)

def to_numpy(clownpiece_tensor):
    """Converts a clownpiece.Tensor to a numpy.ndarray."""
    return np.array(clownpiece_tensor.tolist())

### 1. Loading and Preprocessing the Data

We use `torchvision` to load the MNIST dataset. The dataset is transformed into tensors and normalized. We use `torch.utils.data.DataLoader` to create iterators for both the training and test sets, which will feed data to our model in batches.

In [None]:
# Load MNIST dataset using torchvision
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

print("MNIST dataset loaded successfully.")

### 2. Defining the Model Architecture

Our model is a simple Multi-Layer Perceptron (MLP) created using the `Sequential` container. It consists of three linear layers with ReLU activation functions in between. The input layer takes flattened 28x28 images (784 features), and the final layer outputs logits for the 10 digit classes.

In [None]:
# Define the model
input_features = 784  # 28x28 images flattened
num_classes = 10

model = Sequential(
    Linear(input_features, 128),
    ReLU(),
    Linear(128, 64),
    ReLU(),
    Linear(64, num_classes)
)

print("Model Architecture:")
print(model)

### 3. Training the Model

Here, we set up the training parameters. We use `CrossEntropyLoss` as our loss function, a fixed `learning_rate` of 0.01, and train for 3 epochs. We'll track the training loss and test accuracy during training.

In [None]:
# Loss and training parameters
loss_fn = CrossEntropyLoss()
learning_rate = 1e-2
epochs = 3
train_losses = []
test_accuracies = []

The training loop iterates through the dataset for a specified number of epochs. In each iteration (batch), it performs the following steps:
1.  **Forward Pass**: Computes the model's predictions (logits).
2.  **Loss Calculation**: Measures the difference between predictions and actual labels.
3.  **Backward Pass**: Computes gradients of the loss with respect to model parameters.
4.  **Weight Update**: Adjusts model weights using gradient descent.
5.  **Zero Gradients**: Resets gradients for the next iteration.

Every 20 batches, we evaluate the model's accuracy on the entire test set and record both the training loss and test accuracy for later visualization.

In [None]:
# Training loop
# May take up to ~10 minutes
for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # Flatten the data and convert to clownpiece Tensors
        data_flat = data.view(data.shape[0], -1)
        X = to_CT_tensor(data_flat)
        y = to_CT_tensor(target, requires_grad=False)

        # Forward pass
        logits = model(X)

        # Calculate loss
        loss = loss_fn(logits, y)

        # Backward pass
        loss.backward()

        # Update weights
        with no_grad():
            for param in model.parameters():
                if param.grad is not None:
                    param.copy_(param - param.grad * learning_rate)
        
        # Zero gradients
        for param in model.parameters():
            if param.grad is not None:
                param.grad = None

        if batch_idx % 20 == 0:
            train_losses.append(loss.item())
            
            # Evaluation on test set
            model.eval()
            correct = 0
            with no_grad():
                for test_data, test_target in test_loader:
                    test_data_flat = test_data.view(test_data.shape[0], -1)
                    X_test = to_CT_tensor(test_data_flat)
                    
                    logits_test = model(X_test)
                    pred = np.argmax(to_numpy(logits_test), axis=1)
                    correct += np.sum(pred == test_target.numpy())
            
            accuracy = 100. * correct / len(test_loader.dataset)
            test_accuracies.append(accuracy)
            
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]\tLoss: {loss.item():.6f}, Accuracy: {accuracy:.2f}%', flush=True)
            model.train() # Switch back to train mode

### 4. Visualizing the Results

Finally, we plot the training loss and test accuracy over the course of training. The x-axis for both plots represents the number of iterations (in twenties), allowing us to see how both metrics evolved as the model processed more data.

In [None]:
# Plotting results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses)
ax1.set_title("Training Loss")
ax1.set_xlabel("Iterations (x20)")
ax1.set_ylabel("Cross-Entropy Loss")
ax1.grid(True)

ax2.plot(test_accuracies)
ax2.set_title("Test Accuracy")
ax2.set_xlabel("Iterations (x20)")
ax2.set_ylabel("Accuracy (%)")
ax2.grid(True)

plt.show()

---

# Exploratory Task


Below is another example model which was the original candidate for this task. (which is a dummy transformer model).

However, TAs found that it failed to converge :(
  
If you are interested, **make a copy of this notebook**, then replace the model definition with the following code. (Perhaps make some change to fix it? Not sure if the model structure is just inpractice)

> the task is not graded.

```python
# Define the model
input_features = 784  # 28x28 images flattened
num_classes = 10

hidden_dim = 32
kernel_size = 4

class ImageEmbedding(Module):
    def __init__(self, input_features, hidden_dim, kernel_size = 4):
        super().__init__()
        self.input_features = input_features
        self.patch_size = kernel_size * kernel_size
        
        assert input_features % self.patch_size == 0, "input_features must be divisible by patch_size"
        self.num_patches = self.input_features // self.patch_size
        
        self.linear = Linear(self.patch_size, hidden_dim)

    def forward(self, x):
        # x: (batch_size, input_features)
        # Reshape to (batch_size, num_patches, patch_size)
        patches = x.reshape((-1, self.num_patches, self.patch_size))
        
        # Project patches to embeddings
        return self.linear(patches) # (batch_size, num_patches, patch_size)
    
class TransformerBlock(Module):
    def __init__(self, hidden_dim, num_heads, ffn_dim):
        super().__init__()
        self.attention = MultiheadAttention(hidden_dim, num_heads, True)
        
        self.mlp = Sequential(
            Linear(hidden_dim, ffn_dim),
            ReLU(),
            Linear(ffn_dim, hidden_dim)
        )

        self.layer_norm1 = LayerNorm(hidden_dim)
        self.layer_norm2 = LayerNorm(hidden_dim)

    def forward(self, x):
        x = x + self.attention(x)
        x = self.layer_norm1(x)
        x = x + self.mlp(x)
        x = self.layer_norm2(x)
        return x
    
class Reduce(Module):
    def forward(self, x):
        return x.mean(dim=-1)

model = Sequential(
    ImageEmbedding(input_features=input_features, hidden_dim=hidden_dim, kernel_size=kernel_size),
    ReLU(),
    TransformerBlock(hidden_dim, num_heads=4, ffn_dim=2*hidden_dim),
    ReLU(),
    Reduce(),
    Linear(input_features // (kernel_size * kernel_size), num_classes)
)

print("Model Architecture:")
print(model)
```