In [None]:
# ==============================================================================
#      Numerical Experiment for Operator-Theoretic Error Analysis
# ==============================================================================
#
# Objective:
# To empirically validate the theory that the local discretization error in
# diffusion model sampling correlates with a proxy for the Koopman operator norm,
# which is tied to the Lipschitz properties of the learned score network.
#
# Steps:
# 1. Train two score-based models on a 2D toy dataset:
#    - Model A: Standard training.
#    - Model B: Training with regularization to enforce smoothness.
# 2. For each model, compute two metrics across time t:
#    - Metric 1 (Theory): A proxy for the operator norm (gradient norm of the score).
#    - Metric 2 (Empirical): The actual local discretization error.
# 3. Plot the results to show the correlation between the two metrics.
#
# ==============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_moons
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# --- 1. Setup and Hyperparameters ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data parameters
n_samples = 5000
batch_size = 256

# Diffusion parameters
T_end = 1.0
beta_0 = 0.0001
beta_T = 0.02

# Model and Training parameters
n_epochs = 2000
lr = 1e-4
lambda_reg = 0.01 # Regularization strength for Model B

# Analysis parameters
timesteps_to_eval = np.linspace(1e-5, T_end, 20)
n_error_samples = 1000 # Number of samples to estimate error

# --- 2. Diffusion Process Helper Functions (VP-SDE) ---

def beta_t(t):
    return beta_0 + t * (beta_T - beta_0)

def alpha_t(t):
    log_alpha = -0.25 * t**2 * (beta_T - beta_0) - 0.5 * t * beta_0
    return torch.exp(log_alpha)

def sigma_t(t):
    return torch.sqrt(1.0 - torch.exp(-0.5 * t**2 * (beta_T - beta_0) - t * beta_0))

# --- 3. Score Network Model ---

class ScoreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 128), nn.SiLU(),
            nn.Linear(128, 128), nn.SiLU(),
            nn.Linear(128, 2)
        )
        # Time embedding
        self.t_embed = nn.Sequential(nn.Linear(1, 128), nn.SiLU(), nn.Linear(128, 128))

    def forward(self, x, t):
        # Concatenate x and time embedding
        t_embedding = self.t_embed(t.view(-1, 1))
        x_with_time = torch.cat([x, t.view(-1, 1)], dim=1)
        return self.net(x_with_time)

# --- 4. Training Function ---

def train_model(is_regularized=False):
    print(f"--- Training Model {'B (Regularized)' if is_regularized else 'A (Standard)'} ---")
    score_net = ScoreNet().to(device)
    optimizer = optim.Adam(score_net.parameters(), lr=lr)
    
    # Create dataset
    X, _ = make_moons(n_samples=n_samples, noise=0.05)
    dataset = TensorDataset(torch.from_numpy(X).float())
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in tqdm(range(n_epochs), desc="Training"):
        for data, in loader:
            x0 = data.to(device)
            optimizer.zero_grad()
            
            # Sample random time t
            t = torch.rand(x0.shape[0], device=device) * T_end
            
            # Perturb data
            alpha = alpha_t(t).view(-1, 1)
            sigma = sigma_t(t).view(-1, 1)
            noise = torch.randn_like(x0)
            xt = alpha * x0 + sigma * noise
            
            # Predict score (which is proportional to -noise/sigma)
            predicted_noise = score_net(xt, t)
            target_noise = -noise / sigma
            
            # DSM loss
            loss = ((predicted_noise - target_noise)**2).mean()

            # --- Regularization for Model B ---
            if is_regularized:
                xt.requires_grad_(True)
                predicted_noise_reg = score_net(xt, t)
                
                # Compute gradient norm (proxy for Lipschitz)
                grad_outputs = torch.ones_like(predicted_noise_reg)
                gradients = torch.autograd.grad(
                    outputs=predicted_noise_reg,
                    inputs=xt,
                    grad_outputs=grad_outputs,
                    create_graph=True
                )[0]
                grad_norm = gradients.view(gradients.shape[0], -1).norm(2, dim=1)
                reg_loss = (grad_norm**2).mean()
                
                loss += lambda_reg * reg_loss

            loss.backward()
            optimizer.step()
    
    return score_net

# --- 5. Metric Calculation Functions ---

@torch.no_grad()
def calculate_proxy_lip(score_net, t_val, n_samples=1000):
    """Calculates the proxy for the operator norm (Metric 1)."""
    t = torch.full((n_samples,), t_val, device=device)
    # Sample from the perturbed distribution p_t (approximated)
    # We start from prior and denoise for a bit to get a better p_t sample
    z = torch.randn(n_samples, 2, device=device)
    xt = alpha_t(t).view(-1, 1) * torch.randn_like(z) + sigma_t(t).view(-1, 1) * z # Simple approximation of p_t
    
    xt.requires_grad_(True)
    
    predicted_noise = score_net(xt, t)
    
    grad_outputs = torch.ones_like(predicted_noise)
    gradients = torch.autograd.grad(
        outputs=predicted_noise,
        inputs=xt,
        grad_outputs=grad_outputs,
        create_graph=False
    )[0]
    
    grad_norm = gradients.view(gradients.shape[0], -1).norm(2, dim=1).mean().item()
    return grad_norm

