In [2]:
# Standard libraries
import sys
# Add your custom path
gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"
sys.path.append(gems_tco_path)

# Data manipulation and analysis
import pandas as pd
import numpy as np

from GEMS_TCO import kernels
from GEMS_TCO import data_preprocess 
from GEMS_TCO import kernels_new, kernels_reparametrization_space as kernels_repar_space
from GEMS_TCO import orderings as _orderings 
from GEMS_TCO import load_data
from GEMS_TCO import alg_optimization, alg_opt_Encoder
from GEMS_TCO import configuration as config

from typing import Optional, List, Tuple
from pathlib import Path
from json import JSONEncoder

from GEMS_TCO.data_loader import load_data2
import torch
import torch.optim as optim
import time

Load monthly data

In [3]:
space: List[str] = ['2', '1']
lat_lon_resolution = [int(s) for s in space]
mm_cond_number: int = 20
years = ['2024']
month_range = [7] 

output_path = input_path = Path(config.mac_estimates_day_path)
data_load_instance = load_data2(config.mac_data_load_path)


df_map, ord_mm, nns_map = data_load_instance.load_maxmin_ordered_data_bymonthyear(
lat_lon_resolution=lat_lon_resolution, 
mm_cond_number=mm_cond_number,
years_=years, 
months_=month_range,
lat_range=[0.0, 5.0],      
lon_range=[123.0, 133.0] 
)

#days: List[str] = ['0', '31']
#days_s_e = [int(d) for d in days]
#days_list = list(range(days_s_e[0], days_s_e[1]))

Subsetting data to lat: [0.0, 5.0], lon: [123.0, 133.0]


Load daily data applying max-min ordering

In [4]:
daily_aggregated_tensors = [] 
daily_hourly_maps = []        

analysis_hour =2
for day_index in range(31):
  
    hour_start_index = day_index * 8
    #hour_end_index = (day_index + 1) * 8
    #hour_end_index = day_index*8 + 1
    hour_end_index = day_index*8 + analysis_hour

    hour_indices = [hour_start_index, hour_end_index]
    
    # Load the data for the current day
    day_hourly_map, day_aggregated_tensor = data_load_instance.load_working_data(
        df_map, 
        hour_indices, 
        ord_mm=None,  
        dtype=torch.float 
    )
    # Append the day's data to their respective lists
    daily_aggregated_tensors.append(day_aggregated_tensor)
    daily_hourly_maps.append(day_hourly_map) 

print(daily_aggregated_tensors[0].shape)
#print(daily_hourly_maps[0])

torch.Size([18126, 4])


Hyper parameters

In [5]:
v = 0.5 # smooth
mm_cond_number = 8
nheads = 300
#nheads = 1230
#lr = 0.01
#step = 80
#gamma_par = 0.5

# --- Placeholder Global Variables ---
# ðŸ’¥ REVISED: Added lr, patience, factor. Removed step, gamma_par
lr=0.1
patience = 5       # Scheduler: Epochs to wait for improvement
factor = 0.5         # Scheduler: Factor to reduce LR by (e.g., 0.5 = 50% cut)
epochs=150

In [7]:
import torch
from typing import Callable

# =========================================================================
# 1. Distance and Covariance Functions (Optimized)
# =========================================================================

