In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [None]:

# Synthetic data: y = Xw + noise
# We'll choose N << P to create an overparameterized scenario.
N = 50
P = 2000
X = torch.randn(N, P)
true_w = torch.randn(P) * 0.1
y = X @ true_w


In [None]:

# Define a simple linear model
class OverParamLinear(nn.Module):
    def __init__(self, p):
        super(OverParamLinear, self).__init__()
        self.w = nn.Parameter(torch.zeros(p))

    def forward(self, X):
        return X @ self.w


In [None]:

model = OverParamLinear(P)
criterion = nn.MSELoss()

# SGLD-like updates:
lr = 1e-3
T = 1e-4  # "Temperature"
steps = 20000
burn_in = 10000

w_samples = []

optimizer = optim.SGD(model.parameters(), lr=lr)

In [None]:

for step in range(steps):
    optimizer.zero_grad()
    pred = model(X)
    loss = criterion(pred, y)
    loss.backward()

    # Standard SGD step
    for p in model.parameters():
        # Add Gaussian noise for SGLD
        noise = torch.randn_like(p) * np.sqrt(2 * lr * T)
        p.data = p.data - lr * p.grad.data + noise

    # Collect samples after burn-in
    if step > burn_in and step % 10 == 0:
        w_samples.append(model.w.data.clone())

w_samples = torch.stack(w_samples)  # shape: [num_samples, P]

In [None]:

# Check flatness: For each sample, perturb and check loss increase
def flatness_score(w, eps=0.01, directions=10):
    base_loss = criterion(X @ w, y).item()
    increases = []
    for _ in range(directions):
        direction = torch.randn_like(w)
        direction = direction / direction.norm()
        w_pert = w + eps * direction
        pert_loss = criterion(X @ w_pert, y).item()
        increases.append(pert_loss - base_loss)
    return np.mean(increases)

scores = [flatness_score(ws) for ws in w_samples]


In [None]:

plt.figure(figsize=(6,4))
plt.hist(scores, bins=30, color='steelblue', alpha=0.7)
plt.xlabel('Average Loss Increase under Perturbation', fontsize=14)
plt.ylabel('Frequency', fontsize=14)
# plt.title('Distribution of Flatness Scores for Sampled Solutions', fontsize=16)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('flatness_distribution.png', dpi=300)
plt.show()

# Interpretation:
# If most sampled solutions have low loss-increase after perturbation,
# it indicates they lie in flat regions. As P is large and no unique
# isolated solution exists, we expect the stationary measure to find
# these large, flat solution sets.