### Imports

In [1]:
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
sys.path.append("../src")  # Add src/ directory to the Python path

from data_loader import load_kmnist

In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
def load_kmnist():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    dataset = torchvision.datasets.KMNIST(
        root="../data", train=True, transform=transform, download=True
    )
    return dataset

In [5]:
train_data = load_kmnist()  # This is a PyTorch dataset

In [6]:
train_data

Dataset KMNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

In [7]:
hyperparameter_grid = {
    "adamw": {"lr": [0.0001, 0.001, 0.01], "weight_decay": [0.0001, 0.001]},
    "adam": {"lr": [0.0001, 0.001, 0.01], "weight_decay": [0.0001, 0.001]},
    "rmsprop": {"lr": [0.0001, 0.001, 0.01], "momentum": [0.8, 0.9], "alpha": [0.9, 0.99]}
}


In [8]:
from sklearn.model_selection import KFold

def get_k_folds(data, k=5):
    """
    Splits dataset indices into K folds.

    Args:
        data (list or dataset): The dataset (list of tensors or PyTorch dataset).
        k (int): Number of folds.

    Returns:
        Generator of (train_indices, val_indices).
    """
    num_samples = len(data)  # Number of images
    kf = KFold(n_splits=k, shuffle=True, random_state=42)

    for train_idx, val_idx in kf.split(np.arange(num_samples)):  # Correct approach for lists
        yield train_idx, val_idx


In [9]:
from train import train_model
from model import KMNISTModel
from model import get_optimizer
import torch
import torch.nn as nn
import itertools
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

def hyperparameter_tuning(optimizer_name, search_space, train_data, k=5, epochs=10):
    """
    Runs cross-validation for different hyperparameter settings.
    """
    best_config, best_score = None, 0
    param_keys, param_values = zip(*search_space.items())
    param_combinations = [dict(zip(param_keys, v)) for v in itertools.product(*param_values)]

    for params in param_combinations:
        print(f"Testing {optimizer_name} with params: {params}")
        total_val_acc = 0
        
        # Separate train data and labels into 

        for fold, (train_idx, val_idx) in enumerate(get_k_folds(train_data, k)):
            print(f"Fold {fold+1}/{k}")
            train_subset = torch.utils.data.Subset(train_data, train_idx)
            val_subset = torch.utils.data.Subset(train_data, val_idx)

            train_loader = torch.utils.data.DataLoader(train_subset, batch_size=64, shuffle=True)
            val_loader = torch.utils.data.DataLoader(val_subset, batch_size=64, shuffle=False)

            model = KMNISTModel().to(device)
            optimizer = get_optimizer(model, optimizer_name=optimizer_name, **params)

            train_losses, valid_losses, train_metrics, valid_metrics = train_model(model, device, train_loader, val_loader, epochs, optimizer, criterion)
            total_val_acc += max(valid_metrics)

        avg_val_acc = total_val_acc / k
        print(f"Avg Validation Accuracy for {params}: {avg_val_acc:.4f}")
        print('--------------------------------')

        if avg_val_acc > best_score:
            best_score = avg_val_acc
            best_config = params

    print(f"Best config for {optimizer_name}: {best_config} with validation accuracy: {best_score:.4f}")
    return best_config


In [None]:
train_data

In [10]:

for optimizer_name in ["adamw", "adam", "rmsprop"]:
    best_params = hyperparameter_tuning(optimizer_name, hyperparameter_grid[optimizer_name], train_data, k=5, epochs=10)
    print(f"Best parameters for {optimizer_name}: {best_params}")


Testing adamw with params: {'lr': 0.0001, 'weight_decay': 0.0001}
Fold 1/5


KeyboardInterrupt: 