def precompute_coords_anisotropy(params: torch.Tensor, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    Pre-computes the anisotropic coordinates and squared distance
    using broadcasting to be more memory-efficient.
    """
    sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget = params

    if y is None or x is None:
        raise ValueError("Both y and x must be provided.")
        
    # x is (N, 4), y is (M, 4)
    x1, y1, t1 = x[:, 0], x[:, 1], x[:, 3]
    x2, y2, t2 = y[:, 0], y[:, 1], y[:, 3]
    
    # Pre-calculate advected coordinates
    # x_adv = (x - v_lat*t)
    # y_adv = (y - v_lon*t)
    x1_adv = x1 - advec_lat * t1
    y1_adv = y1 - advec_lon * t1
    
    x2_adv = x2 - advec_lat * t2
    y2_adv = y2 - advec_lon * t2

    # Use broadcasting to compute pairwise differences (N, M)
    # (N, 1) - (1, M) -> (N, M)
    delta_x_adv = x1_adv.unsqueeze(1) - x2_adv.unsqueeze(0)
    delta_y_adv = y1_adv.unsqueeze(1) - y2_adv.unsqueeze(0)
    delta_t     = t1.unsqueeze(1)     - t2.unsqueeze(0)

    # Calculate squared distance terms
    # d^2 = (delta_x_adv / r_lat)^2 + (delta_y_adv / r_lon)^2 + (delta_t * beta)^2
    term1_sq = (delta_x_adv / range_lat).pow(2)
    term2_sq = (delta_y_adv / range_lon).pow(2)
    term3_sq = (delta_t * beta).pow(2)
    
    distance_sq = term1_sq + term2_sq + term3_sq
    return distance_sq

def matern_cov_anisotropy_v05(params: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Calculates the Matern 0.5 (Exponential) covariance.
    (Optimized to remove boolean masking)
    """
    sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget = params
    
    # Call the optimized precompute function
    distance_sq = precompute_coords_anisotropy(params, x, y)

    # Add a small epsilon for numerical stability at d=0
    # This prevents NaN gradients when distance is exactly zero
    epsilon = 1e-12
    sqrt_distance = torch.sqrt(distance_sq + epsilon)
    
    # Compute covariance directly.
    # torch.exp(-0.0) is 1.0, which correctly handles the zero-lag case.
    out = sigmasq * torch.exp(-sqrt_distance)

    # Add nugget to the diagonal
    if x.shape[0] == y.shape[0]:
        out.diagonal().add_(nugget) 
    
    return out

# =========================================================================
# 2. Full Likelihood Function (Corrected)
# =========================================================================

def full_likelihood(params: torch.Tensor, input_data: torch.Tensor, y: torch.Tensor, covariance_function: Callable) -> torch.Tensor:
    """
    Optimized likelihood function using Cholesky decomposition.
    (Corrected to include an intercept in the mean trend)
    """
    input_data = input_data.to(torch.float64)
    y = y.to(torch.float64)
            
    cov_matrix = covariance_function(params=params, y=input_data, x=input_data)
    
    try:
        # Add a small jitter for stability
        jitter = torch.eye(cov_matrix.shape[0], device=cov_matrix.device, dtype=torch.float64) * 1e-6
        L = torch.linalg.cholesky(cov_matrix + jitter)
    except torch.linalg.LinAlgError:
        print("Warning: Cholesky decomposition failed. Matrix may not be positive definite.")
        return torch.tensor(torch.inf, device=params.device, dtype=params.dtype)

    log_det = 2 * torch.sum(torch.log(torch.diag(L)))
    
    # --- ðŸ’¥ START FIX: Add intercept to the trend model ðŸ’¥ ---
    locs_original = input_data[:, :2].to(torch.float64) # [lat, lon]
    intercept = torch.ones(locs_original.shape[0], 1, 
                           device=locs_original.device, 
                           dtype=torch.float64)
    # X matrix is now [1, lat, lon]
    locs = torch.cat((intercept, locs_original), dim=1) 
    # --- END FIX ---
    
    if y.dim() == 1:
        y_col = y.unsqueeze(-1).to(torch.float64)
    else:
        y_col = y.to(torch.float64)

    # Solve for C_inv_X and C_inv_y
    C_inv_X = torch.cholesky_solve(locs, L, upper=False)
    C_inv_y = torch.cholesky_solve(y_col, L, upper=False)

    # Compute beta
    tmp1 = torch.matmul(locs.T, C_inv_X) # (3, N) @ (N, 3) = (3, 3)
    tmp2 = torch.matmul(locs.T, C_inv_y) # (3, N) @ (N, 1) = (3, 1)
    
    try:
        # Add jitter to the small system as well
        jitter_beta = torch.eye(tmp1.shape[0], device=tmp1.device, dtype=torch.float64) * 1e-8
        beta = torch.linalg.solve(tmp1 + jitter_beta, tmp2) # Solves (3,3) system
    except torch.linalg.LinAlgError:
        print("Warning: Could not solve for beta. X^T C_inv X may be singular.")
        return torch.tensor(torch.inf, device=locs.device, dtype=locs.dtype)

    # Compute the mean
    mu = torch.matmul(locs, beta) # (N, 3) @ (3, 1) = (N, 1)
    y_mu = y_col - mu

    # Compute the quadratic form
    C_inv_y_mu = torch.cholesky_solve(y_mu, L, upper=False)
    quad_form = torch.matmul(y_mu.T, C_inv_y_mu)

    # Compute the negative log likelihood
    neg_log_lik = 0.5 * (log_det + quad_form.squeeze())
    return neg_log_lik

dw adams

In [8]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1. Create dummy parameters (sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget)
params_tensor = torch.tensor([
    11.7989,  # sigmasq
    0.1104,   # range_lat
    0.1643,   # range_lon
    0.0223,   # advec_lat
    -0.1672,  # advec_lon
    0.1864,   # beta
    0.0000   # nugget
], dtype=torch.float64, device=device)


# 4. Call the standalone likelihood function
loss = full_likelihood(
    params=params_tensor,
    input_data= day_aggregated_tensor,
    y=day_aggregated_tensor[:, 2],
    covariance_function=matern_cov_anisotropy_v05
)

print(f"Full Negative Log-Likelihood: {loss.item()}")

Full Negative Log-Likelihood: 24807.323750727002


vecchia adams

In [9]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1. Create dummy parameters (sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget)
params_tensor = torch.tensor([
  12.826226,  # sigma_sq
  0.353326,   # range_lat
  0.442368,   # range_lon
  0.035456,   # advec_lat
  -0.214677,  # advec_lon
  0.180337,   # beta
  2.856746   # nugget
], dtype=torch.float64, device=device)

# 4. Call the standalone likelihood function
loss = full_likelihood(
    params=params_tensor,
    input_data= day_aggregated_tensor,
    y=day_aggregated_tensor[:, 2],
    covariance_function=matern_cov_anisotropy_v05
)

print(f"Full Negative Log-Likelihood: {loss.item()}")

Full Negative Log-Likelihood: 23293.543004957515


vecchia l bfgs

In [10]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1. Create dummy parameters (sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget)
params_tensor = torch.tensor([
    12.635289,  # sigmasq
    0.562862,   # range_lon
    0.412029,   # range_lat
    0.204350,   # beta
    0.038643,   # advec_lat
    -0.214493,  # advec_lon
    3.488041    # nugget
], dtype=torch.float64, device=device)

# 4. Call the standalone likelihood function
loss = full_likelihood(
    params=params_tensor,
    input_data= day_aggregated_tensor,
    y=day_aggregated_tensor[:, 2],
    covariance_function=matern_cov_anisotropy_v05
)

print(f"Full Negative Log-Likelihood: {loss.item()}")

Full Negative Log-Likelihood: 23797.00324943286
