In [13]:
import pandas as pd
import numpy as np
import random
import torch
import torch.nn.functional as F
from torchvision import datasets
from torch.autograd.functional import hessian
import torch.nn as nn
import plotly.express as px

In [14]:
def get_cropped_images(X, y, n):
    """Returns images that can be cropped by n pixels on all sides without
    losing any information (all cropped pixels = 0).

    Args:
        X np.ndarray: Input images.
        y np.ndarray: Corresponding labels.
        n (int): Number of pixels to crop from each side.

    Returns:
        np.ndarray: Cropped images.
    """

    mask = np.zeros((X.shape[1], X.shape[2]), dtype=bool)
    mask[:n, :] = True
    mask[-n:, :] = True
    mask[:, :n] = True
    mask[:, -n:] = True

    border_pixels = X[:, mask]
    croppable_mask = (border_pixels.sum(axis=1) == 0)
    return X[croppable_mask, n:-n, n:-n], y[croppable_mask]


In [15]:
NUM_LABELS = 7

In [16]:
rand_state = 42
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
# Load MNIST dataset
train = datasets.MNIST(root="./data", train=True,  download=True)
test  = datasets.MNIST(root="./data", train=False, download=True)
X_full = torch.cat([train.data, test.data], dim=0).numpy() / 255.0
y_full = torch.cat([train.targets, test.targets], dim=0).numpy()

# Crop images by n pixels on each side
X_cropped, y_cropped = get_cropped_images(X_full, y_full, 4)
y_cropped = pd.Series(y_cropped)

# Select most frequent classes
classes = y_cropped.value_counts().index[:NUM_LABELS]

# Create train/test split with 1000 train / 200 test samples per class
X_train_list = []
y_train_list = []
X_test_list = []
y_test_list = []

for clss in classes:
    indices = y_cropped[y_cropped == clss].sample(1000).index
    X_train_list.append(X_cropped[indices[:800]])
    y_train_list.append(y_cropped[indices[:800]])
    X_test_list.append(X_cropped[indices[800:]])
    y_test_list.append(y_cropped[indices[800:]])

X_train = np.concatenate(X_train_list, axis=0)
y_train = pd.concat(y_train_list, axis=0).reset_index(drop=True)
X_test = np.concatenate(X_test_list, axis=0)
y_test = pd.concat(y_test_list, axis=0).reset_index(drop=True)

# Convert y from image label to class lablel (see classes variable)
class_map = {clss: idx for idx, clss in enumerate(classes)}
y_train = y_train.map(class_map)
y_test = y_test.map(class_map)

# Shuffle data
shuffle_indices_train = y_train.sample(y_train.shape[0], random_state=rand_state).index
shuffle_indices_test = y_test.sample(y_test.shape[0], random_state=rand_state).index

X = X_train[shuffle_indices_train]
y = y_train[shuffle_indices_train].reset_index(drop=True).to_numpy()
X_test = X_test[shuffle_indices_test]
y_test = y_test[shuffle_indices_test].reset_index(drop=True).to_numpy()

# Convert to torch tensors
X = torch.tensor(X, dtype=torch.float32, device=device)
y = torch.tensor(y, dtype=torch.long, device=device)
X_test = torch.tensor(X_test, dtype=torch.float32, device=device)
y_test = torch.tensor(y_test, dtype=torch.long, device=device)

In [18]:
class FullyConnectedNet(nn.Module):
    def __init__(self, input_size, num_hidden_layers, hidden_layer_size, 
                 num_labels, activation):
        super(FullyConnectedNet, self).__init__()

        self.input_size = input_size
        self.num_hidden_layers = num_hidden_layers
        self.hidden_layers_size = hidden_layer_size
        self.num_labels = num_labels
        self.activation = activation

        layers = [nn.Flatten()]
        in_size = input_size

        for _ in range(num_hidden_layers):
            layers += [nn.Linear(in_size, hidden_layer_size), activation()]
            in_size = hidden_layer_size

        layers.append(nn.Linear(in_size, num_labels))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

In [19]:
def flatten_params(model):
    return torch.cat([p.contiguous().view(-1) for p in model.parameters()])

def unflatten_params(flat_params, model):
    params = []
    i = 0
    for p in model.parameters():
        n = p.numel()
        params.append(flat_params[i:i+n].view_as(p))
        i += n
    return params

def forward_with_params(model, X, params):
    """Functional forward that uses explicit parameter tensors instead of model weights."""
    x = X
    i = 0
    for layer in model.network:
        if isinstance(layer, nn.Flatten):
            x = x.view(x.size(0), -1)
        elif isinstance(layer, nn.Linear):
            w = params[i]
            b = params[i + 1]
            x = F.linear(x, w, b)
            i += 2
        elif isinstance(layer, nn.ReLU):
            x = F.relu(x)
        else:
            raise ValueError(f"Unsupported layer type: {type(layer)}")
    return x

def loss_wrt_params(flat_params, model, criterion):
    params = unflatten_params(flat_params, model)
    outputs = forward_with_params(model, X, params)
    loss = criterion(outputs, y)
    return loss

input_size = X.shape[1] * X.shape[2] # 400
num_hidden_layers = 2
hidden_layer_size = 24
learning_rate = .01
epochs = 2

model = FullyConnectedNet(
    input_size=input_size,
    num_hidden_layers=num_hidden_layers,
    hidden_layer_size=hidden_layer_size,
    num_labels=NUM_LABELS,
    activation=nn.ReLU
)

optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

criterion = nn.CrossEntropyLoss()

model.to(device)
model.train()

train_losses = np.empty(epochs)
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    train_losses[epoch] = loss.item()

    if (epoch+1) % 200 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

params_flat = flatten_params(model).detach().clone().requires_grad_(True).to(device)
H = hessian(lambda p: loss_wrt_params(p, model, criterion), params_flat)
H.max()


KeyboardInterrupt: 

In [435]:
px.line(x=np.arange(1, epochs+1), y=train_losses,
           labels={'x': 'Epoch', 'y': 'Training Loss'},
           title='Training Loss over Epochs').show()

In [434]:
model.eval()
with torch.no_grad():
    outputs = model(X_test.to(device))
    preds = outputs.argmax(dim=1)
    correct = (preds == y_test.to(device)).sum().item()
    accuracy = correct / len(y_test)

print(f"Test accuracy: {accuracy:.4f}")

Test accuracy: 0.9407
