**Set Up**

In [2]:
# importing needed libraries
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import ConcatDataset, random_split, DataLoader
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms

In [3]:


#define general class of MLP that uses relu, tanh or sigmoid activation and has prescribed widhts/depths

import torch.nn as nn
import torch.nn.init as init

class mult_layer_percep(nn.Module):
    def __init__(
        self,
        input_len: int,
        layer_widths: list,
        output_len: int,
        activation: str,
        init_method: str = "default",
        dropout: float | None = None,
    ):
        super().__init__()

        assert input_len > 0 and output_len > 0
        assert all(h > 0 for h in layer_widths)

        activation = activation.lower()
        if activation == "relu":
            act_factory = nn.ReLU
            nonlin_for_kaim = "relu"
        elif activation == "tanh":
            act_factory = nn.Tanh
            nonlin_for_kaim = "tanh"  
        elif activation == "sigmoid":
            act_factory = nn.Sigmoid
            nonlin_for_kaim = "sigmoid"  
        else:
            raise ValueError("activation must be 'relu', 'tanh', or 'sigmoid'")

        layers = []
        sizes = [input_len] + list(layer_widths) + [output_len]

        for i in range(len(sizes) - 1):
            lin = nn.Linear(sizes[i], sizes[i + 1])

            
            if init_method != "default":
                m = init_method.lower()
                if m in ("kaim"):
                    init.kaiming_uniform_(lin.weight, nonlinearity=nonlin_for_kaim)
                elif m in ("xav"):
                    init.xavier_uniform_(lin.weight)
                elif m == "normal":
                    init.normal_(lin.weight, mean=0.0, std=0.01)
                elif m == "uniform":
                    init.uniform_(lin.weight, a=-0.01, b=0.01)
                else:
                    raise ValueError("init_method must be 'default', 'kaim', 'xavier', 'normal', or 'uniform'")

                if lin.bias is not None:
                    init.zeros_(lin.bias)

            layers.append(lin)

            # Hidden layers only: activation (and optional dropout)
            if i < len(sizes) - 2:
                layers.append(act_factory())
                if dropout is not None:
                    layers.append(nn.Dropout(p=float(dropout)))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

        
        


Loading CIFAR-10 data

In [4]:


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(torch.flatten)
])
   #transform data to torch tensor and flatten into vector form.

data_root = "./CIFAR10_data"

train_data = torchvision.datasets.CIFAR10(root=data_root, train=True,download=True,transform=transform) #load data
test_data = torchvision.datasets.CIFAR10(root=data_root,train=False, download=True,transform=transform) 

# splitting the data into 0.8-0.1-0.1 as required.
full_data = ConcatDataset([train_data, test_data])  

n_total = len(full_data)            
n_train = int(0.8 * n_total)        
n_val   = int(0.1 * n_total)        
n_test  = n_total - n_train - n_val 

g = torch.Generator().manual_seed(20)  # reproducible split
train_data, val_data, test_data = random_split(full_data, [n_train, n_val, n_test], generator=g)

batch_size = 200

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,num_workers=2,pin_memory=True) #create data loaders for easy batching
test_loader  = DataLoader(test_data,  batch_size=batch_size, shuffle=False,num_workers=2,pin_memory=True)
val_loader = DataLoader(val_data,batch_size=batch_size,shuffle=True,num_workers=2,pin_memory=True)


Training


In [5]:

def train_one_epoch(model, loader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for x, y in loader:
        x, y = x.to(device,non_blocking=True), y.to(device,non_blocking=True)

        optimizer.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = loss_fn(logits, y)

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total


def train_for_n_epochs(n,model,train_loader,val_loader,loss_fn,optimizer,device):
    """ Trains model for n epochs with batches given in *_loader's, using optimizer and loss_fn"""

    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    best_loss = float("inf")
    for epoch in range(1, n+1):
        train_loss, train_acc = train_one_epoch(model, train_loader, loss_fn, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, loss_fn, device)
        train_losses.append(train_loss);train_accs.append(train_acc);val_losses.append(val_loss);val_accs.append(val_acc)

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(),"best_mlp.pt")


    return train_losses,train_accs,val_losses,val_accs


@torch.no_grad()
def predict_all(model, loader, loss_fn, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0.0
    total = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        logits = model(x)
        loss = loss_fn(logits, y)

        total_loss += loss.item() * x.size(0)
        total += x.size(0)

        preds = logits.argmax(dim=1).cpu()
        all_preds.append(preds)
        all_labels.append(y.cpu())

    y_true = torch.cat(all_labels)
    y_pred = torch.cat(all_preds)
    avg_loss = total_loss / total

    return y_true, y_pred, avg_loss