In [None]:
import json
import random
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import GridSearchCV

PROJECT_ROOT = Path.cwd()
DATA_ROOT = (PROJECT_ROOT.parent / "data").resolve()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_IMG_SIZE = 32
PIN_MEMORY = torch.cuda.is_available()


def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed()
print(f"Initialized notebook on {DEVICE} | data root: {DATA_ROOT}")

In [None]:
def build_transforms(img_size: int = DEFAULT_IMG_SIZE) -> transforms.Compose:
    """Standard Fashion-MNIST resizing + normalization."""
    return transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])


def get_train_loader(
    batch_size: int = 256,
    img_size: int = DEFAULT_IMG_SIZE,
    num_workers: int = 2,
    shuffle: bool = True,
) -> DataLoader:
    """Return a loader over the Fashion-MNIST training split."""
    transform = build_transforms(img_size)
    train_dataset = datasets.FashionMNIST(
        root=DATA_ROOT,
        train=True,
        download=True,
        transform=transform,
    )
    return DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=PIN_MEMORY,
        drop_last=False,
    )


def dataloader_to_numpy(loader: DataLoader, max_batches: int | None = None) -> Tuple[np.ndarray, np.ndarray]:
    xs, ys = [], []
    for idx, (xb, yb) in enumerate(loader):
        xs.append(xb.cpu().numpy())
        ys.append(yb.cpu().numpy())
        if max_batches is not None and idx + 1 >= max_batches:
            break
    return np.concatenate(xs, axis=0), np.concatenate(ys, axis=0)

In [None]:
class LeNet5(nn.Module):
    """Classic LeNet-5: 2 conv blocks (with AvgPool) followed by 3 fully connected layers."""

    def __init__(self, input_channels: int = 1, num_classes: int = 10, img_size: int = DEFAULT_IMG_SIZE):
        super().__init__()
        self.img_size = img_size
        self.conv1 = nn.Conv2d(input_channels, 6, kernel_size=5, stride=1, padding=0)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)

        with torch.no_grad():
            dummy = torch.zeros(1, input_channels, img_size, img_size)
            features = self._forward_features(dummy)
            flatten_dim = features.view(1, -1).shape[1]

        self.fc1 = nn.Linear(flatten_dim, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._forward_features(x)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [None]:
class TorchClassifier(BaseEstimator, ClassifierMixin):
    """Minimal sklearn wrapper used purely for hyper-parameter search."""

    def __init__(
        self,
        model_class=LeNet5,
        model_kwargs=None,
        lr: float = 1e-3,
        batch_size: int = 64,
        epochs: int = 5,
        optimizer_name: str = "adam",
        weight_decay: float = 0.0,
        momentum: float = 0.9,
        device: str | None = None,
        seed: int = 42,
        verbose: int = 0,
    ):
        self.model_class = model_class
        self.model_kwargs = model_kwargs or {
            "input_channels": 1,
            "num_classes": 10,
            "img_size": DEFAULT_IMG_SIZE,
        }
        self.lr = lr
        self.batch_size = batch_size
        self.epochs = epochs
        self.optimizer_name = optimizer_name
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.device = device
        self.seed = seed
        self.verbose = verbose
        self.model_ = None
        self.classes_ = None

    def _resolve_device(self):
        if self.device is not None:
            return torch.device(self.device)
        return DEVICE

    def _set_seed(self):
        set_seed(self.seed)

    def _make_optimizer(self, params):
        name = self.optimizer_name.lower()
        if name == "adam":
            return torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay)
        if name == "sgd":
            return torch.optim.SGD(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
        raise ValueError(f"Unknown optimizer {self.optimizer_name}")

    def fit(self, X, y):
        self._set_seed()
        device = self._resolve_device()
        X_t = torch.as_tensor(X, dtype=torch.float32)
        y_t = torch.as_tensor(y, dtype=torch.long)
        dataset = torch.utils.data.TensorDataset(X_t, y_t)
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        self.model_ = self.model_class(**self.model_kwargs).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = self._make_optimizer(self.model_.parameters())
        self.classes_ = np.unique(np.asarray(y))
        for epoch in range(self.epochs):
            epoch_loss = 0.0
            total = 0
            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)
                optimizer.zero_grad(set_to_none=True)
                logits = self.model_(xb)
                loss = criterion(logits, yb)
                loss.backward()
                optimizer.step()
                epoch_loss += float(loss.item()) * xb.size(0)
                total += xb.size(0)
            if self.verbose:
                print(f"[TorchClassifier] epoch {epoch + 1}/{self.epochs} loss={epoch_loss/total:.4f}")
        return self

    def predict(self, X):
        if self.model_ is None:
            raise RuntimeError("Call fit before predict.")
        device = self._resolve_device()
        X_t = torch.as_tensor(X, dtype=torch.float32)
        dataset = torch.utils.data.TensorDataset(X_t)
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
        preds = []
        with torch.no_grad():
            for (xb,) in loader:
                logits = self.model_(xb.to(device))
                preds.append(torch.argmax(logits, dim=1).cpu().numpy())
        return np.concatenate(preds, axis=0)

    def predict_proba(self, X):
        if self.model_ is None:
            raise RuntimeError("Call fit before predict_proba.")
        device = self._resolve_device()
        X_t = torch.as_tensor(X, dtype=torch.float32)
        dataset = torch.utils.data.TensorDataset(X_t)
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
        probs = []
        with torch.no_grad():
            for (xb,) in loader:
                logits = self.model_(xb.to(device))
                probs.append(torch.softmax(logits, dim=1).cpu().numpy())
        return np.concatenate(probs, axis=0)

    def score(self, X, y):
        preds = self.predict(X)
        return float((preds == np.asarray(y)).mean())


param_grid = {
    "lr": [1e-2, 5e-3, 1e-3, 5e-4],
    "batch_size": [32, 64, 128, 256],
    "epochs": [5, 10, 15],
    "optimizer_name": ["adam", "sgd"],
    "weight_decay": [0.0, 1e-4, 5e-4, 1e-3],
}

train_loader = get_train_loader(batch_size=256, img_size=DEFAULT_IMG_SIZE)
X_train_sub, y_train_sub = dataloader_to_numpy(train_loader, max_batches=120)

lenet_estimator = TorchClassifier(device=str(DEVICE))
lenet_grid = GridSearchCV(lenet_estimator, param_grid=param_grid, cv=3, n_jobs=4, verbose=2)

lenet_grid.fit(X_train_sub, y_train_sub)
print("Best params:", lenet_grid.best_params_)
print("Best CV score:", lenet_grid.best_score_)

# Persist best model artifacts for later reuse
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

best_estimator = lenet_grid.best_estimator_
best_model = best_estimator.model_
weights_path = ARTIFACTS_DIR / "lenet5_grid_best_weights.pt"
torch.save(best_model.state_dict(), weights_path)

metadata = {
    "best_params": lenet_grid.best_params_,
    "cv_score": float(lenet_grid.best_score_),
    "model_kwargs": best_estimator.model_kwargs,
}
params_path = ARTIFACTS_DIR / "lenet5_grid_best_params.json"
with params_path.open("w", encoding="utf-8") as fh:
    json.dump(metadata, fh, indent=2)

estimator_path = ARTIFACTS_DIR / "lenet5_grid_best_estimator.pt"
torch.save(best_estimator, estimator_path)
print(f"Saved weights to {weights_path}")
print(f"Saved params to {params_path}")
print(f"Saved sklearn wrapper to {estimator_path}")