In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from torch.distributions import MultivariateNormal, kl_divergence, Normal

# ==========================================
# 1. PVI-GP Classifier Model
# ==========================================
class PVIGPClassifier(nn.Module):
    def __init__(self, X_train):
        super().__init__()
        self.X_train = X_train
        self.n = X_train.shape[0]
        self.jitter = 1e-5
        
        # 1. Hyperparameters (Kernel)
        # Softplus를 통과시킬 것이므로 초기값은 적당히 설정
        self.log_lengthscale = nn.Parameter(torch.tensor(np.log(0.5)))
        self.log_scale = nn.Parameter(torch.tensor(np.log(1.0)))
        
        # 2. Variational Parameters q(f)
        # q(f) ~ N(mu, L*L^T)
        self.q_mu = nn.Parameter(torch.zeros(self.n))
        self.q_L_vec = nn.Parameter(torch.randn(self.n * (self.n + 1) // 2) * 0.05)

    def kernel_matrix(self, x1, x2):
        # [Robust] Lengthscale이 0으로 수렴하는 것을 방지 (Softplus + 0.1)
        lengthscale = F.softplus(self.log_lengthscale) + 0.1
        scale = F.softplus(self.log_scale) + 0.1
        
        dist_sq = torch.cdist(x1, x2, p=2) ** 2
        return scale * torch.exp(-0.5 * dist_sq / (lengthscale ** 2))

    def get_q_dist(self):
        # Construct Covariance Matrix from Cholesky factor L
        L = torch.zeros(self.n, self.n, device=self.X_train.device)
        indices = torch.tril_indices(row=self.n, col=self.n)
        L[indices[0], indices[1]] = self.q_L_vec
        
        # 대각 성분은 양수여야 함 (Exp 처리)
        diag_idx = range(self.n)
        L[diag_idx, diag_idx] = torch.exp(L[diag_idx, diag_idx])
        
        q_cov = L @ L.t() + torch.eye(self.n, device=self.X_train.device) * self.jitter
        return MultivariateNormal(self.q_mu, covariance_matrix=q_cov)

    def get_p_dist(self):
        # Prior p(f) ~ N(0, K)
        K = self.kernel_matrix(self.X_train, self.X_train)
        K = K + torch.eye(self.n, device=self.X_train.device) * self.jitter
        return MultivariateNormal(torch.zeros(self.n, device=self.X_train.device), covariance_matrix=K)

    def predict(self, X_test, num_samples=1000):
        # [Efficient] O(N^3) 회피 및 대각 성분(Marginal Variance)만 계산
        K_xx = self.kernel_matrix(self.X_train, self.X_train) + torch.eye(self.n, device=self.X_train.device) * self.jitter
        K_sx = self.kernel_matrix(X_test, self.X_train)
        K_ss_diag = self.kernel_matrix(X_test, X_test).diag()
        
        q_dist = self.get_q_dist()
        
        # Linear Algebra for Mean/Var
        K_inv_mu = torch.linalg.solve(K_xx, q_dist.mean)
        K_inv_K_xs_T = torch.linalg.solve(K_xx, K_sx.t())
        
        mu_star = K_sx @ K_inv_mu
        
        # 대각 성분만 효율적으로 계산 (sum(dim=1))
        term2 = (K_sx * K_inv_K_xs_T.t()).sum(dim=1)
        B = torch.linalg.solve(K_xx, q_dist.covariance_matrix @ K_inv_K_xs_T)
        term3 = (K_sx * B.t()).sum(dim=1)
        
        var_star = (K_ss_diag - term2 + term3).clamp(min=1e-6)
        
        # Sampling for Prediction
        q_star = Normal(mu_star, torch.sqrt(var_star))
        f_samples = q_star.sample((num_samples,))
        probs = torch.sigmoid(f_samples).mean(dim=0)
        
        return probs

# ==========================================
# [cite_start]2. PVI Objective (Log Score) [cite: 81, 88]
# ==========================================
def pvi_objective(model, y_target, num_samples=32, lambda_reg=0.05):
    q_dist = model.get_q_dist()
    
    # 1. Reparameterization Trick
    f_samples = q_dist.rsample((num_samples,)) 
    
    # 2. Logistic Link
    probs = torch.sigmoid(f_samples)
    likelihoods = torch.where(y_target.unsqueeze(0) == 1, probs, 1 - probs)
    
    # 3. [Key] Log-Mean-Exp (PVI)
    # Mean(Integral) first, then Log.
    predictive_likelihood = likelihoods.mean(dim=0) 
    log_score = torch.log(predictive_likelihood + 1e-10).sum()
    
    # 4. Regularization
    kl_reg = kl_divergence(q_dist, model.get_p_dist())
    
    return -log_score + lambda_reg * kl_reg

# ==========================================
# 3. Training Loop
# ==========================================
def train():
    # Data: Noisy Moons (to test robustness)
    X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)

    model = PVIGPClassifier(X)
    
    # Differential Learning Rates
    optimizer = torch.optim.Adam([
        {'params': [model.q_mu, model.q_L_vec], 'lr': 0.02}, # q(f)는 빠르게
        {'params': [model.log_lengthscale, model.log_scale], 'lr': 0.01} # Kernel은 천천히
    ])
    
    print("Training PVI-GP (Logistic Likelihood)...")
    for epoch in range(5000):
        optimizer.zero_grad()
        loss = pvi_objective(model, y, num_samples=32, lambda_reg=0.05)
        loss.backward()
        optimizer.step()
        
        if epoch % 200 == 0:
            print(f"Epoch {epoch:4d} | Loss: {loss.item():.4f}")

    return model, X, y

# ==========================================
# 4. Visualization
# ==========================================
def visualize(model, X_train, y_train):
    model.eval()
    
    x_min, x_max = X_train[:, 0].min() - 0.5, X_train[:, 0].max() + 0.5
    y_min, y_max = X_train[:, 1].min() - 0.5, X_train[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                         np.arange(y_min, y_max, 0.1))
    
    X_test = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
    
    with torch.no_grad():
        probs = model.predict(X_test, num_samples=500)
        probs = probs.reshape(xx.shape)
        
    plt.figure(figsize=(8, 6))
    contour = plt.contourf(xx, yy, probs, levels=20, cmap='RdBu', alpha=0.8, vmin=0, vmax=1)
    plt.colorbar(contour, label="Predictive Probability")
    
    plt.scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], c='blue', edgecolors='k', label='Class 0')
    plt.scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], c='red', edgecolors='k', label='Class 1')
    
    plt.title("PVI GP Classifier (Logistic Link, No Approximation)")
    plt.legend()
    plt.show()

if __name__ == "__main__":
    trained_model, X, y = train()
    visualize(trained_model, X, y)