In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

In [None]:
from configs import CONFIGS
from models.mlp import MLP

N_SAMPLES_FOR_LOSS = 10000
N_COMPONENTS_TO_ANALYZE = 50

def calculate_subset_loss(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item() * data.size(0)
    return total_loss / len(data_loader.sampler)

In [None]:
print("\nLoading data, model, and PCA results...")
try:
    mean_weight = np.load('data/mean_weight.npy')
    pca_components = np.load('data/pca_components.npy')
    variances = np.load('data/pca_variances.npy')
    model_state_dict = torch.load('data/final_model.pth')
except FileNotFoundError as e:
    print(f"Error: {e.filename} not found.")
    print("Please run 'train_model.py' and 'run_pca.py' first.")
    exit()

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./mnist_data', train=True, download=True, transform=transform)

np.random.seed(42)

indices = np.random.permutation(len(train_dataset))[:N_SAMPLES_FOR_LOSS]

subset_sampler = SubsetRandomSampler(indices)
sub_loader = DataLoader(train_dataset, batch_size=CONFIGS["batch_size"], sampler=subset_sampler)

model = MLP(input_dim=28*28, hidden_dim=50, output_dim=10).to(CONFIGS["device"])
criterion = nn.CrossEntropyLoss()

print("Calculating baseline loss L0 on the data subset...")
w0_tensor = torch.tensor(mean_weight.reshape(CONFIGS["hidden_dim"], CONFIGS["hidden_dim"]),
                         dtype=torch.float32).to(CONFIGS["device"])
model.load_state_dict(model_state_dict)
model.fc2.weight.data = w0_tensor

L0 = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
flatness_threshold = L0 * np.e
print(f"Baseline Loss L0 (on subset): {L0:.6f}")
print(f"Flatness Threshold (L0 * e): {flatness_threshold:.6f}")

In [None]:
p_10 = pca_components[10]
p_10_tensor = torch.tensor(p_10.reshape(CONFIGS["hidden_dim"], CONFIGS["hidden_dim"]),
                              dtype=torch.float32).to(CONFIGS["device"])

p_20 = pca_components[20]
p_20_tensor = torch.tensor(p_20.reshape(CONFIGS["hidden_dim"], CONFIGS["hidden_dim"]),
                              dtype=torch.float32).to(CONFIGS["device"])

p_50 = pca_components[50]
p_50_tensor = torch.tensor(p_50.reshape(CONFIGS["hidden_dim"], CONFIGS["hidden_dim"]),
                              dtype=torch.float32).to(CONFIGS["device"])

p_100 = pca_components[100]
p_100_tensor = torch.tensor(p_100.reshape(CONFIGS["hidden_dim"], CONFIGS["hidden_dim"]),
                              dtype=torch.float32).to(CONFIGS["device"])

loss_r_10 = []
loss_l_10 = []

loss_r_20 = []
loss_l_20 = []

loss_r_50 = []
loss_l_50 = []

loss_r_100 = []
loss_l_100 = []

for step in np.linspace(0, 10, 50):
    model.fc2.weight.data = w0_tensor + step * p_10_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_r_10.append(loss)

    model.fc2.weight.data = w0_tensor + step * p_20_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_r_20.append(loss)

    model.fc2.weight.data = w0_tensor + step * p_50_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_r_50.append(loss)

    model.fc2.weight.data = w0_tensor + step * p_100_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_r_100.append(loss)

for step in np.linspace(-10, 0, 50):
    model.fc2.weight.data = w0_tensor - step * p_10_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_l_10.append(loss)

    model.fc2.weight.data = w0_tensor - step * p_20_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_l_20.append(loss)

    model.fc2.weight.data = w0_tensor - step * p_50_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_l_50.append(loss)

    model.fc2.weight.data = w0_tensor - step * p_100_tensor
    loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
    loss_l_100.append(loss)

loss_r_10 = np.array(loss_r_10)
loss_l_10 = np.array(loss_l_10)
loss_r_20 = np.array(loss_r_20)
loss_l_20 = np.array(loss_l_20)
loss_r_50 = np.array(loss_r_50)
loss_l_50 = np.array(loss_l_50)
loss_r_100 = np.array(loss_r_100)
loss_l_100 = np.array(loss_l_100)

plt.figure(figsize=(12, 8))
steps = np.linspace(-10, 10, 100)
plt.plot(steps, np.concatenate((loss_l_10, loss_r_10)), label='PCA Component 10')
plt.plot(steps, np.concatenate((loss_l_20, loss_r_20)), label='PCA Component 20')
plt.plot(steps, np.concatenate((loss_l_50, loss_r_50)), label='PCA Component 50')
plt.plot(steps, np.concatenate((loss_l_100, loss_r_100)), label='PCA Component 100')
plt.axhline(y=flatness_threshold, color='r', linestyle='--', label='Flatness Threshold (L0 * e)')
plt.xlabel('Step Size along PCA Component')
plt.ylabel('Loss on Subset')
plt.title('Loss Landscape along PCA Directions')
plt.legend()
plt.grid()
plt.savefig('loss_landscape_pca_directions.png')
plt.show()

In [None]:
flatness_values = []
print(f"\nMeasuring flatness for the top {N_COMPONENTS_TO_ANALYZE} PCA directions...")

for i in tqdm(range(N_COMPONENTS_TO_ANALYZE), desc="Analyzing Directions"):
    p_i = pca_components[i]
    p_i_tensor = torch.tensor(p_i.reshape(CONFIGS["hidden_dim"], CONFIGS["hidden_dim"]),
                              dtype=torch.float32).to(CONFIGS["device"])

    delta_theta_r = 0.0
    for step in np.logspace(-2, 1.5, 50):
        model.fc2.weight.data = w0_tensor + step * p_i_tensor
        loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
        if loss > flatness_threshold:
            delta_theta_r = step
            break

    delta_theta_l = 0.0
    for step in np.logspace(-2, 1.5, 50):
        model.fc2.weight.data = w0_tensor - step * p_i_tensor
        loss = calculate_subset_loss(model, sub_loader, criterion, CONFIGS["device"])
        if loss > flatness_threshold:
            delta_theta_l = -step
            break
        
    flatness = delta_theta_r - delta_theta_l
    flatness_values.append(flatness)

flatness_values = np.array(flatness_values)

print("\nAnalysis complete. Generating plots...")
if not os.path.exists('plots'):
    os.makedirs('plots')

In [None]:
np.save('data/flatness_values.npy', flatness_values)

indices_analyzed = np.arange(1, N_COMPONENTS_TO_ANALYZE + 1)
variances_analyzed = variances[:N_COMPONENTS_TO_ANALYZE]

plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.plot(indices_analyzed, variances_analyzed, 'o-')
plt.yscale('log'); plt.xscale('log')
plt.title('Variance vs. PC Index'); plt.xlabel('PC Index (i)'); plt.ylabel('Variance (σ²)')
plt.grid(True, which="both", ls="--")

plt.subplot(1, 3, 2)
plt.plot(indices_analyzed, flatness_values, 'o-')
plt.title('Flatness vs. PC Index'); plt.xlabel('PC Index (i)'); plt.ylabel('Flatness (F)')
plt.yscale('log'); plt.xscale('log')
plt.grid(True, which="both", ls="--")

plt.subplot(1, 3, 3)
# Filter out any zero or negative flatness values before log transform for fitting
valid_mask = flatness_values > 0
if np.sum(valid_mask) > 1:
    log_F = np.log(flatness_values[valid_mask])
    log_var = np.log(variances_analyzed[valid_mask])
    m, c = np.polyfit(log_F, log_var, 1)
    plt.plot(flatness_values[valid_mask], np.exp(m * log_F + c), 'r--', label=f'Fit (slope ≈ {m:.2f})')

plt.plot(flatness_values, variances_analyzed, 'o', label='Experimental Data')
plt.title('Variance vs. Flatness'); plt.xlabel('Flatness (F)'); plt.ylabel('Variance (σ²)')
plt.yscale('log'); plt.xscale('log')
plt.legend(); plt.grid(True, which="both", ls="--")

plt.tight_layout()
plt.savefig('plots/feng_tu_2021_reproduction.png')
print("Saved plot to 'plots/feng_tu_2021_reproduction.png'")
plt.show()