In [None]:
from skorch.callbacks import EarlyStopping
from torchvision import datasets, transforms
import torch
import torch.optim as optim
from skorch import NeuralNetClassifier

import seaborn as sns
from itertools import islice
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from utils.mlflow import is_mlflow_server_running, set_mlflow_tracking_uri
import mlflow.pytorch

In [None]:
IMG_SIZE = 32

FIGURES_DIR = Path("figures/cifar10/")
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

if not is_mlflow_server_running():
    raise RuntimeError("MLflow server is not running. Please start the MLflow server before running this notebook.")

set_mlflow_tracking_uri()

In [None]:
# Transformations de base (obligatoires)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),
        std=(0.2470, 0.2435, 0.2616)
    )
])

train_dataset = datasets.CIFAR10(
    root="../data",
    train=True,
    transform=transform,
    download=True
)

test_dataset = datasets.CIFAR10(
    root="../data",
    train=False,
    transform=transform,
    download=True
)

# Visualize some images

In [None]:
X_example, y_example = zip(*islice(iter(train_dataset), 7))

In [None]:
from computer_vision.src.figures import plot_example

plot_example(torch.stack(X_example), y_example, train_dataset.classes, n=7);

# Distribution of classes in the training set

In [None]:
from computer_vision.src.figures import plot_label_distribution

figures_name = FIGURES_DIR / "class_distribution.png"

plot_label_distribution(train_dataset, figures_name)

# Training a baseline model

In [None]:
y_train = np.array([y for x, y in iter(train_dataset)])
y_test = np.array([y for x, y in iter(test_dataset)])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from computer_vision.src.baseline import BaselineModel
torch.manual_seed(0)

params = {
    'max_epochs': 10,
    'lr': 0.1,
}

baseline = NeuralNetClassifier(
    BaselineModel,
    #max_epochs=10,
    iterator_train__num_workers=2,
    iterator_valid__num_workers=2,
    #lr=0.1,
    callbacks=[EarlyStopping(patience=5)],
    device=device,
    module__input_dim=IMG_SIZE*IMG_SIZE*3,
    **params
)

