In [None]:
def calculate_loo_means(model, likelihood, train_x, train_y):
    """
    Calculate LOO predictive means using the formula:
    μᵢ = yᵢ - [K⁻¹y]ᵢ/[K⁻¹]ᵢᵢ
    
    Args:
        model: GPyTorch ExactGP model
        likelihood: GPyTorch likelihood
        train_x: Training inputs
        train_y: Training targets
        
    Returns:
        torch.Tensor: LOO predictive means
    """
    with torch.no_grad():
        # Get the kernel matrix K
        output = model(train_x)
        K = output.covariance_matrix
        
        # Add noise variance to diagonal
        K = K + likelihood.noise * torch.eye(K.shape[0], device=K.device)
        
        # Compute K⁻¹ using Cholesky decomposition for stability
        L = torch.linalg.cholesky(K)
        K_inv = torch.cholesky_inverse(L)
        
        # Compute K⁻¹y
        K_inv_y = K_inv @ train_y
        
        # Get diagonal elements of K⁻¹
        K_inv_diag = K_inv.diag()
        
        # Calculate LOO means
        loo_means = train_y - K_inv_y / K_inv_diag
        
        return loo_means

# Example usage:
if __name__ == "__main__":
    # Generate synthetic data
    train_x = torch.linspace(0, 1, 100).to(device)
    train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * math.sqrt(0.04)
    train_x = train_x.reshape(-1, 1)
    
    # Initialize model with fixed parameters
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    model = ExactGPModel(train_x, train_y, likelihood).to(device)
    
    # Set fixed parameters
    model.covar_module.base_kernel.lengthscale = 0.2
    model.covar_module.outputscale = 1.0
    likelihood.noise = 0.04
    
    # Disable gradient tracking
    for param in model.parameters():
        param.requires_grad = False
    
    # Calculate LOO means
    model.eval()
    likelihood.eval()
    loo_means = calculate_loo_means(model, likelihood, train_x, train_y)
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(train_x.cpu().numpy(), train_y.cpu().numpy(), 'k*', label='Actual')
    plt.plot(train_x.cpu().numpy(), loo_means.cpu().numpy(), 'r-', label='LOO Predictions')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('LOO Predictive Means')
    plt.legend()
    plt.show()
    
    # Print numerical precision info
    print("\nNumerical Precision Information:")
    print(f"Condition number of K: {torch.linalg.cond(K).item():.2e}")
    print(f"Max absolute value in K_inv: {torch.max(torch.abs(K_inv)).item():.2e}")
    print(f"Min diagonal element of K_inv: {torch.min(K_inv_diag).item():.2e}")