In [None]:
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from datetime import datetime
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import os
import torch
import mlflow
import random

In [None]:
# conf.
EXPERIMENT_NAME = "MNIST with CNN v.4"
TRACKING_URL = "http://localhost:5500"
RUN_NAME = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
RUN_DESCRIPTION = "Runs training with the top 5 parameter combinations based on 1-epoch validation accuracy (dropout, batch size, learning rate, gamma, step, weight decay)."

In [None]:
def seed_everything(seed: int) -> None:
    """
    Sets random seeds for reproducibility across random, numpy, and PyTorch.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU

In [None]:
class ConvNeuralNet(nn.Module):
    """
    A convolutional neural network for image classification.

    Architecture:
    - Two convolutional blocks with ReLU, MaxPool and Dropout.
    - Fully connected classifier with two hidden layers and dropout.
    - Outputs class logits (e.g. for CrossEntropyLoss).
    """

    def __init__(self, dropout1=0.5, dropout2=0.25):
        super().__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(p=dropout2),
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=dropout2),
        )

        self.classifier1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU(),
            nn.Dropout(p=dropout1),
            nn.Linear(256, 10),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout1),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.classifier(x)
        return x

In [None]:
def select_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

In [None]:
def save_and_log_model_checkpoint(model, path, artifact_dir):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    mlflow.log_artifact(path, artifact_path=artifact_dir)

In [None]:
def run_training_pipeline(params=None):
    """Log model architecture parameters."""

    default_params = {
        "epochs": 100,
        "scheduler_step": 1,
        "scheduler_gamma": 0.9,
        "learning_rate": 0.001,
        "weight_decay": 0.00001,
        "batch_size": 128,
        "num_workers": 4,
        "dropout1": 0.5,
        "dropout2": 0.25,
        "patience": 5,
        "min_save_accuracy": 99.5,  # min validation accuracy to save the model
    }

    # update with custom parameters
    if params is not None:
        default_params.update(params)

    # extract parameters
    epochs = default_params["epochs"]
    scheduler_step = default_params["scheduler_step"]
    scheduler_gamma = default_params["scheduler_gamma"]
    learning_rate = default_params["learning_rate"]
    weight_decay = default_params["weight_decay"]
    batch_size = default_params["batch_size"]
    num_workers = default_params["num_workers"]
    dropout1 = default_params["dropout1"]
    dropout2 = default_params["dropout2"]
    patience = default_params["patience"]
    min_save_accuracy = default_params["min_save_accuracy"]

    # set up model, optimizer, scheduler and data transforms
    device = select_device()
    model = ConvNeuralNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=scheduler_step, gamma=scheduler_gamma)

    transform = transforms.Compose([
        transforms.RandomRotation(12), 
        transforms.RandomResizedCrop(28, scale=(0.9, 1.0)),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
    ])

    # load the MNIST dataset and prepare DataLoaders for training and validation
    train_dataset = MNIST(root=".", train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dataset = MNIST(root=".", train=False, transform=transforms.ToTensor(), download=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)

    # set up and start the MLflow experiment run
    mlflow.set_tracking_uri(TRACKING_URL)
    mlflow.set_experiment(EXPERIMENT_NAME)

    with mlflow.start_run(run_name=RUN_NAME):
        run_id = mlflow.active_run().info.run_id
        client = mlflow.tracking.MlflowClient()
        client.set_tag(run_id, "mlflow.note.content", RUN_DESCRIPTION)

        # log the parameters
        all_params = {
            "model_summary": str(model),
            "criterion": "CrossEntropyLoss",
            "optimizer": "ADAM",
            "scheduler_step": scheduler_step,
            "scheduler_gamma": scheduler_gamma,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "batch_size": batch_size,
            "data_loader_workers": num_workers,
            "dropout1": dropout1,
            "dropout2": dropout2,
            "patience": patience,
            "min_save_accuracy": min_save_accuracy,
        }
        mlflow.log_params(all_params)

        # initialize early stopping counter and best validation loss
        early_stop_counter, best_val_loss = 0, float("inf")

        for epoch in range(1, epochs + 1):
            total_train_loss, total_train_correct, total_train_samples = 0, 0, 0
            total_val_loss, total_val_correct, total_val_samples = 0, 0, 0

            model.train()
            for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
                imgs, labels = imgs.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                total_train_loss += loss.item()
                preds = outputs.argmax(1)
                total_train_correct += (preds == labels).sum().item()
                total_train_samples += labels.size(0)

            avg_train_loss = total_train_loss / len(train_loader)
            train_accuracy = 100 * total_train_correct / total_train_samples

            model.eval()
            with torch.no_grad():
                for imgs, labels in val_loader:
                    imgs, labels = imgs.to(device), labels.to(device)
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)

                    total_val_loss += loss.item()
                    preds = outputs.argmax(1)
                    total_val_correct += (preds == labels).sum().item()
                    total_val_samples += labels.size(0)

            avg_val_loss = total_val_loss / len(val_loader)
            val_accuracy = 100 * total_val_correct / total_val_samples

            print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.2f} | Train Acc: {train_accuracy:.2f}% | "
                f"Val Loss: {avg_val_loss:.2f} | Val Acc: {val_accuracy:.2f}%")

            # log to MLflow
            mlflow.log_metrics(
                {
                    "loss/train": avg_train_loss,
                    "loss/validate": avg_val_loss,
                    "accuracy/train": train_accuracy,
                    "accuracy/validate": val_accuracy,
                },
                step=epoch,
            )
            scheduler.step()

            improved = avg_val_loss < best_val_loss
            should_save_checkpoint = (val_accuracy >= min_save_accuracy and epoch % 3 == 0)
            is_final_epoch = epoch == epochs
            should_early_stop = early_stop_counter >= patience

            if improved:
                best_val_loss = avg_val_loss
                early_stop_counter = 0
                best_model_path = (f"ml_models/{EXPERIMENT_NAME}/{RUN_NAME}/best_model.pt")
                save_and_log_model_checkpoint(model, best_model_path, artifact_dir="best_model")
            else:
                early_stop_counter += 1

            if should_save_checkpoint:
                checkpoint_path = (f"ml_models/{EXPERIMENT_NAME}/{RUN_NAME}/epoch_{epoch}.pt")
                save_and_log_model_checkpoint(model, checkpoint_path, artifact_dir="checkpoints")

            if should_early_stop:
                last_model_path = (f"ml_models/{EXPERIMENT_NAME}/{RUN_NAME}/epoch_{epoch}.pt")
                save_and_log_model_checkpoint(model, last_model_path, artifact_dir="last_model")
                print(f"Early stopping triggered after {epoch} epochs! No improvement for {patience} epochs.")
                break

            if is_final_epoch:
                last_model_path = (f"ml_models/{EXPERIMENT_NAME}/{RUN_NAME}/epoch_{epoch}.pt")
                save_and_log_model_checkpoint(model, last_model_path, artifact_dir="last_model")

In [None]:
# main
print("=" * 50)
print(EXPERIMENT_NAME.upper(), "RUN NAME: ", RUN_NAME.upper())
print("=" * 50)
seed_everything(42)

# Top 5 configs from 1-epoch grid search
top_param_grid = [
    {
        "epochs": 100,
        "dropout1": 0.5,
        "dropout2": 0.4,
        "batch_size": 128,
        "learning_rate": 0.001,
        "scheduler_gamma": 0.95,
        "scheduler_step": 2,
        "weight_decay": 0.0001,
    },
    {
        "epochs": 100,
        "dropout1": 0.5,
        "dropout2": 0.4,
        "batch_size": 128,
        "learning_rate": 0.001,
        "scheduler_gamma": 0.8,
        "scheduler_step": 2,
        "weight_decay": 0.0001,
    },
    {
        "epochs": 100,
        "dropout1": 0.5,
        "dropout2": 0.4,
        "batch_size": 128,
        "learning_rate": 0.001,
        "scheduler_gamma": 0.95,
        "scheduler_step": 1,
        "weight_decay": 1e-5,
    },
    {
        "epochs": 100,
        "dropout1": 0.5,
        "dropout2": 0.4,
        "batch_size": 128,
        "learning_rate": 0.001,
        "scheduler_gamma": 0.8,
        "scheduler_step": 1,
        "weight_decay": 1e-5,
    },
    {
        "epochs": 100,
        "dropout1": 0.5,
        "dropout2": 0.4,
        "batch_size": 128,
        "learning_rate": 0.001,
        "scheduler_gamma": 0.95,
        "scheduler_step": 2,
        "weight_decay": 1e-5,
    },
]

for i, params in enumerate(top_param_grid):
    print(f"\n[RUN] Training top config {i+1}/{len(top_param_grid)}: {params}")
    run_training_pipeline(params)