In [3]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange

# ==============================
# Simulator
# ==============================

def sample_theta(n, a=0.5, b=1.5):
    return np.random.uniform(low=a, high=b, size=(n, 2))

def simulate(theta, T=10, sigma=0.05):
    X = []
    for alpha, beta in theta:
        A = np.array([[alpha, 0], [0, beta]])
        x_t = np.array([1.0, 1.0])
        traj = []
        for _ in range(T):
            x_t = A @ x_t + np.random.normal(0, sigma, size=2)
            traj.append(x_t.copy())
        X.append(np.array(traj).flatten())
    return np.array(X)

# ==============================
# Bayes-optimal predictor
# ==============================

def compute_bayes_optimal_predictor(x_query, x_mc, theta_mc, k=50):
    nn = NearestNeighbors(n_neighbors=k).fit(x_mc)
    dists, indices = nn.kneighbors(x_query)
    median_dist = np.median(dists, axis=1, keepdims=True)
    adaptive_sigma = 0.5 * median_dist + 1e-8
    weights = np.exp(-dists**2 / (2 * adaptive_sigma**2))
    weights /= np.sum(weights, axis=1, keepdims=True)
    theta_neighbors = theta_mc[indices]
    return np.einsum('ij,ijk->ik', weights, theta_neighbors)

# ==============================
# Normalizer
# ==============================

class Normalizer:
    def __init__(self):
        self.x_scaler = StandardScaler()
        self.y_scaler = StandardScaler()

    def fit(self, x, y):
        self.x_scaler.fit(x)
        self.y_scaler.fit(y)

    def transform(self, x, y):
        return self.x_scaler.transform(x), self.y_scaler.transform(y)

    def transform_x(self, x):
        return self.x_scaler.transform(x)

    def inverse_y(self, y):
        return self.y_scaler.inverse_transform(y)

# ==============================
# SGNN Model
# ==============================

class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout=0.05):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )

    def forward(self, x):
        return x + self.block(x)

class SGNN(nn.Module):
    def __init__(self, input_dim=20, hidden_dim=1024, n_blocks=8):
        super().__init__()
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
        )
        self.blocks = nn.Sequential(*[ResidualBlock(hidden_dim) for _ in range(n_blocks)])
        self.out_proj = nn.Linear(hidden_dim, 2)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.blocks(x)
        return self.out_proj(x)

# ==============================
# Main Training + Evaluation Loop
# ==============================

# Reference set for Bayes estimator
theta_ref = sample_theta(10000)
x_ref = simulate(theta_ref)

theta_val = sample_theta(1000)
x_val = simulate(theta_val)

normalizer = Normalizer()
normalizer.fit(x_ref, theta_ref)
x_ref_norm, theta_ref_norm = normalizer.transform(x_ref, theta_ref)
x_val_norm = normalizer.transform_x(x_val)
theta_val_norm = normalizer.y_scaler.transform(theta_val)

# Bayes prediction
bayes_y_val_norm = compute_bayes_optimal_predictor(x_val_norm, x_ref_norm, theta_ref_norm)
bayes_y_val = normalizer.inverse_y(bayes_y_val_norm)
bayes_rmse = np.sqrt(mean_squared_error(bayes_y_val, theta_val))
print(f"[Bayes] RMSE to true θ: {bayes_rmse:.4f}")

# SGNN training on 5M
theta_train = sample_theta(5_000_000)
x_train = simulate(theta_train)

x_tr, x_holdout, y_tr, y_holdout = train_test_split(x_train, theta_train, test_size=0.005)
normalizer = Normalizer()
normalizer.fit(x_tr, y_tr)
x_tr_norm, y_tr_norm = normalizer.transform(x_tr, y_tr)
x_hold_norm = normalizer.transform_x(x_holdout)
y_hold_bayes_norm = compute_bayes_optimal_predictor(x_hold_norm, x_ref_norm, theta_ref_norm)
y_hold_bayes = normalizer.inverse_y(y_hold_bayes_norm)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_train_tensor = torch.tensor(x_tr_norm, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_tr_norm, dtype=torch.float32).to(device)
train_dataset = torch.utils.data.TensorDataset(x_train_tensor, y_train_tensor)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True)