@torch.no_grad()
def calculate_local_error(score_net, t_val, n_samples=1000, dt=0.01, small_steps=10):
    """Calculates the empirical local discretization error (Metric 2)."""
    t = torch.full((n_samples,), t_val, device=device)
    # Get samples at time t
    z = torch.randn(n_samples, 2, device=device)
    xt = alpha_t(t).view(-1, 1) * torch.randn_like(z) + sigma_t(t).view(-1, 1) * z

    # (A) "Ground Truth" solution with very small steps
    x_true = xt.clone()
    dt_small = dt / small_steps
    for i in range(small_steps):
        t_curr = t - i * dt_small
        beta = beta_t(t_curr).view(-1, 1)
        score = -score_net(x_true, t_curr) / sigma_t(t_curr).view(-1, 1)
        drift = -0.5 * beta * (x_true + 2 * score)
        x_true = x_true - drift * dt_small

    # (B) Approximated solution with one large step
    beta = beta_t(t).view(-1, 1)
    score = -score_net(xt, t) / sigma_t(t).view(-1, 1)
    drift = -0.5 * beta * (xt + 2 * score)
    x_approx = xt - drift * dt
    
    error = ((x_true - x_approx)**2).mean().item()
    return error

# --- 6. Main Execution ---

# Train the two models
model_A = train_model(is_regularized=False)
model_B = train_model(is_regularized=True)

models = {'Model A (Standard)': model_A, 'Model B (Regularized)': model_B}
results = {}

# Calculate metrics for each model
for name, model in models.items():
    print(f"--- Evaluating {name} ---")
    proxy_lips = []
    local_errors = []
    for t_val in tqdm(timesteps_to_eval, desc=f"Evaluating {name}"):
        proxy_lips.append(calculate_proxy_lip(model, t_val))
        local_errors.append(calculate_local_error(model, t_val))
    results[name] = {'proxy_lip': np.array(proxy_lips), 'error': np.array(local_errors)}

# --- 7. Plotting the Results ---

plt.style.use('seaborn-whitegrid')
fig, axes = plt.subplots(1, 3, figsize=(21, 6))
fig.suptitle("Operator-Theoretic Analysis of Discretization Error", fontsize=16)

# Plot 1: Proxy for Operator Norm vs. Time
ax = axes[0]
ax.plot(timesteps_to_eval, results['Model A (Standard)']['proxy_lip'], 'o-', label='Model A (Standard)')
ax.plot(timesteps_to_eval, results['Model B (Regularized)']['proxy_lip'], 's-', label='Model B (Regularized)')
ax.set_title("Result 1: Proxy for Operator Norm vs. Time")
ax.set_xlabel("Time (t)")
ax.set_ylabel("Proxy for Operator Norm (Score Gradient Norm)")
ax.legend()
ax.set_yscale('log')

# Plot 2: Local Discretization Error vs. Time
ax = axes[1]
ax.plot(timesteps_to_eval, results['Model A (Standard)']['error'], 'o-', label='Model A (Standard)')
ax.plot(timesteps_to_eval, results['Model B (Regularized)']['error'], 's-', label='Model B (Regularized)')
ax.set_title("Result 2: Local Discretization Error vs. Time")
ax.set_xlabel("Time (t)")
ax.set_ylabel("Empirical Local Error (MSE)")
ax.legend()
ax.set_yscale('log')

# Plot 3: Correlation Plot
ax = axes[2]
ax.plot(results['Model A (Standard)']['proxy_lip'], results['Model A (Standard)']['error'], 'o', label='Model A (Standard)')
ax.plot(results['Model B (Regularized)']['proxy_lip'], results['Model B (Regularized)']['error'], 's', label='Model B (Regularized)')
ax.set_title("Result 3: Correlation between Theory and Practice")
ax.set_xlabel("Theoretical Proxy (Score Gradient Norm)")
ax.set_ylabel("Empirical Error (MSE)")
ax.legend()
ax.set_xscale('log')
ax.set_yscale('log')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# Visualize some generated samples
@torch.no_grad()
def generate_samples(score_net, n_samples=500, n_steps=100):
    xt = torch.randn(n_samples, 2, device=device)
    ts = np.linspace(T_end, 1e-5, n_steps)
    dt = ts[0] - ts[1]
    
    for t_val in ts:
        t = torch.full((n_samples,), t_val, device=device)
        beta = beta_t(t).view(-1, 1)
        sigma = sigma_t(t).view(-1, 1)
        score = -score_net(xt, t) / sigma
        drift = -0.5 * beta * (xt + 2 * score)
        noise = torch.randn_like(xt) if t_val > 1e-5 else 0
        xt = xt - drift * dt + torch.sqrt(beta * dt) * noise
    return xt.cpu().numpy()

samples_A = generate_samples(model_A)
samples_B = generate_samples(model_B)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].scatter(samples_A[:, 0], samples_A[:, 1], s=5, alpha=0.5)
axes[0].set_title("Samples from Model A (Standard)")
axes[0].set_aspect('equal', adjustable='box')
axes[1].scatter(samples_B[:, 0], samples_B[:, 1], s=5, alpha=0.5)
axes[1].set_title("Samples from Model B (Regularized)")
axes[1].set_aspect('equal', adjustable='box')
plt.show()