In [42]:
import sys
import os
gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"
sys.path.append(gems_tco_path)

# Data manipulation and analysis
import pandas as pd
import numpy as np
import pickle 

import GEMS_TCO
from GEMS_TCO import kernels 
from GEMS_TCO import orderings as _orderings
from GEMS_TCO import load_data_local_computer

import torch
from collections import defaultdict

import torch
from torch.func import grad, hessian, jacfwd, jacrev
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import copy                    # clone tensor

# Summary

Two options: 1. torch.autograd 2. torch.func (recommended for both gradients and hessians)

Observations:
- In order to track gradients, ```sqrt()``` in distance function has to be removed and put ```sqrt(distance function output)``` in covariance function.   
- If dtypes don't match, both autograd and torch.func cannot track hessians, so consider ```.to(torch.float64)``` so ``` aggregated_data[:,:4].torch.float64()```   

- For hessians, torch.func is recommended. ``` torch.autograd.functional.hessian(compute_loss, params)``` this doesn't work.   
- It seems there is nontrivial difference between float32 and float64 settings. 

In [None]:
# CODE EXAMPLE

# Define the function to compute the loss
def compute_loss(params):
    return instance.full_likelihood(params, aggregated_data[:, :4].to(torch.float64), aggregated_data[:, 2].to(torch.float64), instance.matern_cov_ani)
    # return instance.vecchia_interpolation_1to6(params, instance.matern_cov_ani, 35)

# Convert parameters to a tensor with requires_grad=True
params = torch.tensor(df.iloc[0, :-1].values, dtype=torch.float64, requires_grad=True)
print(params)

# Compute the first derivative using torch.func.grad
grad_f = grad(compute_loss)
g1 = grad_f(params)
print(f'Gradient: {g1}')

# Compute the Hessian matrix using torch.func.hessian
try:
    hessian_matrix = hessian(compute_loss, params)      ## this is equivalent to jacfwd(jacrev(compute_loss))(params)
    print(hessian_matrix)
except Exception as e:
    print(f'Error computing Hessian: {e}')

tensor([25.8229,  1.0230,  1.1314,  0.0733, -0.0958,  0.1777,  1.5697],
       dtype=torch.float64, requires_grad=True)
Gradient: tensor([  -2.3781,  -20.0865,   -9.3401,  -61.2322,   37.6362, -191.4623,
         -14.3487], dtype=torch.float64, grad_fn=<AddBackward0>)
<function compute_loss at 0x3cf3e5260>


# LOAD DATA 

In [44]:
lat_lon_resolution = [10,10]
day = 1
mm_cond_number = 20

years = ['2024']
month_range =[7,8]
idx_for_datamap= [ 8*(day-1),8*day]

instance = load_data_local_computer()
map, ord_mm, nns_map= instance.load_mm20k_data_bymonthyear( lat_lon_resolution= lat_lon_resolution, mm_cond_number=mm_cond_number,years_=years, months_=month_range)
analysis_data_map, aggregated_data = instance.load_working_data_byday( map, ord_mm, nns_map, idx_for_datamap=[0,8])

input_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/Exercises/st_model/estimates"
output_filename = 'vecchia_inter_estimates_1250_july24.csv'
output_csv_path = os.path.join(input_path, output_filename)

df = pd.read_csv(output_csv_path)
df.head()

Unnamed: 0,sigmasq,range_lat,range_lon,advec_lat,advec_lon,beta,nugget,loss
0,25.822897,1.023014,1.131423,0.073286,-0.09581,0.17767,1.5697,50396.417969
1,24.242847,2.147953,1.945775,0.050045,-0.070032,0.140178,1.442902,43290.984375
2,26.15612,0.818184,1.061921,0.081691,-0.119475,0.24149,2.123204,54808.964844
3,26.00123,1.1942,1.434477,-0.206528,-0.174575,0.041952,2.593778,52697.542969
4,23.72466,1.819339,2.713623,-0.166419,-0.037361,-0.014584,1.945128,43821.054688


