## Imports

In [None]:
import sys
sys.path.append("..")
from src import preprocess_mnist
from src import NNModel
from src import train_model
from src import plot_training_curves
from src import detect_convergence, plot_convergence
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os

## Displaying Non-Flattened MNIST Images

In [None]:
train_loader, val_loader, test_loader = preprocess_mnist(flatten=False)

images, labels = next(iter(train_loader))
fig, axes = plt.subplots(1, 8, figsize=(12, 2))
for i in range(8):
    axes[i].imshow(images[i].squeeze(), cmap='gray')
    axes[i].set_title(str(labels[i].item()))
    axes[i].axis('off')
plt.show()


## Displaying Flattened MNIST Data Information

In [None]:
train_loader, val_loader, test_loader = preprocess_mnist(batch_size=64, augment=False, flatten=True)

images, labels = next(iter(train_loader))
print(f"Images batch shape: {images.shape}")
print(f"Labels batch shape: {labels.shape}")
print(f"Example labels: {labels[:10]}")

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model, loss, optimizer
model = NNModel().to(device)
model.apply(model._init_weights)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

epochs = 50

checkpoint_path = "./checkpoints/mnist.pth"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
# Train
history = train_model(model, train_loader, val_loader, criterion, optimizer, epochs=epochs, device=device, checkpoint_path=checkpoint_path)

# Plot
plot_training_curves(history)
conv_epoch = detect_convergence(history["val_loss_mean"])
plot_convergence(history["train_loss_mean"], history["val_loss_mean"], conv_epoch)
