In [1]:
import torch
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Cached:    {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

Allocated: 0.00 MB
Cached:    0.00 MB


In [4]:
# REAL DATA EXPERIMENTS
#               _                 _   _      
#              | |               | | (_)     
#    __ _ _ __ | |_ __ _ _ __ ___| |_ _  ___ 
#   / _` | '_ \| __/ _` | '__/ __| __| |/ __|
#  | (_| | | | | || (_| | | | (__| |_| | (__ 
#   \__,_|_| |_|\__\__,_|_|  \___|\__|_|\___|
# 
model_name = "dfNGP"

# import configs to we can access the hypers with getattr
import configs
from configs import PATIENCE, MAX_NUM_EPOCHS, NUM_RUNS, WEIGHT_DECAY

# Reiterating import for visibility
MAX_NUM_EPOCHS = MAX_NUM_EPOCHS
NUM_RUNS = NUM_RUNS
WEIGHT_DECAY = WEIGHT_DECAY
PATIENCE = PATIENCE

# TODO: Delete overwrite, run full
NUM_RUNS = 1

# assign model-specific variable
MODEL_LEARNING_RATE = getattr(configs, f"{model_name}_REAL_LEARNING_RATE")
MODEL_LEARNING_RATE = 0.005
MODEL_REAL_RESULTS_DIR = getattr(configs, f"{model_name}_REAL_RESULTS_DIR")
import os
os.makedirs(MODEL_REAL_RESULTS_DIR, exist_ok = True)

# imports for probabilistic models
if model_name in ["GP", "dfGP", "dfNGP"]:
    from GP_models import GP_predict
    from metrics import compute_NLL_sparse, compute_NLL_full
    from configs import L_RANGE
    if model_name in ["dfGP", "dfNGP"]:
        from configs import SIGMA_F_RANGE
    if model_name == "GP":
        from configs import B_DIAGONAL_RANGE, B_OFFDIAGONAL_RANGE

# for all models with NN components train on batches
if model_name in ["dfNGP", "dfNN", "PINN"]:
    from configs import BATCH_SIZE

if model_name in ["dfNGP", "dfNN"]:
    from NN_models import dfNN

# universals 
from metrics import compute_RMSE, compute_MAE, compute_divergence_field

# basics
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

# utilitarian
from utils import set_seed
# reproducibility
set_seed(42)
import gc

# setting device to GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# overwrite if needed: # device = 'cpu'
print('Using device:', device)
print()

### START TIMING ###
import time
start_time = time.time()  # Start timing after imports

#############################
### LOOP 1 - over REGIONS ###
#############################

# For region_name in ["regiona", "regionb", "regionc"]:
for region_name in ["regionc"]:

    print(f"\nTraining for {region_name.upper()}...")

    # Store metrics for the current region (used for metrics_summary report)
    region_results = []

    ##########################################
    ### x_train & y_train, x_test & x_test ###
    ##########################################

    # define paths based on region_name
    path_to_training_tensor = "data/real_data/" + region_name + "_train_tensor.pt"
    path_to_test_tensor = "data/real_data/" + region_name + "_test_tensor.pt"

    # load and tranpose to have rows as points
    train = torch.load(path_to_training_tensor, weights_only = False).T 
    test = torch.load(path_to_test_tensor, weights_only = False).T

    # The train and test tensors have the following columns:
    # [:, 0] = x
    # [:, 1] = y
    # [:, 2] = surface elevation (s)
    # [:, 3] = ice flux in x direction (u)
    # [:, 4] = ice flux in y direction (v)
    # [:, 5] = ice flux error in x direction (u_err)
    # [:, 6] = ice flux error in y direction (v_err)

    # train
    x_train = train[:, [0, 1]].to(device)
    y_train = train[:, [3, 4]].to(device)

    # test
    x_test = test[:, [0, 1]].to(device)
    y_test = test[:, [3, 4]].to(device)

    # local measurment errors as noise
    train_noise_diag = torch.concat((train[:, 5], train[:, 6]), dim = 0).to(device)

    # Print train details
    print(f"=== {region_name.upper()} ===")
    print(f"Training inputs shape: {x_train.shape}")
    print(f"Training observations shape: {y_train.shape}")
    print(f"Training inputs dtype: {x_train.dtype}")
    print()

    # Print test details
    print(f"=== {region_name.upper()} ===")
    print(f"Test inputs shape: {x_test.shape}")
    print(f"Test observations shape: {y_test.shape}")
    print(f"Test inputs dtype: {x_test.dtype}")
    print()

    ##################################
    ### LOOP 2 - over training run ###
    ##################################

    # NOTE: GPs don't train on batches, use full data, even for dfNGP

    for run in range(NUM_RUNS):

        print(f"\n--- Training Run {run + 1}/{NUM_RUNS} ---")

        # NOTE: The dfNN mean function uses autograd and thus required x_train to be set to .requires_grad
        x_train = x_train.to(device).requires_grad_(True)
        # same for x_test for eval round
        x_test = x_test.to(device).requires_grad_(True)

        ### Initialise dfNGP hyperparameters ###
        # 3 learnable HPs, same as dfGP
        # NOTE: at every run this initialisation changes, introducing some randomness
        # HACK: we need to use nn.Parameter for trainable hypers to avoid leaf variable error

        # initialising (trainable) output scalar from a uniform distribution over a predefined range
        sigma_f = nn.Parameter(torch.empty(1, device = device).uniform_( * SIGMA_F_RANGE))

        # initialising (trainable) lengthscales from a uniform distribution over a predefined range
        # each dimension has its own lengthscale
        l = nn.Parameter(torch.empty(2, device = device).uniform_( * L_RANGE))

        # For every run initialise a (new) mean model
        dfNN_mean_model = dfNN().to(device)

        # NOTE: We don't need a criterion either

        # AdamW as optimizer for some regularisation/weight decay
        optimizer = optim.AdamW(list(dfNN_mean_model.parameters()) + [sigma_f, l], lr = MODEL_LEARNING_RATE, weight_decay = WEIGHT_DECAY)

        # _________________
        # BEFORE EPOCH LOOP
        
        # Export the convergence just for first run only
        if run == 0:
            # initialise tensors to store losses over epochs (for convergence plot)
            train_losses_NLML_over_epochs = torch.zeros(MAX_NUM_EPOCHS) # objective
            train_losses_RMSE_over_epochs = torch.zeros(MAX_NUM_EPOCHS) # by-product
            # monitor performance transfer to test (only RMSE easy to calc without covar)
            test_losses_RMSE_over_epochs = torch.zeros(MAX_NUM_EPOCHS)

            sigma_f_over_epochs = torch.zeros(MAX_NUM_EPOCHS)
            l1_over_epochs = torch.zeros(MAX_NUM_EPOCHS)
            l2_over_epochs = torch.zeros(MAX_NUM_EPOCHS)

        # Early stopping variables
        best_loss = float('inf')
        # counter starts at 0
        epochs_no_improve = 0

        ############################
        ### LOOP 3 - over EPOCHS ###
        ############################
        print("\nStart Training")

        for epoch in range(MAX_NUM_EPOCHS):

            # Assure model is in training mode
            dfNN_mean_model.train()

            # For Run 1 we save a bunch of metrics and update, while for the rest we only update
            if run == 0:
                mean_pred_train, _, lml_train = GP_predict(
                        x_train,
                        y_train,
                        x_train, # predict training data
                        [train_noise_diag, sigma_f, l], # list of (initial) hypers
                        mean_func = dfNN_mean_model, # dfNN as mean function
                        divergence_free_bool = True) # ensures we use a df kernel

                # Compute test loss for loss convergence plot
                mean_pred_test, _, _ = GP_predict(
                        x_train,
                        y_train,
                        x_test.to(device), # have predictions for training data again
                        # HACK: This is rather an eval, so we use detached hypers to avoid the computational tree
                        [train_noise_diag, sigma_f.detach().clone(), l.detach().clone()], # list of (initial) hypers
                        mean_func = dfNN_mean_model, # dfNN as mean function
                        divergence_free_bool = True) # ensures we use a df kernel
                
                # UPDATE HYPERS (after test loss is computed to use same model)
                optimizer.zero_grad() # don't accumulate gradients
                # negative for NLML. loss is always on train
                loss = - lml_train
                loss.backward()
                optimizer.step()
                
                # NOTE: it is important to detach here 
                train_RMSE = compute_RMSE(y_train.detach(), mean_pred_train.detach())
                test_RMSE = compute_RMSE(y_test.detach(), mean_pred_test.detach())

                # Save losses for convergence plot
                train_losses_NLML_over_epochs[epoch] = - lml_train
                train_losses_RMSE_over_epochs[epoch] = train_RMSE
                # NOTE: lml is always just given training data. There is no TEST NLML
                test_losses_RMSE_over_epochs[epoch] = test_RMSE

                # Save evolution of hyprs for convergence plot
                sigma_f_over_epochs[epoch] = sigma_f[0]
                l1_over_epochs[epoch] = l[0]
                l2_over_epochs[epoch] = l[1]

                print(f"{region_name} {model_name} Run {run + 1}/{NUM_RUNS}, Epoch {epoch + 1}/{MAX_NUM_EPOCHS}, Training Loss (NLML): {loss:.4f}, (RMSE): {train_RMSE:.4f}")

                # delete after printing and saving
                # NOTE: keep loss for early stopping check
                del mean_pred_train, mean_pred_test, lml_train, train_RMSE, test_RMSE
                
                # Free up memory every 20 epochs
                if epoch % 20 == 0:
                    gc.collect() and torch.cuda.empty_cache()
            
             # For all runs after the first we run a minimal version using only lml_train
            else:

                # NOTE: We can use x_train[0:2] since the predictions doesn;t matter and we only care about lml_train
                _, _, lml_train = GP_predict(
                        x_train,
                        y_train,
                        x_train[0:2], # predictions don't matter and we output lml_train already
                        [train_noise_diag, sigma_f, l], # list of (initial) hypers
                        mean_func = dfNN_mean_model, # dfNN as mean function
                        divergence_free_bool = True) # ensures we use a df kernel
                
                # UPDATE HYPERS (after test loss is computed to use same model)
                optimizer.zero_grad() # don't accumulate gradients
                # negative for NLML
                loss = - lml_train
                loss.backward()
                optimizer.step()

                # After run 1 we only print lml, nothing else
                print(f"{region_name} {model_name} Run {run + 1}/{NUM_RUNS}, Epoch {epoch + 1}/{MAX_NUM_EPOCHS}, Training Loss (NLML): {loss:.4f}")

                # NOTE: keep loss for early stopping check, del lml_train
                del lml_train
                
                # Free up memory every 20 epochs
                if epoch % 20 == 0:
                    gc.collect() and torch.cuda.empty_cache()

            # EVERY EPOCH: Early stopping check
            if loss < best_loss:
                best_loss = loss
                # reset counter if loss improves
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= PATIENCE:
                print(f"Early stopping triggered after {epoch + 1} epochs.")
                # exit epoch loop
                break

        ##############################
        ### END LOOP 3 over EPOCHS ###
        ##############################

        # for every run...
        #######################################################
        ### EVALUATE after all training for RUN is finished ###
        #######################################################

        # Evaluate the trained model after all epochs are finished or early stopping was triggered
        # NOTE: Detach tuned hyperparameters from the computational graph
        best_sigma_f = sigma_f.detach().clone()
        best_l = l.detach().clone()

        # Need gradients for autograd divergence: We clone and detach
        x_test_grad = x_test.to(device).clone().requires_grad_(True)

        mean_pred_test, covar_pred_test, _ = GP_predict(
            x_train,
            y_train,
            x_test_grad,
            [train_noise_diag, best_sigma_f, best_l], # list of (initial) hypers
            mean_func = dfNN_mean_model, # dfNN as mean function
            divergence_free_bool = True) # ensures we use a df kernel
        
        # Compute divergence field
        dfGP_test_div_field = compute_divergence_field(mean_pred_test, x_test_grad)

        # Only save mean_pred, covar_pred and divergence fields for the first run
        if run == 0:

            # (1) Save predictions from first run so we can visualise them later
            torch.save(mean_pred_test, f"{MODEL_REAL_RESULTS_DIR}/{region_name}_{model_name}_test_mean_predictions.pt")
            torch.save(covar_pred_test, f"{MODEL_REAL_RESULTS_DIR}/{region_name}_{model_name}_test_covar_predictions.pt")

            # (2) Save best hyperparameters
            # Stack tensors into a single tensor
            best_hypers_tensor = torch.cat([
                best_sigma_f,
                best_l
            ])

            torch.save(best_hypers_tensor, f"{MODEL_REAL_RESULTS_DIR}/{region_name}_{model_name}_best_hypers.pt")

            # (3) Since all epoch training is finished, we can save the losses over epochs
            df_losses = pd.DataFrame({
                'Epoch': list(range(train_losses_NLML_over_epochs.shape[0])), # pythonic indexing
                'Train Loss NLML': train_losses_NLML_over_epochs.tolist(),
                'Train Loss RMSE': train_losses_RMSE_over_epochs.tolist(),
                'Test Loss RMSE': test_losses_RMSE_over_epochs.tolist(),
                'Sigma_f': sigma_f_over_epochs.tolist(),
                'l1': l1_over_epochs.tolist(),
                'l2': l2_over_epochs.tolist()
                })
            
            df_losses.to_csv(f"{MODEL_REAL_RESULTS_DIR}/{region_name}_{model_name}_losses_over_epochs.csv", index = False, float_format = "%.5f") # reduce to 5 decimals for readability

            # (4) Save divergence field (computed above for all runs)
            torch.save(dfGP_test_div_field, f"{MODEL_REAL_RESULTS_DIR}/{region_name}_{model_name}_test_prediction_divergence_field.pt")

        x_train_grad = x_train.to(device).clone().requires_grad_(True)

        mean_pred_train, covar_pred_train, _ = GP_predict(
                     x_train,
                     y_train,
                     x_train_grad,
                     [train_noise_diag, best_sigma_f, best_l], # list of (initial) hypers
                     mean_func = dfNN_mean_model, # dfNN as mean function
                     divergence_free_bool = True) # ensures we use a df kernel
        
        dfGP_train_div_field = compute_divergence_field(mean_pred_train, x_train_grad)

        # Divergence: Convert field to metric: mean absolute divergence
        # NOTE: It is important to use the absolute value of the divergence field, since positive and negative deviations are violations and shouldn't cancel each other out 
        dfGP_train_div = dfGP_train_div_field.abs().mean().item()
        dfGP_test_div = dfGP_test_div_field.abs().mean().item()

        # Compute metrics (convert tensors to float) for every run's tuned model
        dfGP_train_RMSE = compute_RMSE(y_train, mean_pred_train).item()
        dfGP_train_MAE = compute_MAE(y_train, mean_pred_train).item()
        dfGP_train_NLL = compute_NLL_full(y_train, mean_pred_train, covar_pred_train).item()

        dfGP_test_RMSE = compute_RMSE(y_test, mean_pred_test).item()
        dfGP_test_MAE = compute_MAE(y_test, mean_pred_test).item()
        # TODO: full NLL
        dfGP_test_NLL = compute_NLL_full(y_test, mean_pred_test, covar_pred_test).item()

        region_results.append([
            run + 1,
            dfGP_train_RMSE, dfGP_train_MAE, dfGP_train_NLL, dfGP_train_div,
            dfGP_test_RMSE, dfGP_test_MAE, dfGP_test_NLL, dfGP_test_div
        ])

        # clean up
        del mean_pred_train, mean_pred_test, covar_pred_train, covar_pred_test
        gc.collect()
        torch.cuda.empty_cache()

    ############################
    ### END LOOP 2 over RUNS ###
    ############################

    # Convert results to a Pandas DataFrame
    results_per_run = pd.DataFrame(
        region_results, 
        columns = ["Run", 
                   "Train RMSE", "Train MAE", "Train NLL", "Train MAD",
                   "Test RMSE", "Test MAE", "Test NLL", "Test MAD"])

    # Compute mean and standard deviation for each metric
    mean_std_df = results_per_run.iloc[:, 1:].agg(["mean", "std"]) # Exclude "Run" column

    # Add region_name and model_name as columns in the DataFrame _metrics_summary to be able to copy df
    mean_std_df["region name"] = region_name
    mean_std_df["model name"] = model_name

    # Save "_metrics_per_run.csv" to CSV
    path_to_metrics_per_run = os.path.join(MODEL_REAL_RESULTS_DIR, f"{region_name}_{model_name}_metrics_per_run.csv")
    results_per_run.to_csv(path_to_metrics_per_run, index = False, float_format = "%.5f") # reduce to 5 decimals
    print(f"\nResults per run saved to {path_to_metrics_per_run}")

    # Save "_metrics_summary.csv" to CSV
    path_to_metrics_summary = os.path.join(MODEL_REAL_RESULTS_DIR, f"{region_name}_{model_name}_metrics_summary.csv")
    mean_std_df.to_csv(path_to_metrics_summary, float_format = "%.5f") # reduce to 5 decimals
    print(f"\nMean & Std saved to {path_to_metrics_summary}")

###############################
### END LOOP 1 over REGIONS ###
###############################

#############################
### WALL time & GPU model ###
#############################

end_time = time.time()
# compute elapsed time
elapsed_time = end_time - start_time 
# convert elapsed time to minutes
elapsed_time_minutes = elapsed_time / 60

if device == "cuda":
    # get name of GPU model
    gpu_name = torch.cuda.get_device_name(0)
else:
    gpu_name = "N/A"

print(f"Elapsed wall time: {elapsed_time:.4f} seconds")

# Define full path for the file
wall_time_and_gpu_path = os.path.join(MODEL_REAL_RESULTS_DIR, model_name + "_run_" "wall_time.txt")

# Save to the correct folder with both seconds and minutes
with open(wall_time_and_gpu_path, "w") as f:
    f.write(f"Elapsed wall time: {elapsed_time:.4f} seconds\n")
    f.write(f"Elapsed wall time: {elapsed_time_minutes:.2f} minutes\n")
    f.write(f"Device used: {device}\n")
    f.write(f"GPU model: {gpu_name}\n")

print(f"Wall time saved to {wall_time_and_gpu_path}.")

Using device: cuda


Training for REGIONC...
=== REGIONC ===
Training inputs shape: torch.Size([462, 2])
Training observations shape: torch.Size([462, 2])
Training inputs dtype: torch.float32

=== REGIONC ===
Test inputs shape: torch.Size([115, 2])
Test observations shape: torch.Size([115, 2])
Test inputs dtype: torch.float32


--- Training Run 1/1 ---

Start Training
regionc dfNGP Run 1/1, Epoch 1/2000, Training Loss (NLML): 64446.4648, (RMSE): 0.3756
regionc dfNGP Run 1/1, Epoch 2/2000, Training Loss (NLML): 63928.1055, (RMSE): 0.3733
regionc dfNGP Run 1/1, Epoch 3/2000, Training Loss (NLML): 63227.0391, (RMSE): 0.3716
regionc dfNGP Run 1/1, Epoch 4/2000, Training Loss (NLML): 62845.0547, (RMSE): 0.3703
regionc dfNGP Run 1/1, Epoch 5/2000, Training Loss (NLML): 62135.0117, (RMSE): 0.3687
regionc dfNGP Run 1/1, Epoch 6/2000, Training Loss (NLML): 61321.2656, (RMSE): 0.3657
regionc dfNGP Run 1/1, Epoch 7/2000, Training Loss (NLML): 60987.1094, (RMSE): 0.3641
regionc dfNGP Run 1/1, Epoc

In [None]:
def compute_NLL_full(y_true, y_mean_pred, y_covar_pred, jitter = 0.5 * 1e-2):
    """Computes Negative Log-Likelihood (NLL) using the full covariance matrix.

    Args:
        y_true (torch.Tensor): True observations of shape (N, 2).
        y_mean_pred (torch.Tensor): Mean predictions of shape (N, 2).
        y_covar_pred (torch.Tensor): Full predicted covariance matrix of shape (N*2, N*2).(BLOCK FORMAT) [u1, u2, u3, ..., v1, v2, v3, ...]
        jitter (float, optional): Small value added to the diagonal for numerical stability. Defaults to 0.5 * 1e-2 - quite high but we need to keep it consistent across all models.

    Returns:
        torch.Tensor(): Negative Log-Likelihood (NLL) scalar.
    """
    # Extract number of points
    N = y_true.shape[0]
    
    # STEP 1: Compute Mahalanobis distance efficiently

    # NOTE: Flatten y_true and y_mean_pred & match covariance matrix shape (BLOCK structure)
    y_true_flat = torch.concat([y_true[:, 0], y_true[:, 1]], dim = 0).unsqueeze(-1)  # Shape: (2 * N, 1)
    y_mean_pred_flat = torch.concat([y_mean_pred[:, 0], y_mean_pred[:, 1]], dim = 0).unsqueeze(-1)  # Shape: (2 * N, 1)

    # Compute the difference between the true and predicted values (y - μ)
    # NOTE: order is (true - pred) to match the Mahalanobis distance formula
    diff = y_true_flat - y_mean_pred_flat   # Shape: (2 * N, 1)

    # STEP 2: Stabilize covariance matrix with fixed jitter to ensure torch.linalg.cholesky() works
    # NOTE: as this is our key metric it is better to add the same jitter to all elements. Thus rather than a loop, we add a fixed small value to the diagonal
    y_covar_pred_stable = y_covar_pred + torch.eye(y_covar_pred.shape[0], device = y_covar_pred.device) * jitter
    
    # Solve Σ⁻¹ using Cholesky decomposition to get lower-triangular matrix L for LL^T = Σ
    chol = torch.linalg.cholesky(y_covar_pred_stable) # Shape: (2 * N, 2 * N)

    # Solve (y - μ)T Σ⁻¹ (y - μ) using Cholesky decomposition for better stability
    mahalanobis_dist = (torch.cholesky_solve(diff, chol).T @ diff).squeeze() # Shape: (1,)

    # STEP 3: Compute log-determinant robustly
    # HACK: avoids underflow/overflow
    sign, log_det_Sigma = torch.linalg.slogdet(y_covar_pred_stable) # sign = 1 or -1, log_det_Sigma = log determinant (scalar)
    
    # If the determinant is non-positive, return a large NLL to indicate instability
    if sign <= 0:
        print("Warning: Non-positive determinant encountered. Returning large NLL.")
        return torch.tensor(float("inf"), device = y_true.device)
    
    # STEP 4: Compute normalisation term
    d = N * 2  # Dimensionality (since we have two outputs per point)
    normalisation_term = d * torch.log(torch.tensor(2 * torch.pi, device = y_true.device))

    # Step 5: Combine 3 scalar terms into negative log-likelihood (NLL)
    log_likelihood = - 0.5 * (mahalanobis_dist + log_det_Sigma + normalisation_term)

    return - log_likelihood  # Negative log-likelihood

In [None]:
def compute_NLL_sparse(y_true, y_mean_pred, y_covar_pred):
    """ Computes a sparse version of the Negative Log-Likelihood (NLL) for a 2D Gaussian distribution. This sparse version neglects cross-covariance terms and is more efficient for large datasets.
    
    NLL: The NLL quantifies how well the predicted Gaussian distribution fits the observed data.
    Sparse format: each of the N points has its own 2×2 covariance matrix. (This is more than just the diagonal of the covariance matrix, but not the full covar.)

    Args:
        y_true (torch.Tensor): True observations of shape (N, 2).
        y_mean_pred (torch.Tensor): Mean predictions of shape (N, 2).
        y_covar_pred (torch.Tensor): Full predicted covariance matrix of shape (N * 2, N * 2).(BLOCK FORMAT) [u1, u2, u3, ..., v1, v2, v3, ...]
            If N = 400, then y_covar_pred is torch.Size([800, 800]) so 640000 elements N x 2 x 2 = only 1600 elements.
        jitter (float, optional): Small value added to the diagonal for numerical stability. Defaults to 0.5 * 1e-2 - quite high but we need to keep it consistent across all models.

    Returns:
        torch.Tensor(): Negative Log-Likelihood (NLL) scalar.
    """
    # Extract number of points
    N = y_true.shape[0]

    # Step 1: Sparsify the covariance matrix
    # Change format of y_covar_pred from (N x 2, N x 2) to (N, 2, 2) so N (2, 2) matrices.
    # NOTE: This is a sparse version of the covariance matrix, neglecting cross-covariance terms.

    # extract diagonal of upper left quadrant: variance of the first output (y1) at each point.
    var_y1_y1 = torch.diag(y_covar_pred[:N, :N])
    # extract diagonal of ulower right quadrant: variance of the second output (y2) at each point
    var_y2_y2 = torch.diag(y_covar_pred[N:, N:])

    # extract diagonal of upper right quadrant: How much do y1 and y2 covary at this point
    covar_y1_y2 = torch.diag(y_covar_pred[:N, N:])
    # extract diagonal of lower left quadrant
    covar_y2_y1 = torch.diag(y_covar_pred[N:, :N])

    col1 = torch.cat([var_y1_y1.unsqueeze(-1), covar_y1_y2.unsqueeze(-1)], dim = -1)
    col2 = torch.cat([covar_y2_y1.unsqueeze(-1), var_y2_y2.unsqueeze(-1)], dim = -1)

    # At each point N, what is the predicted variance of y1 and y2 and 
    # what is the predicted covariance between y1 and y2? (symmetric)
    covar_N22 = torch.cat([col1.unsqueeze(-1), col2.unsqueeze(-1)], dim = -1) # shape: torch.Size([N, 2, 2])


    # STEP 2: Compute Mahalanobis distance efficiently
    # Compute the difference between the true and predicted values (y - μ)
    # NOTE: order is (true - pred) to match the Mahalanobis distance formula
    # NOTE: we can also keep this shape
    diff = y_true - y_mean_pred   # Shape: (N, 2)
    
    # Reshape diff to (N, 2, 1) to do matrix multiplication with (N, 2, 2)
    diff = diff.unsqueeze(-1)  # shape: (N, 2, 1)

    sigma_inverse = torch.inverse(covar_N22) # shape: torch.Size([N, 2, 2])

    # Compute (Σ⁻¹ @ diff) → shape: (N, 2, 1)
    maha_component = torch.matmul(sigma_inverse, diff)

    # Compute (diff^T @ Σ⁻¹ @ diff) for each point → shape: (N, 1, 1)
    # transpose diff to (N, 1, 2) for matrix multiplication
    mahalanobis_distances = torch.matmul(diff.transpose(1, 2), maha_component)

    # Sum (N, ) distances to get a single value
    mahalanobis_distances = mahalanobis_distances.squeeze().sum()


    # STEP 3: Log determinant of the covariance matrix

    # element-wise determinant of all 2x2 matrices: sum
    sign, log_absdet = torch.slogdet(covar_N22)
    if not torch.all(sign > 0):
        print("Warning: Non-positive definite matrix encountered.")
        return torch.tensor(float("inf"), device = covar_N22.device)
    log_det_Sigma = log_absdet.sum()


    # STEP 4: Compute normalisation term
    d = N * 2  # Dimensionality (since we have two outputs per point)
    normalisation_term = d * torch.log(torch.tensor(2 * torch.pi, device = y_true.device))

    # Step 5: Combine 3 scalars into negative log-likelihood (NLL)
    # Gaussian log-likelihood formula: 2D
    # NOTE: Gaussian log-likelihood 2D formula
    log_likelihood =  - 0.5 * (mahalanobis_distances + log_det_Sigma + normalisation_term)

    # return the negative log-likelihood
    return - log_likelihood

In [None]:
print(compute_NLL_full(y_train, mean_pred_train, covar_pred_train).item())
print(compute_NLL_sparse(y_train, mean_pred_train, covar_pred_train).item())

In [None]:
print(compute_NLL_full(y_test, mean_pred_test, covar_pred_test).item())
print(compute_NLL_sparse(y_test, mean_pred_test, covar_pred_test).item())