In [45]:
class matern_advec_beta_torch_vecchia:
    def __init__(self, analaysis_data_map: torch.Tensor, params: torch.Tensor, nns_map=nns_map, mm_cond_number=mm_cond_number):
        
        self.key_list = sorted(analysis_data_map)
        self.input_map = analysis_data_map

        self.mm_cond_number = mm_cond_number
        self.nns_map = nns_map 
        self.input_map = analaysis_data_map
        self.smooth = 0.5  
        sample_df = analaysis_data_map[self.key_list[0]]

        self.size_per_hour = len(sample_df)

    def custom_distance_matrix(self, U, V):
        # Efficient distance computation with broadcasting
        spatial_diff = torch.norm(U[:, :2].unsqueeze(1) - V[:, :2].unsqueeze(0), dim=2)

        temporal_diff = torch.abs(U[:, 2].unsqueeze(1) - V[:, 2].unsqueeze(0))
        distance = (spatial_diff**2 + temporal_diff**2)  # move torch.sqrt to covariance function to track gradients of beta and avec
        return distance
    
    def precompute_coords_ani(self, params, y: torch.Tensor, x: torch.Tensor)-> torch.Tensor:
        sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget = params

        if y is None or x is None:
            raise ValueError("Both y and x_df must be provided.")

        x1 = x[:, 0]
        y1 = x[:, 1]
        t1 = x[:, 3]

        x2 = y[:, 0]
        y2 = y[:, 1]
        t2 = y[:, 3]

        # spat_coord1 = torch.stack((self.x1 , self.y1 - advec * self.t1), dim=-1)
        spat_coord1 = torch.stack(( (x1 - advec_lat * t1)/range_lat, (y1 - advec_lon * t1)/range_lon ), dim=-1)
        spat_coord2 = torch.stack(( (x2 - advec_lat * t2)/range_lat, (y2 - advec_lon * t2)/range_lon ), dim=-1)

        U = torch.cat((spat_coord1, (beta * t1).reshape(-1, 1)), dim=1)
        V = torch.cat((spat_coord2, (beta * t2).reshape(-1, 1)), dim=1)

        distance = self.custom_distance_matrix(U,V)
        non_zero_indices = distance != 0
        return distance, non_zero_indices
    
    # anisotropic in three 
    def matern_cov_ani(self,params: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget = params
        

        distance, non_zero_indices = self.precompute_coords_ani(params, x,y)
        out = torch.zeros_like(distance)

        non_zero_indices = distance != 0
        if torch.any(non_zero_indices):
            out[non_zero_indices] = sigmasq * torch.exp(- torch.sqrt(distance[non_zero_indices]))
        out[~non_zero_indices] = sigmasq

        # Add a small jitter term to the diagonal for numerical stability
        out += torch.eye(out.shape[0]) * nugget 

        return out
    
    def full_likelihood(self,params: torch.Tensor, input_np: torch.Tensor, y: torch.Tensor, covariance_function) -> torch.Tensor:
        input_arr = input_np[:, :4]
        y_arr = y

        # Compute the covariance matrix
        cov_matrix = covariance_function(params=params, y=input_arr, x=input_arr)
        
        # Compute the log determinant of the covariance matrix
        sign, log_det = torch.slogdet(cov_matrix)
        #log_det = torch.log(torch.linalg.det(cov_matrix))
        #if sign <= 0:
        #    raise ValueError("Covariance matrix is not positive definite")
        
        # Extract locations
        locs = input_arr[:, :2]

        # Compute beta
        tmp1 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, locs))
        tmp2 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, y_arr))
        beta = torch.linalg.solve(tmp1, tmp2)

        # Compute the mean
        mu = torch.matmul(locs, beta)
        y_mu = y_arr - mu

        # Compute the quadratic form
        quad_form = torch.matmul(y_mu, torch.linalg.solve(cov_matrix, y_mu))

        # Compute the negative log likelihood
        neg_log_lik = 0.5 * (log_det + quad_form)
        # neg_log_lik = 0.5 * ( log_det )
        return  neg_log_lik 




# Likelihood investigation

In [47]:
lat_lon_resolution = [10,10]

head100map = defaultdict(list)