model = SGNN(input_dim=x_tr.shape[1]).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader))

# Tracking
mse_to_bayes_list = []
sgnn_vs_true_rmse_list = []
bayes_vs_true_rmse_list = []
training_steps = []

for batch_idx, (xb, yb) in enumerate(train_loader):
    model.train()
    preds = model(xb)
    loss = loss_fn(preds, yb)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()

    if (batch_idx + 1) % (len(train_loader) // 6) == 0:
        model.eval()
        with torch.no_grad():
            val_preds = model(torch.tensor(x_hold_norm, dtype=torch.float32).to(device))
            val_preds_np = val_preds.cpu().numpy()
            val_preds_denorm = normalizer.inverse_y(val_preds_np)

            rmse_to_bayes = np.sqrt(mean_squared_error(
                y_hold_bayes_norm, normalizer.y_scaler.transform(val_preds_denorm)))
            sgnn_vs_true_rmse = np.sqrt(mean_squared_error(val_preds_denorm, y_holdout))
            bayes_vs_true_rmse = np.sqrt(mean_squared_error(y_hold_bayes, y_holdout))
            step_count = (batch_idx + 1) * 512  # samples seen

            print(f"[Step {step_count}] "
                  f"SGNN→Bayes RMSE = {rmse_to_bayes:.4f}, "
                  f"SGNN→θ RMSE = {sgnn_vs_true_rmse:.4f}, "
                  f"Bayes→θ RMSE = {bayes_vs_true_rmse:.4f}")

            mse_to_bayes_list.append(rmse_to_bayes)
            sgnn_vs_true_rmse_list.append(sgnn_vs_true_rmse)
            bayes_vs_true_rmse_list.append(bayes_vs_true_rmse)
            training_steps.append(step_count)

# === Plotting ===
plt.rcParams.update({
    "font.family": "serif",
    "font.size": 14,
    "figure.dpi": 300,
    "text.usetex": False
})

fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharex=False)

axs[0].plot(training_steps, mse_to_bayes_list, label='SGNN → Bayes RMSE', marker='o', linewidth=2)
axs[0].set_title('Convergence to Bayes Predictor', fontsize=14)
axs[0].set_xlabel('Training Samples Seen', fontsize=13)
axs[0].set_ylabel('MSE to Bayes Predictions', fontsize=13)
axs[0].grid(True)
axs[0].tick_params(axis='both', which='major', labelsize=12)

axs[1].plot(training_steps, sgnn_vs_true_rmse_list, label='SGNN → θ RMSE', marker='o', linewidth=2)
axs[1].axhline(y=bayes_rmse, color='green', linestyle='--', linewidth=2, label='Bayes → θ RMSE')
axs[1].set_title('Parameter Estimation Accuracy', fontsize=14)
axs[1].set_xlabel('Training Samples Seen', fontsize=13)
axs[1].set_ylabel('RMSE to True θ', fontsize=13)
axs[1].legend(fontsize=12)
axs[1].grid(True)
axs[1].tick_params(axis='both', which='major', labelsize=12)

plt.tight_layout()
plt.savefig("sgnn_accuracy.pdf", format="pdf", bbox_inches='tight')
plt.close()

[Bayes] RMSE to true θ: 0.0271
[Step 828928] SGNN→Bayes RMSE = 0.0845, SGNN→θ RMSE = 0.0325, Bayes→θ RMSE = 0.0269
[Step 1657856] SGNN→Bayes RMSE = 0.0527, SGNN→θ RMSE = 0.0243, Bayes→θ RMSE = 0.0269
[Step 2486784] SGNN→Bayes RMSE = 0.0538, SGNN→θ RMSE = 0.0237, Bayes→θ RMSE = 0.0269
[Step 3315712] SGNN→Bayes RMSE = 0.0519, SGNN→θ RMSE = 0.0232, Bayes→θ RMSE = 0.0269
[Step 4144640] SGNN→Bayes RMSE = 0.0526, SGNN→θ RMSE = 0.0228, Bayes→θ RMSE = 0.0269
[Step 4973568] SGNN→Bayes RMSE = 0.0509, SGNN→θ RMSE = 0.0226, Bayes→θ RMSE = 0.0269
