In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

In [None]:
from configs import CONFIGS
device = CONFIGS["device"]
batch_size = CONFIGS["batch_size"]
num_components = CONFIGS["num_components"]
loss_sample_batch_size = CONFIGS["loss_sample_batch_size"]
flatness_threshold_factor = np.e

print(f"Using device: {device}")

from models.mlp import MLP

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./mnist_data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=CONFIGS["batch_size"],
    shuffle=True
)
sub_loader = DataLoader(train_dataset, \
    batch_size=loss_sample_batch_size, sampler=SubsetRandomSampler(indices))

model = MLP(input_dim=28*28, hidden_dim=50, output_dim=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=CONFIGS["learning_rate"])
criterion = nn.CrossEntropyLoss()

def calculate_subset_loss(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item() * data.size(0)
    return total_loss / len(data_loader.sampler)

In [None]:
print("\nloading data, model, and PCA results...")
try:
    mean_weight = np.load('data/mean_weight.npy')
    pca_components = np.load('data/pca_components.npy')
    variances = np.load('data/pca_variances.npy')
    model_state_dict = torch.load('data/final_model.pth')
except FileNotFoundError as e:
    print(f"error: {e.filename} not found.")
    print("please run 'train_model.py' and 'run_pca.py' first.")
    exit()

In [None]:
print("calculating baseline loss L0 on the data subset...")
w0_tensor = torch.tensor(mean_weight.reshape(50, 50), dtype=torch.float32).to(device)
model.load_state_dict(model_state_dict)
model.fc2.weight.data = w0_tensor

L0 = calculate_subset_loss(model, sub_loader, criterion, device)
flatness_threshold = L0 * flatness_threshold_factor
print(f"baseline Loss L0 (on subset): {L0:.6f}")
print(f"flatness Threshold (L0 * e): {flatness_threshold:.6f}")

if not os.path.exists('plots'):
    os.makedirs('plots')

In [None]:
plt.figure(figsize=(10, 6))

analysis_indices = [10, 20, 50, 100]
dtheta_range = np.linspace(-10, 10, 81)

for i in tqdm(analysis_indices, desc="loss landscape"):
    p_i = pca_components[i]
    p_i_tensor = torch.tensor(p_i.reshape(50, 50), dtype=torch.float32).to(device)

    losses = []
    for dtheta in dtheta_range:
        model.fc2.weight.data = w0_tensor + dtheta * p_i_tensor
        loss = calculate_subset_loss(model, sub_loader, criterion, device)
        losses.append(loss)
    
    plt.plot(dtheta_range, losses, label=f"PCA index {i}")

plt.title("Loss landscape along PCA directions")
plt.xlabel(r"$\delta\theta$")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, which="both", ls="--")
plt.savefig('plots/loss_landscapes.png')

plt.show()