In [None]:
# Baseline training with MLflow tracking
with mlflow.start_run(run_name="baseline_run"):
    # Log some useful params
    mlflow.set_tag("model_type", "BaselineModel")
    mlflow.log_params(params)

    # Fit
    baseline.fit(train_dataset, y=y_train)

    # Log per-epoch metrics from skorch history
    for epoch, row in enumerate(baseline.history):
        if 'train_loss' in row:
            mlflow.log_metric('train_loss', float(row['train_loss']), step=epoch)
        if 'valid_loss' in row:
            mlflow.log_metric('val_loss', float(row['valid_loss']), step=epoch)
        # skorch may also store valid_acc/train_acc under other keys; try common ones
        if 'train_accuracy' in row:
            mlflow.log_metric('train_acc', float(row['train_accuracy']), step=epoch)
        if 'valid_accuracy' in row:
            mlflow.log_metric('val_acc', float(row['valid_accuracy']), step=epoch)

    # Save and log loss plot
    train_loss_history = baseline.history[:, 'train_loss']
    val_loss_history = baseline.history[:, 'valid_loss']
    plt.figure()
    sns.lineplot(x=range(1, len(train_loss_history) + 1), y=train_loss_history, label='Train Loss')
    sns.lineplot(x=range(1, len(val_loss_history) + 1), y=val_loss_history , label='Validation Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Baseline: Training Loss over Epochs")
    baseline_fig = FIGURES_DIR / "baseline_loss.png"
    plt.savefig(baseline_fig)
    plt.show()
    mlflow.log_artifact(str(baseline_fig))

    # Also log class distribution figure created earlier
    try:
        mlflow.log_artifact(str(figures_name))
    except Exception:
        pass

    # Log trained model (try mlflow.pytorch; fallback to state_dict artifact)
    try:
        mlflow.pytorch.log_model(baseline.module_, artifact_path="baseline_model")
    except Exception:
        torch.save(baseline.module_.state_dict(), "baseline_model_state_dict.pth")
        mlflow.log_artifact("baseline_model_state_dict.pth")

    # Evaluate on test set
    acc = 100 * baseline.score(test_dataset, y_test)
    print("Test set Accuracy: {:.2f}%".format(acc))
    mlflow.log_metric('test_accuracy', float(acc))

    mlflow.end_run()

# Use BaseCNN with custom parameters

In [None]:
from computer_vision.src.BaseCNN import BaseCNN
from sklearn.model_selection import GridSearchCV
from skorch.helper import SliceDataset

params = {
    'max_epochs': 10,
    'lr': 0.001,
    'optimizer': optim.Adam,
    'callbacks': [EarlyStopping(patience=3)],

    'module__img_size': IMG_SIZE,
    'module__nb_conv_layers': 2,
    'module__nb_layers': 2,
    'module__net_width': 256,
    'module__dropout_rates': [0.25, 0.5],
}

cnn = NeuralNetClassifier(
    BaseCNN,
    #max_epochs=10,
    #lr=0.001,
    #optimizer=optim.Adam,
    device=device,
    #callbacks=[EarlyStopping(patience=3)],

    module__num_classes=10,

    **params
)


In [None]:
with mlflow.start_run(run_name="basecnn_run"):
    mlflow.set_tag("model_type", "BaseCNN")
    # log key params
    mlflow.log_params(params)

    cnn.fit(train_dataset, y_train)

    # log history metrics
    for epoch, row in enumerate(cnn.history):
        if 'train_loss' in row:
            mlflow.log_metric('train_loss', float(row['train_loss']), step=epoch)
        if 'valid_loss' in row:
            mlflow.log_metric('val_loss', float(row['valid_loss']), step=epoch)

    # save and log plot
    train_loss_history = cnn.history[:, 'train_loss']
    valid_loss_history = cnn.history[:, 'valid_loss']
    plt.figure()
    sns.lineplot(x=range(1, len(train_loss_history) + 1), y=train_loss_history, label='Train Loss')
    sns.lineplot(x=range(1, len(valid_loss_history) + 1), y=valid_loss_history , label='Validation Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("BaseCNN: Training Loss over Epochs")
    cnn_fig = FIGURES_DIR / "cnn_loss.png"
    plt.savefig(cnn_fig)
    plt.show()
    mlflow.log_artifact(str(cnn_fig))

    try:
        mlflow.pytorch.log_model(cnn.module_, artifact_path="basecnn_model")
    except Exception:
        torch.save(cnn.module_.state_dict(), "basecnn_model_state_dict.pth")
        mlflow.log_artifact("basecnn_model_state_dict.pth")

    acc = 100 * cnn.score(test_dataset, y_test)
    print("Test set Accuracy: {:.2f}%".format(acc))
    mlflow.log_metric('test_accuracy', float(acc))

# Hyperparameter Tuning with Grid Search

In [None]:
net = NeuralNetClassifier(
    BaseCNN,
    max_epochs=10,
    lr=0.01,
    train_split=False,
    callbacks=[EarlyStopping(patience=3, monitor='train_loss')],

    module__num_classes=10,
    module__img_size=IMG_SIZE,
)

In [None]:
params = {
    #'lr': [0.01, 0.02],
    #'max_epochs': [10, 20],
    'optimizer': [optim.Adam, optim.SGD],
    'module__nb_conv_layers':[2, 3],
    #'module__nb_layers':[2, 3, 4],
    'module__net_width':[128, 256],
    #'module__dropout_rates':[[0.2],[0.25, 0.5]],
}

In [None]:
grid = GridSearchCV(net, params, cv=3, scoring='accuracy', verbose=2, n_jobs=-1)

In [None]:
train_dataset_sliceable = SliceDataset(train_dataset)

In [None]:
with mlflow.start_run(run_name="gridsearch_run"):
    mlflow.set_tag('procedure', 'GridSearchCV')
    # Log the grid search param grid (as artifact or param)
    try:
        mlflow.log_param('param_grid', str(params))
    except Exception:
        pass

    grid.fit(train_dataset_sliceable, y_train)
    mlflow.log_param('best_params', str(grid.best_params_))
    mlflow.log_metric('best_cv_accuracy', float(grid.best_score_)*100)

    # Optionally log the best estimator model
    try:
        best_model = grid.best_estimator_.module_
        mlflow.pytorch.log_model(best_model, artifact_path='grid_best_model')
    except Exception:
        try:
            torch.save(grid.best_estimator_.module_.state_dict(), 'grid_best_model_state_dict.pth')
            mlflow.log_artifact('grid_best_model_state_dict.pth')
        except Exception:
            pass

    print("Best parameters found: ", grid.best_params_)
    print("Best cross-validation accuracy: ", grid.best_score_)
    print("Test set accuracy: ", grid.score(SliceDataset(test_dataset), y_test))

    mlflow.log_metric('test_accuracy', float(grid.score(SliceDataset(test_dataset), y_test))*100)

    for epoch, row in enumerate(grid.best_estimator_.history):
        if 'train_loss' in row:
            mlflow.log_metric('train_loss', float(row['train_loss']), step=epoch)
        if 'valid_loss' in row:
            mlflow.log_metric('val_loss', float(row['valid_loss']), step=epoch)