In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

# Set seeds
torch.manual_seed(0)
np.random.seed(0)

# Simulator parameters
d = 5
A0 = torch.eye(d)
noise_std = 0.1

# Neural net
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, d)
        )

    def forward(self, x):
        return self.net(x)

# Simulate data from a given A
def simulate(A, n=1000):
    x = torch.randn(n, d)
    noise = noise_std * torch.randn(n, d)
    y = x @ A.T + noise
    return x, y

# Theoretical worst-case bound
def theoretical_bound(empirical_loss, delta, sigma=noise_std, d=d):
    return empirical_loss + 0.5 * delta * (d**0.5) / sigma

# Empirical data-dependent bound
def empirical_bound(empirical_loss, A0, A_star, X_test, sigma=noise_std):
    delta_matrix = A_star - A0
    mean_shift = (X_test @ delta_matrix.T).norm(dim=1).mean().item()
    return empirical_loss + mean_shift / (2 * sigma)

# Training and evaluation
def train_and_eval(A_star):
    X_train, y_train = simulate(A0, n=2000)
    X_test, y_test = simulate(A_star, n=1000)

    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    for _ in range(2000):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(model(X_train), y_train)
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        test_loss = loss_fn(model(X_test), y_test).item()

    return test_loss, X_test

# Run across mismatch levels
delta_vals = torch.linspace(0, 2, 20)
empirical_losses = []
theoretical_bounds = []
empirical_bounds = []

for delta in delta_vals:
    # Construct A* = A0 + delta * U / ||U||
    U = torch.randn(d, d)
    U = U / torch.norm(U)
    A_star = A0 + delta * U

    test_loss, X_test = train_and_eval(A_star)
    empirical_losses.append(test_loss)

    # Theoretical worst-case bound
    theoretical_bounds.append(theoretical_bound(test_loss, delta.item()))

    # Empirical bound using actual ||Δx||
    empirical_bounds.append(empirical_bound(test_loss, A0, A_star, X_test))

# Convert to numpy
delta_vals = delta_vals.numpy()
empirical_losses = np.array(empirical_losses)
theoretical_bounds = np.array(theoretical_bounds)
empirical_bounds = np.array(empirical_bounds)

plt.rcParams.update({
    "font.family": "serif",
    "font.size": 14,
    "figure.dpi": 300,
    "text.usetex": False  # Set to True if you're using LaTeX installation
})

fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

# Left: theoretical bound
axs[0].plot(delta_vals, empirical_losses, label='Empirical Test Loss', marker='o', linewidth=2)
axs[0].plot(delta_vals, theoretical_bounds, '--', label='Worst-Case Bound', color='orange', linewidth=2)
axs[0].set_title('Worst-Case Generalization Bound', fontsize=14)
axs[0].set_xlabel(r'Simulator Mismatch $\delta$', fontsize=13)
axs[0].set_ylabel('Test MSE Loss', fontsize=13)
axs[0].set_yscale('log')
axs[0].legend(fontsize=12)
axs[0].tick_params(axis='both', which='major', labelsize=12)

# Right: empirical bound
axs[1].plot(delta_vals, empirical_losses, label='Empirical Test Loss', marker='o', linewidth=2)
axs[1].plot(delta_vals, empirical_bounds, '--', label='Empirical Bound', color='green', linewidth=2)
axs[1].set_title('Data-Dependent Generalization Bound', fontsize=14)
axs[1].set_xlabel(r'Simulator Mismatch $\delta$', fontsize=13)
axs[1].set_yscale('log')
axs[1].legend(fontsize=12)
axs[1].tick_params(axis='both', which='major', labelsize=12)

# Tight layout and save
plt.tight_layout()
plt.savefig("sgnn_generalization_bounds.pdf", format="pdf", bbox_inches='tight')
plt.close()