headn = 10
b = [0]*7
b1=b2=b3=b4=b5=b6=0
for day in range(1,2):
    mm_cond_number = 10

    years = ['2024']
    month_range =[7,8]
    idx_for_datamap= [ 8*(day-1),8*day]

    instance = load_data_local_computer()
    map, ord_mm, nns_map= instance.load_mm20k_data_bymonthyear( lat_lon_resolution= lat_lon_resolution, mm_cond_number=mm_cond_number,years_=years, months_=month_range)
    analysis_data_map, aggregated_data = instance.load_working_data_byday( map, ord_mm, nns_map, idx_for_datamap= idx_for_datamap)

    params = [24.42, 1.92, 1.92, 0.001, -0.045, 0.237, 3.34]
    params = torch.tensor(params, requires_grad=True)
    instance = matern_advec_beta_torch_vecchia(analysis_data_map, params, nns_map, mm_cond_number)
    out = instance.full_likelihood(params, aggregated_data[:,:4],aggregated_data[:,2], instance.matern_cov_ani)
    print(out)  # 15105

    print(f'day {day} finished')

tensor(2588.8357, grad_fn=<MulBackward0>)
day 1 finished


# Gradients and hessians

In [56]:
params = [24.42, 1.92, 1.92, 0.001, -0.045, 0.237, 3.34]
params = torch.tensor(params, dtype=torch.float64, requires_grad=True)

# Define the function to compute the loss
def compute_loss(params):
    return instance.full_likelihood(params, aggregated_data[:, :4].to(torch.float64), aggregated_data[:, 2].to(torch.float64), instance.matern_cov_ani)
    # return instance.vecchia_interpolation_1to6(params, instance.matern_cov_ani, 35)
    
# Compute the first derivative using torch.func.grad
grad_f = torch.autograd.grad(compute_loss(params), params)

grad_f

(tensor([  0.5812,  19.9819,  11.6205,   2.1511,  18.5098, 447.8913,   5.8960],
        dtype=torch.float64),)

In [57]:
params = [24.42, 1.92, 1.92, 0.001, -0.045, 0.237, 3.34]
params = torch.tensor(params, dtype=torch.float32, requires_grad=True)

# Convert parameters to a tensor with requires_grad=True
params = torch.tensor(df.iloc[0, :-1].values, dtype=torch.float64, requires_grad=True)
print(params)

# Define the function to compute the loss
def compute_loss(params):
    return instance.full_likelihood(params, aggregated_data[:, :4].to(torch.float64), aggregated_data[:, 2].to(torch.float64), instance.matern_cov_ani)
    # return instance.vecchia_interpolation_1to6(params, instance.matern_cov_ani, 35)



# Compute the first derivative using torch.func.grad
grad_f = grad(compute_loss)
g1 = grad_f(params)
print(f'Gradient: {g1}')

# Compute the Hessian matrix using torch.func.hessian
try:
    hessian_matrix = hessian(compute_loss, params)
    print(hessian_matrix)
except Exception as e:
    print(f'Error computing Hessian: {e}')


print(hessian(compute_loss)(params))



tensor([25.8229,  1.0230,  1.1314,  0.0733, -0.0958,  0.1777,  1.5697],
       dtype=torch.float64, requires_grad=True)
Gradient: tensor([  0.9324, -43.9642, -35.9082,  59.9937, -17.1091, -76.0932,  -0.6668],
       dtype=torch.float64, grad_fn=<AddBackward0>)
<function compute_loss at 0x3cf321800>
tensor([[ 6.8449e-01, -3.0525e+00, -4.1062e+00,  5.7238e+00, -5.5499e+00,
          3.8050e+01,  3.0690e+00],
        [-3.0525e+00,  1.5700e+02, -8.9039e+00,  1.1649e+02,  5.2665e+01,
         -3.1752e+01, -1.3032e+01],
        [-4.1062e+00, -8.9039e+00,  1.5732e+02, -4.8812e+01, -6.4641e+01,
          5.5250e+01, -1.4956e+01],
        [ 5.7238e+00,  1.1649e+02, -4.8812e+01,  1.6052e+03, -9.2467e+02,
          1.2072e+03,  7.3658e+01],
        [-5.5499e+00,  5.2665e+01, -6.4641e+01, -9.2467e+02,  2.2599e+03,
         -1.1357e+03, -7.7113e+01],
        [ 3.8050e+01, -3.1752e+01,  5.5250e+01,  1.2072e+03, -1.1357e+03,
          6.4662e+03,  2.6074e+02],
        [ 3.0690e+00, -1.3032e+01, -1.49