In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

In [None]:
##############################################
# Data Loading
##############################################

# We use MNIST as the dataset. We will also create a subset for faster training.
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Choose a small subset for demonstration
subset_indices = np.random.choice(len(mnist_train), 1000, replace=False)
mnist_subset = Subset(mnist_train, subset_indices)
train_loader = DataLoader(mnist_subset, batch_size=64, shuffle=True)

# Create random label dataset:
random_labels = torch.randint(low=0, high=10, size=(len(mnist_subset),))
random_train_dataset = [(img, random_labels[idx].item()) for idx, (img, lbl) in enumerate(mnist_subset)]
random_train_loader = DataLoader(random_train_dataset, batch_size=64, shuffle=True)


In [None]:

##############################################
# Network Definitions
##############################################

class OverParamNet(nn.Module):
    def __init__(self, hidden_size=5000):  # Large network (overparameterized)
        super(OverParamNet, self).__init__()
        self.fc1 = nn.Linear(784, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class UnderParamNet(nn.Module):
    def __init__(self, hidden_size=50):  # Smaller network (underparameterized)
        super(UnderParamNet, self).__init__()
        self.fc1 = nn.Linear(784, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:

##############################################
# Training Function
##############################################

def train_model(train_loader, model_class, lr=0.1, max_epochs=30):
    device = torch.device('cpu')
    model = model_class().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(max_epochs):
        model.train()
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        # If we perfectly fit the data, we can stop early.
        if train_acc == 1.0:
            break
    return model

In [None]:

##############################################
# Measure Flatness
##############################################

def measure_loss_increase(model_state_dict, model_class, train_loader, epsilons, num_directions=5):
    device = torch.device('cpu')
    model = model_class().to(device)
    model.load_state_dict(model_state_dict)
    model.eval()
    criterion = nn.CrossEntropyLoss()

    # Compute baseline loss
    baseline_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            batch_size = labels.size(0)
            baseline_loss += loss.item() * batch_size
            total_samples += batch_size
    baseline_loss /= total_samples

    # Flatten parameters
    params = torch.cat([p.detach().view(-1) for p in model.parameters()])
    dim = params.numel()

    mean_increases = []
    std_increases = []

    for eps in epsilons:
        increases = []
        for _ in range(num_directions):
            direction = torch.randn(dim, device=device)
            direction = direction / torch.norm(direction)
            perturbed_params = params + eps * direction

            # Load perturbed params back into model
            idx = 0
            with torch.no_grad():
                for p in model.parameters():
                    size = p.numel()
                    new_vals = perturbed_params[idx:idx+size].view(p.size())
                    p.copy_(new_vals)
                    idx += size

            # Compute loss for perturbed model
            pert_loss = 0.0
            total_samples = 0
            with torch.no_grad():
                for images, labels in train_loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    batch_size = labels.size(0)
                    pert_loss += loss.item() * batch_size
                    total_samples += batch_size
            pert_loss /= total_samples

            increases.append(pert_loss - baseline_loss)

        mean_increases.append(np.mean(increases))
        std_increases.append(np.std(increases))

    return mean_increases, std_increases


In [None]:

##############################################
# Experiments
##############################################

# Train overparameterized and underparameterized networks on natural labels
solutions_overparam = []
solutions_underparam = []

for i in range(3):
    sol_over = train_model(train_loader, OverParamNet)
    solutions_overparam.append(sol_over.state_dict())

    sol_under = train_model(train_loader, UnderParamNet)
    solutions_underparam.append(sol_under.state_dict())

# Train overparameterized networks on random labels
solutions_random = []
for i in range(3):
    sol_rand = train_model(random_train_loader, OverParamNet)
    solutions_random.append(sol_rand.state_dict())


In [None]:
##############################################
# Measure Flatness and Plot
##############################################

epsilons = [0.001, 0.005, 0.01, 0.05, 0.1]

mean_inc_overparam = []
for sol_sd in solutions_overparam:
    m, s = measure_loss_increase(sol_sd, OverParamNet, train_loader, epsilons)
    mean_inc_overparam.append(m)

mean_inc_underparam = []
for sol_sd in solutions_underparam:
    m, s = measure_loss_increase(sol_sd, UnderParamNet, train_loader, epsilons)
    mean_inc_underparam.append(m)

mean_inc_random = []
for sol_sd in solutions_random:
    m, s = measure_loss_increase(sol_sd, OverParamNet, random_train_loader, epsilons)
    mean_inc_random.append(m)



In [None]:
# Define colors
color_overparam = 'steelblue'
color_underparam = 'darkred'

# Plot: Overparam vs Underparam
plt.figure(figsize=(8, 6))
plt.errorbar(
    epsilons,
    np.mean(mean_inc_overparam, axis=0),
    yerr=np.std(mean_inc_overparam, axis=0),
    label='Overparameterized (Natural)',
    fmt='o-', color=color_overparam, capsize=5, markersize=8, linewidth=2
)
plt.errorbar(
    epsilons,
    np.mean(mean_inc_underparam, axis=0),
    yerr=np.std(mean_inc_underparam, axis=0),
    label='Underparameterized (Natural)',
    fmt='s--', color=color_underparam, capsize=5, markersize=8, linewidth=2
)
plt.xlabel('Perturbation Magnitude ($\epsilon$)', fontsize=14)
plt.ylabel('Loss Increase', fontsize=14)
# plt.title('Flatness Comparison: Overparam vs Underparam', fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.grid(alpha=0.3)
plt.legend(fontsize=12)
plt.tight_layout()
plt.savefig('flatness_over_vs_under.png', dpi=300)
plt.show()

# Plot: Natural vs Random (both Overparam)
plt.figure(figsize=(8, 6))
plt.errorbar(
    epsilons,
    np.mean(mean_inc_overparam, axis=0),
    yerr=np.std(mean_inc_overparam, axis=0),
    label='Overparam (Natural)',
    fmt='o-', color=color_overparam, capsize=5, markersize=8, linewidth=2
)
plt.errorbar(
    epsilons,
    np.mean(mean_inc_random, axis=0),
    yerr=np.std(mean_inc_random, axis=0),
    label='Overparam (Random Labels)',
    fmt='s--', color=color_underparam, capsize=5, markersize=8, linewidth=2
)
plt.xlabel('Perturbation Magnitude ($\epsilon$)', fontsize=14)
plt.ylabel('Loss Increase', fontsize=14)
#plt.title('Flatness Comparison: Natural vs Random Labels (Overparam)', fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.grid(alpha=0.3)
plt.legend(fontsize=12)
plt.tight_layout()
plt.savefig('flatness_natural_vs_random.png', dpi=300)
plt.show()