In [2]:
import sys
import os
# gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"
# sys.path.append(gems_tco_path)

# Data manipulation and analysis
import pandas as pd
import numpy as np
import pickle 

import GEMS_TCO
from GEMS_TCO import kernels 
from GEMS_TCO import orderings as _orderings
from GEMS_TCO import load_data_local_computer

import torch
from collections import defaultdict

import torch
from torch.func import grad, hessian, jacfwd, jacrev
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import copy                    # clone tensor

# Summary

Two options: 1. torch.autograd 2. torch.func (recommended for both gradients and hessians)

Observations:
- In order to track gradients, ```sqrt()``` in distance function has to be removed and put ```sqrt(distance function output)``` in covariance function.   

- If dtypes don't match, both autograd and torch.func cannot track hessians, so consider ```.to(torch.float64)``` so ``` aggregated_data[:,:4].torch.float64()```   
for the consistency.
Actually, it turns out that if I use ```float32```, then autograd derivative can be different from analytical derivative by ```0.001 ~ 0.004```. 

the difference is on the order of one-thousandth 

- For hessians, torch.func is recommended. ``` torch.autograd.functional.hessian(compute_loss, params)``` this doesn't work.   

- It seems there is nontrivial difference between float32 and float64 settings. 

In [2]:
# CODE EXAMPLE
# Define the function to compute the loss
def compute_loss(params):
    return instance.full_likelihood(params, aggregated_data[:, :4].to(torch.float64), aggregated_data[:, 2].to(torch.float64), instance.matern_cov_ani)
    # return instance.vecchia_interpolation_1to6(params, instance.matern_cov_ani, 35)

# Convert parameters to a tensor with requires_grad=True
params = torch.tensor(df.iloc[0, :-1].values, dtype=torch.float64, requires_grad=True)
print(params)

# Compute the first derivative using torch.func.grad
grad_f = grad(compute_loss)
g1 = grad_f(params)
print(f'Gradient: {g1}')

# Compute the Hessian matrix using torch.func.hessian
try:
    hessian_matrix = hessian(compute_loss, params)      ## this is equivalent to jacfwd(jacrev(compute_loss))(params)
    print(hessian_matrix)
except Exception as e:
    print(f'Error computing Hessian: {e}')

NameError: name 'df' is not defined

# LOAD DATA 

In [4]:
lat_lon_resolution = [20,20]
day = 1
mm_cond_number = 20

years = ['2024']
month_range =[7,8]
idx_for_datamap= [ 8*(day-1),8*day]

instance = load_data_local_computer()
map, ord_mm, nns_map= instance.load_mm20k_data_bymonthyear( lat_lon_resolution= lat_lon_resolution, mm_cond_number=mm_cond_number,years_=years, months_=month_range)
analysis_data_map, aggregated_data = instance.load_working_data_byday( map, ord_mm, nns_map, idx_for_datamap=[0,8])

input_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/Exercises/st_model/estimates"
output_filename = 'vecchia_inter_estimates_1250_july24.csv'
output_csv_path = os.path.join(input_path, output_filename)

df = pd.read_csv(output_csv_path)
df.head()

Unnamed: 0,sigmasq,range_lat,range_lon,advec_lat,advec_lon,beta,nugget,loss
0,25.822897,1.023014,1.131423,0.073286,-0.09581,0.17767,1.5697,50396.417969
1,24.242847,2.147953,1.945775,0.050045,-0.070032,0.140178,1.442902,43290.984375
2,26.15612,0.818184,1.061921,0.081691,-0.119475,0.24149,2.123204,54808.964844
3,26.00123,1.1942,1.434477,-0.206528,-0.174575,0.041952,2.593778,52697.542969
4,23.72466,1.819339,2.713623,-0.166419,-0.037361,-0.014584,1.945128,43821.054688


In [20]:
class matern_advec_beta_torch_vecchia:
    def __init__(self, analaysis_data_map: torch.Tensor, params: torch.Tensor, nns_map=nns_map, mm_cond_number=mm_cond_number):
        
        self.key_list = sorted(analysis_data_map)
        self.input_map = analysis_data_map

        self.mm_cond_number = mm_cond_number
        self.nns_map = nns_map 
        self.input_map = analaysis_data_map
        self.smooth = 0.5  
        sample_df = analaysis_data_map[self.key_list[0]]

        self.size_per_hour = len(sample_df)

    def custom_distance_matrix(self, U, V):
        # Efficient distance computation with broadcasting
        spatial_diff = torch.norm(U[:, :2].unsqueeze(1) - V[:, :2].unsqueeze(0), dim=2)

        temporal_diff = torch.abs(U[:, 2].unsqueeze(1) - V[:, 2].unsqueeze(0))
        distance = (spatial_diff**2 + temporal_diff**2)  # move torch.sqrt to covariance function to track gradients of beta and avec
        return distance
    
    def precompute_coords_ani(self, params, y: torch.Tensor, x: torch.Tensor)-> torch.Tensor:
        sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget = params

        if y is None or x is None:
            raise ValueError("Both y and x_df must be provided.")

        x1 = x[:, 0]
        y1 = x[:, 1]
        t1 = x[:, 3]

        x2 = y[:, 0]
        y2 = y[:, 1]
        t2 = y[:, 3]

        # spat_coord1 = torch.stack((self.x1 , self.y1 - advec * self.t1), dim=-1)
        spat_coord1 = torch.stack(( (x1 - advec_lat * t1)/range_lat, (y1 - advec_lon * t1)/range_lon ), dim=-1)
        spat_coord2 = torch.stack(( (x2 - advec_lat * t2)/range_lat, (y2 - advec_lon * t2)/range_lon ), dim=-1)

        U = torch.cat((spat_coord1, (beta * t1).reshape(-1, 1)), dim=1)
        V = torch.cat((spat_coord2, (beta * t2).reshape(-1, 1)), dim=1)

        distance = self.custom_distance_matrix(U,V)
        non_zero_indices = distance != 0
        return distance, non_zero_indices
    
    # anisotropic in three 
    def matern_cov_ani(self,params: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget = params
        

        distance, non_zero_indices = self.precompute_coords_ani(params, x,y)
        out = torch.zeros_like(distance)

        non_zero_indices = distance != 0
        if torch.any(non_zero_indices):
            out[non_zero_indices] = sigmasq * torch.exp(- torch.sqrt(distance[non_zero_indices]))
        out[~non_zero_indices] = sigmasq

        # Add a small jitter term to the diagonal for numerical stability
        out += torch.eye(out.shape[0], dtype=torch.float64) * nugget 
        
        return out
    
    def full_likelihood(self,params: torch.Tensor, input_np: torch.Tensor, y: torch.Tensor, covariance_function) -> torch.Tensor:
        input_arr = input_np[:, :4]
        y_arr = y

        # Compute the covariance matrix
        cov_matrix = covariance_function(params=params, y=input_arr, x=input_arr)
        
        # Compute the log determinant of the covariance matrix
        sign, log_det = torch.slogdet(cov_matrix)
        #log_det = torch.log(torch.linalg.det(cov_matrix))
        #if sign <= 0:
        #    raise ValueError("Covariance matrix is not positive definite")
        
        # Extract locations
        locs = input_arr[:, :2]

        # Compute beta
        tmp1 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, locs))
        tmp2 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, y_arr))
        beta = torch.linalg.solve(tmp1, tmp2)

        # Compute the mean
        mu = torch.matmul(locs, beta)
        y_mu = y_arr - mu

        # Compute the quadratic form
        quad_form = torch.matmul(y_mu, torch.linalg.solve(cov_matrix, y_mu))

        # Compute the negative log likelihood
        neg_log_lik = 0.5 * (log_det + quad_form)
        # neg_log_lik = 0.5 * ( log_det )
        return  neg_log_lik 

    def vecchia_like_local_computer(self, params: torch.Tensor, covariance_function) -> torch.Tensor:
        self.cov_map = defaultdict(list)
        neg_log_lik = 0.0
        
        for time_idx in range(len(self.input_map)):
            current_np = self.input_map[self.key_list[time_idx]]

            # Use below when working on local computer to avoid singular matrix
            #cur_heads = current_np[:21, :]
            #neg_log_lik += self.full_likelihood(params, cur_heads, cur_heads[:, 2], covariance_function)

            for index in range(0, self.size_per_hour):

                
                current_row = current_np[index].reshape(1, -1)
                current_y = current_row[0, 2]

                # Construct conditioning set
                mm_neighbors = self.nns_map[index]
                past = list(mm_neighbors)
                data_list = []

                if past:
                    data_list.append(current_np[past])

                if time_idx > 1:
                    cov_matrix = self.cov_map[index]['cov_matrix']
                    tmp_for_beta = self.cov_map[index]['tmp_for_beta']
                    cov_xx_inv = self.cov_map[index]['cov_xx_inv']
                    L_inv = self.cov_map[index]['L_inv']
                    cov_ygivenx = self.cov_map[index]['cov_ygivenx']
                    cond_mean_tmp = self.cov_map[index]['cond_mean_tmp']
                    log_det = self.cov_map[index]['log_det']
                    locs = self.cov_map[index]['locs']
                    
                    last_hour_np = self.input_map[self.key_list[time_idx - 1]]
                    past_conditioning_data = last_hour_np[past + [index], :]
                    data_list.append(past_conditioning_data)

                    if data_list:
                        conditioning_data = torch.vstack(data_list)
                    else:
                        conditioning_data = torch.empty((0, current_row.shape[1]), dtype=torch.float32)

                    np_arr = torch.vstack((current_row, conditioning_data))
                    y_and_neighbors = np_arr[:, 2]

                    cov_yx = cov_matrix[0, 1:]

                    tmp2 = torch.matmul(torch.matmul(L_inv, locs).T, torch.matmul(L_inv, y_and_neighbors))
                    beta = torch.linalg.solve(tmp_for_beta, tmp2)

                    mu = torch.matmul(locs, beta)
                    mu_current = mu[0]
                    mu_neighbors = mu[1:]
                    
                    # Mean and variance of y|x
                    cond_mean = mu_current + torch.matmul(cond_mean_tmp, (y_and_neighbors[1:] - mu_neighbors))
                    alpha = current_y - cond_mean
                    quad_form = alpha**2 * (1 / cov_ygivenx)
                    neg_log_lik += 0.5 * (log_det + quad_form)

                    continue

                if time_idx > 0:
                    last_hour_np = self.input_map[self.key_list[time_idx - 1]]
                    past_conditioning_data = last_hour_np[past + [index], :]
                    data_list.append(past_conditioning_data)

                if data_list:
                    conditioning_data = torch.vstack(data_list)
                else:
                    conditioning_data = torch.empty((0, current_row.shape[1]), dtype=torch.float32)

                np_arr = torch.vstack((current_row, conditioning_data))
                y_and_neighbors = np_arr[:, 2]
                locs = np_arr[:, :2]

                cov_matrix = covariance_function(params=params, y= np_arr, x= np_arr)
                # print(f'Condition number: {torch.linalg.cond(cov_matrix)}')
                L = torch.linalg.cholesky(cov_matrix)
                L11 = L[:1, :1]
                L12 = torch.zeros(L[:1, 1:].shape)
                L21 = L[1:, :1]
                L22 = L[1:, 1:]
                L11_inv = torch.linalg.inv(L11)
                L22_inv = torch.linalg.inv(L22)

                # First block: [L11_inv, L12]
                upper_block = torch.cat((L11_inv, L12), dim=1)  # Concatenate along columns (dim=1)

                # Second block: [-torch.matmul(torch.matmul(L22_inv, L21), L11_inv), L22_inv]
                lower_left = -torch.matmul(torch.matmul(L22_inv, L21), L11_inv)
                lower_block = torch.cat((lower_left, L22_inv), dim=1)  # Concatenate along columns (dim=1)

                # Combine the upper and lower blocks
                L_inv = torch.cat((upper_block, lower_block), dim=0)  # Concatenate along rows (dim=0)

                cov_yx = cov_matrix[0, 1:]

                tmp1 = torch.matmul(L_inv, locs)
                tmp2 = torch.matmul(torch.matmul(L_inv, locs).T, torch.matmul(L_inv, y_and_neighbors))
                tmp_for_beta = torch.matmul(tmp1.T, tmp1)
                beta = torch.linalg.solve(tmp_for_beta, tmp2)

                mu = torch.matmul(locs, beta)
                mu_current = mu[0]
                mu_neighbors = mu[1:]

                # Mean and variance of y|x
                sigma = cov_matrix[0, 0]
                cov_xx = cov_matrix[1:, 1:]
                cov_xx_inv = torch.linalg.inv(cov_xx)

                cov_ygivenx = sigma - torch.matmul(cov_yx, torch.matmul(cov_xx_inv, cov_yx))
                cond_mean_tmp = torch.matmul(cov_yx, cov_xx_inv)
                cond_mean = mu_current + torch.matmul(cond_mean_tmp, (y_and_neighbors[1:] - mu_neighbors))
                
                alpha = current_y - cond_mean
                quad_form = alpha**2 * (1 / cov_ygivenx)
                log_det = torch.log(cov_ygivenx)
                neg_log_lik += 0.5 * (log_det + quad_form)

             
                if time_idx == 1:
                    self.cov_map[index] = {
                        'tmp_for_beta': tmp_for_beta,
                        'cov_xx_inv': cov_xx_inv,
                        'cov_matrix': cov_matrix,
                        'L_inv': L_inv,
                        'cov_ygivenx': cov_ygivenx,
                        'cond_mean_tmp': cond_mean_tmp,
                        'log_det': log_det,
                        'locs': locs
                    }

        return neg_log_lik  

    def vecchia_local_extra_base(self, params: torch.Tensor, covariance_function) -> torch.Tensor:
        self.cov_map = defaultdict(list)
        neg_log_lik = 0.0
        
        for time_idx in range(len(self.input_map)):
            current_np = self.input_map[self.key_list[time_idx]]

            # Use below when working on local computer to avoid singular matrix
            # cur_heads = current_np[:21, :]
            # neg_log_lik += self.full_likelihood(params, cur_heads, cur_heads[:, 2], covariance_function)

            for index in range(0, self.size_per_hour):
                current_row = current_np[index].reshape(1, -1)
                current_y = current_row[0, 2]

                # Construct conditioning set
                mm_neighbors = self.nns_map[index]
                past = list(mm_neighbors) + base_list
                data_list = []

                if past:
                    data_list.append(current_np[past])

                if time_idx > 1:
                    cov_matrix = self.cov_map[index]['cov_matrix']
                    tmp_for_beta = self.cov_map[index]['tmp_for_beta']
                    cov_xx_inv = self.cov_map[index]['cov_xx_inv']
            
                    cov_ygivenx = self.cov_map[index]['cov_ygivenx']
                    cond_mean_tmp = self.cov_map[index]['cond_mean_tmp']
                    log_det = self.cov_map[index]['log_det']
                    locs = self.cov_map[index]['locs']
                    
                    last_hour_np = self.input_map[self.key_list[time_idx - 1]]
                    past_conditioning_data = last_hour_np[past + [index] + base_list, :]
                    data_list.append(past_conditioning_data)

                    if data_list:
                        conditioning_data = torch.vstack(data_list)
                    else:
                        conditioning_data = torch.empty((0, current_row.shape[1]), dtype=torch.float32)

                    np_arr = torch.vstack((current_row, conditioning_data))
                    y_and_neighbors = np_arr[:, 2]

                    cov_yx = cov_matrix[0, 1:]

                    y_arr = y_and_neighbors
                    tmp1 = tmp_for_beta
                    tmp2 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, y_arr))
                    beta = torch.linalg.solve(tmp1, tmp2)

                    mu = torch.matmul(locs, beta)
                    mu_current = mu[0]
                    mu_neighbors = mu[1:]
                    
                    # Mean and variance of y|x
                    cond_mean = mu_current + torch.matmul(cond_mean_tmp, (y_and_neighbors[1:] - mu_neighbors))
                    alpha = current_y - cond_mean
                    quad_form = alpha**2 * (1 / cov_ygivenx)
                    neg_log_lik += 0.5 * (log_det + quad_form)

                    continue

                if time_idx > 0:
                    last_hour_np = self.input_map[self.key_list[time_idx - 1]]
                    past_conditioning_data = last_hour_np[past + [index]+ base_list, :]
                    data_list.append(past_conditioning_data)

                if data_list:
                    conditioning_data = torch.vstack(data_list)
                else:
                    conditioning_data = torch.empty((0, current_row.shape[1]), dtype=torch.float32)

                np_arr = torch.vstack((current_row, conditioning_data))
                y_and_neighbors = np_arr[:, 2]
                locs = np_arr[:, :2]

                cov_matrix = covariance_function(params=params, y= np_arr, x= np_arr)
                # print(f'Condition number: {torch.linalg.cond(cov_matrix)}')
                cov_yx = cov_matrix[0, 1:]
                        # Compute the log determinant of the covariance matrix
                sign, log_det = torch.slogdet(cov_matrix)
                # if sign <= 0:
                #     raise ValueError("Covariance matrix is not positive definite")
            
                y_arr = y_and_neighbors
                # Compute beta
                tmp1 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, locs))
                tmp2 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, y_arr))
                beta = torch.linalg.solve(tmp1, tmp2)

                mu = torch.matmul(locs, beta)
                mu_current = mu[0]
                mu_neighbors = mu[1:]

                # Mean and variance of y|x
                sigma = cov_matrix[0, 0]
                cov_xx = cov_matrix[1:, 1:]
                cov_xx_inv = torch.linalg.inv(cov_xx)

                cov_ygivenx = sigma - torch.matmul(cov_yx, torch.matmul(cov_xx_inv, cov_yx))
                cond_mean_tmp = torch.matmul(cov_yx, cov_xx_inv)
                cond_mean = mu_current + torch.matmul(cond_mean_tmp, (y_and_neighbors[1:] - mu_neighbors))
                
                alpha = current_y - cond_mean
                quad_form = alpha**2 * (1 / cov_ygivenx)
                log_det = torch.log(cov_ygivenx)
                neg_log_lik += 0.5 * (log_det + quad_form)
 
                if time_idx == 1:
                    self.cov_map[index] = {
                        'tmp_for_beta': tmp1,
                        'cov_xx_inv': cov_xx_inv,
                        'cov_matrix': cov_matrix,
               
                        'cov_ygivenx': cov_ygivenx,
                        'cond_mean_tmp': cond_mean_tmp,
                        'log_det': log_det,
                        'locs': locs
                    }
        return neg_log_lik
    

    def vecchia_b2(self, params: torch.Tensor, covariance_function, cut_line=35) -> torch.Tensor:
        self.cov_map = defaultdict(list)
        neg_log_lik = 0.0
        key_list = sorted(analysis_data_map)
        cut_line = cut_line
        heads = analysis_data_map[key_list[0]][:cut_line,:]
        for time_idx in range(1, len(analysis_data_map)):
            tmp = analysis_data_map[key_list[time_idx]][:cut_line,:]
            heads = torch.cat( (heads,tmp), dim=0)

        neg_log_lik += self.full_likelihood(params, heads, heads[:, 2], covariance_function)          
        
        for time_idx in range(0,len(self.input_map)):
            current_np = self.input_map[self.key_list[time_idx]]

            # Use below when working on local computer to avoid singular matrix
            for index in range(cut_line, self.size_per_hour):
                current_row = current_np[index].reshape(1, -1)
                current_y = current_row[0, 2]

                # Construct conditioning set
                mm_neighbors = self.nns_map[index]
                past = list(mm_neighbors) 
                data_list = []

                if past:
                    data_list.append(current_np[past])

                if time_idx > 0:
                    last_hour_np = self.input_map[self.key_list[time_idx - 1]]
                    past_conditioning_data = last_hour_np[past + [index], :]
                    data_list.append(past_conditioning_data)

                if time_idx > 1:
                    last_hour_np = self.input_map[self.key_list[time_idx -2]]
                    # if index==200:
                    #     print(self.input_map[self.key_list[time_idx-6]])
                    past_conditioning_data = last_hour_np[past + [index], :]
                    data_list.append(past_conditioning_data)
                
                if data_list:
                    conditioning_data = torch.vstack(data_list)
                else:
                    conditioning_data = torch.empty((0, current_row.shape[1]), dtype=torch.float32)

                np_arr = torch.vstack((current_row, conditioning_data))
                y_and_neighbors = np_arr[:, 2]
                locs = np_arr[:, :2]

                cov_matrix = covariance_function(params=params, y= np_arr, x= np_arr)
                # print(f'Condition number: {torch.linalg.cond(cov_matrix)}')
                cov_yx = cov_matrix[0, 1:]
                        # Compute the log determinant of the covariance matrix
                sign, log_det = torch.slogdet(cov_matrix)
                # if sign <= 0:
                #     raise ValueError("Covariance matrix is not positive definite")
            
                y_arr = y_and_neighbors
                # Compute beta
                tmp1 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, locs))
                tmp2 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, y_arr))
                beta = torch.linalg.solve(tmp1, tmp2)

                mu = torch.matmul(locs, beta)
                mu_current = mu[0]
                mu_neighbors = mu[1:]

                # Mean and variance of y|x
                sigma = cov_matrix[0, 0]
                cov_xx = cov_matrix[1:, 1:]
                cov_xx_inv = torch.linalg.inv(cov_xx)

                cov_ygivenx = sigma - torch.matmul(cov_yx, torch.matmul(cov_xx_inv, cov_yx))
                cond_mean_tmp = torch.matmul(cov_yx, cov_xx_inv)
                cond_mean = mu_current + torch.matmul(cond_mean_tmp, (y_and_neighbors[1:] - mu_neighbors))
                
                alpha = current_y - cond_mean
                quad_form = alpha**2 * (1 / cov_ygivenx)
                log_det = torch.log(cov_ygivenx)
                neg_log_lik += 0.5 * (log_det + quad_form)
 
        return neg_log_lik
    

    def vecchia_interpolation_1to6(self, params: torch.Tensor, covariance_function, cut_line=200) -> torch.Tensor:
        self.cov_map = defaultdict(list)
        neg_log_lik = 0.0
        key_list = sorted(analysis_data_map)
        cut_line = cut_line
        heads = analysis_data_map[key_list[0]][:cut_line,:]
        for time_idx in range(1, len(analysis_data_map)):
            tmp = analysis_data_map[key_list[time_idx]][:cut_line,:]
            heads = torch.cat( (heads,tmp), dim=0)

        neg_log_lik += self.full_likelihood(params, heads, heads[:, 2], covariance_function)          
        
        for time_idx in range(0,len(self.input_map)):
            current_np = self.input_map[self.key_list[time_idx]]

            # Use below when working on local computer to avoid singular matrix
            for index in range(cut_line, self.size_per_hour):
                current_row = current_np[index].reshape(1, -1)
                current_y = current_row[0, 2]

                # Construct conditioning set
                mm_neighbors = self.nns_map[index]
                past = list(mm_neighbors) 
                data_list = []

                if past:
                    data_list.append(current_np[past])

                if time_idx > 0 and time_idx<7:
                    last_hour_np = self.input_map[self.key_list[time_idx - 1]]
                    past_conditioning_data = last_hour_np[past + [index], :]
                    data_list.append(past_conditioning_data)

                    last_hour_np = self.input_map[self.key_list[time_idx +1]]
                    # if index==200:
                    #     print(self.input_map[self.key_list[time_idx-6]])
                    past_conditioning_data = last_hour_np[past + [index], :]
                    data_list.append(past_conditioning_data)
                
                if data_list:
                    conditioning_data = torch.vstack(data_list)
                else:
                    conditioning_data = torch.empty((0, current_row.shape[1]), dtype=torch.float32)

                np_arr = torch.vstack((current_row, conditioning_data))
                y_and_neighbors = np_arr[:, 2]
                locs = np_arr[:, :2]

                cov_matrix = covariance_function(params=params, y= np_arr, x= np_arr)
                # print(f'Condition number: {torch.linalg.cond(cov_matrix)}')
                cov_yx = cov_matrix[0, 1:]
                        # Compute the log determinant of the covariance matrix
                sign, log_det = torch.slogdet(cov_matrix)
                # if sign <= 0:
                #     raise ValueError("Covariance matrix is not positive definite")
            
                y_arr = y_and_neighbors
                # Compute beta
                tmp1 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, locs))
                tmp2 = torch.matmul(locs.T, torch.linalg.solve(cov_matrix, y_arr))
                beta = torch.linalg.solve(tmp1, tmp2)

                mu = torch.matmul(locs, beta)
                mu_current = mu[0]
                mu_neighbors = mu[1:]

                # Mean and variance of y|x
                sigma = cov_matrix[0, 0]
                cov_xx = cov_matrix[1:, 1:]
                cov_xx_inv = torch.linalg.inv(cov_xx)

                cov_ygivenx = sigma - torch.matmul(cov_yx, torch.matmul(cov_xx_inv, cov_yx))
                cond_mean_tmp = torch.matmul(cov_yx, cov_xx_inv)
                cond_mean = mu_current + torch.matmul(cond_mean_tmp, (y_and_neighbors[1:] - mu_neighbors))
                
                alpha = current_y - cond_mean
                quad_form = alpha**2 * (1 / cov_ygivenx)
                log_det = torch.log(cov_ygivenx)
                neg_log_lik += 0.5 * (log_det + quad_form)
 
        return neg_log_lik


# Likelihood investigation

In [6]:
lat_lon_resolution = [10,10]

head100map = defaultdict(list)

headn = 10
b = [0]*7
b1=b2=b3=b4=b5=b6=0
for day in range(1,2):
    mm_cond_number = 10

    years = ['2024']
    month_range =[7,8]
    idx_for_datamap= [ 8*(day-1),8*day]

    instance = load_data_local_computer()
    map, ord_mm, nns_map= instance.load_mm20k_data_bymonthyear( lat_lon_resolution= lat_lon_resolution, mm_cond_number=mm_cond_number,years_=years, months_=month_range)
    analysis_data_map, aggregated_data = instance.load_working_data_byday( map, ord_mm, nns_map, idx_for_datamap= idx_for_datamap)

    params = [ 24.793444,	1.5845289,	1.7182478,	0.009088504,	-0.10729945,	0.13103764,	2.7172387]  #1250
    params = [ 27.25, 2.18, 2.294, 4.099e-4, -0.07915, 0.0999, 3.65]   #200
    # params = [24.42, 1.92, 1.92, 0.001, -0.045, 0.237, 3.34]
    params = torch.tensor(params, requires_grad=True)
    instance = matern_advec_beta_torch_vecchia(analysis_data_map, params, nns_map, mm_cond_number)
    
    out = instance.full_likelihood(params, aggregated_data[:,:4].to(torch.float64),aggregated_data[:,2].to(torch.float64), instance.matern_cov_ani)
    print(out)  # 15105


   
    print(f'day {day} finished')

tensor(2547.2583, dtype=torch.float64, grad_fn=<MulBackward0>)
day 1 finished


# Gradients and hessians

In [44]:
# Convert parameters to a tensor with requires_grad=True
params = torch.tensor(df.iloc[0, :-1].values, dtype=torch.float64, requires_grad=True)
print(f'input parameters: {params}')

# Define the function to compute the loss
def compute_loss(params):
    return instance.full_likelihood(params, aggregated_data[:, :4].to(torch.float64), aggregated_data[:, 2].to(torch.float64), instance.matern_cov_ani)
    # return instance.vecchia_interpolation_1to6(params, instance.matern_cov_ani, 35)
    
# Compute the first derivative using torch.func.grad
grad_f = torch.autograd.grad(compute_loss(params), params)

print(f' the gradient: {grad_f}')

grad_function = torch.func.grad(compute_loss)
gradient = grad_function(params)
print(f' the gradient: {gradient}')

#[  0.9324, -43.9642, -35.9082,  59.9937, -17.1091, -76.0932,  -0.6668]
torch.autograd.gradcheck(compute_loss, params, atol=1e-9, rtol=1e-6)


input parameters: tensor([25.8229,  1.0230,  1.1314,  0.0733, -0.0958,  0.1777,  1.5697],
       dtype=torch.float64, requires_grad=True)
 the gradient: (tensor([  0.9324, -43.9643, -35.9081,  59.9936, -17.1093, -76.0930,  -0.6668],
       dtype=torch.float64),)
 the gradient: tensor([  0.9324, -43.9643, -35.9081,  59.9936, -17.1093, -76.0930,  -0.6668],
       dtype=torch.float64, grad_fn=<AddBackward0>)


True

# Vecchia interpolation, conditioning on both past and future

In [None]:
# params = [24.42, 1.92, 1.92, 0.001, -0.045, 0.237, 3.34]


# Convert parameters to a tensor with requires_grad=True
params = torch.tensor(df.iloc[0, :-1].values, dtype=torch.float64, requires_grad=True)

params = [ 27.25, 2.18, 2.294, 4.099e-4, -0.07915, 0.0999, 3.65]   #200
params = torch.tensor(params, dtype=torch.float64, requires_grad=True)
  

def compute_statistic_full(params, data,y):
    instance = matern_advec_beta_torch_vecchia(analysis_data_map, params, nns_map, mm_cond_number)
    
    # Define the function to compute the loss
    def compute_loss(params):
        ll = instance.full_likelihood(params,data , y, instance.matern_cov_ani)
        print(f'full likelihood {ll}')
        return ll
        # return instance.vecchia_interpolation_1to6(params, instance.matern_cov_ani, 35)
    grad_function = torch.func.grad(compute_loss)
    gradient = grad_function(params)

    print(f'Gradient: {gradient}')

    # Compute the Hessian matrix using torch.func.hessian
    try:
        hessian_matrix =  torch.func.hessian(compute_loss)(params)
        # print(hessian_matrix)
    except Exception as e:
        print(f'Error computing Hessian: {e}')

    statistic  = torch.matmul(gradient, torch.linalg.solve(hessian_matrix, gradient))
    # print(f' statistic is {statistic}')
    print(f'full statistic is {statistic}')
    return statistic

def compute_statistic_vecc(params, mm_cond_number):
    instance = load_data_local_computer()
    map, ord_mm, nns_map= instance.load_mm20k_data_bymonthyear( lat_lon_resolution= lat_lon_resolution, mm_cond_number=mm_cond_number,years_=years, months_=month_range)
    analysis_data_map, aggregated_data = instance.load_working_data_byday( map, ord_mm, nns_map, idx_for_datamap= idx_for_datamap)
    instance = matern_advec_beta_torch_vecchia(analysis_data_map, params, nns_map, mm_cond_number)
    
    # Define the function to compute the loss
    def compute_loss(params):
        # return instance.vecchia_like_local_computer(params, instance.matern_cov_ani)
        ll = instance.vecchia_interpolation_1to6(params, instance.matern_cov_ani, 35)
        print(f'likelihood {ll}')
        return ll

    grad_function = torch.func.grad(compute_loss)
    gradient = grad_function(params)
    try:
        hessian_matrix =  torch.func.hessian(compute_loss)(params)
        # print(hessian_matrix)
    except Exception as e:
        print(f'Error computing Hessian: {e}')

    statistic  = torch.matmul(gradient, torch.linalg.solve(hessian_matrix, gradient))
    # print(f' statistic is {statistic}')
    print(f'vecc statistic is {statistic}')
    return statistic

compute_statistic_full(params,aggregated_data[:, :4].to(torch.float64),aggregated_data[:, 2].to(torch.float64))
compute_statistic_vecc(params, 10)

full likelihood 2547.258276245673
Gradient: tensor([-0.6224, -0.2272, -0.0769, -1.2766,  2.0406,  2.8008,  0.0732],
       dtype=torch.float64, grad_fn=<AddBackward0>)
full likelihood 2547.258276245673
full statistic is 5.100717330882949
likelihood 2505.8255159564155
likelihood 2505.8255159564155
vecc statistic is 4.490207541744903


tensor(4.4902, dtype=torch.float64, grad_fn=<DotBackward0>)

### Now compare statistics

In [15]:
for i in range(11,20):
    compute_statistic_vecc(params, i)

likelihood 2507.000597851662
likelihood 2507.000597851662
vecc statistic is 4.648368261380997
likelihood 2507.7412012081904
likelihood 2507.7412012081904
vecc statistic is 4.963277645222633
likelihood 2507.274915903053
likelihood 2507.274915903053
vecc statistic is 4.934515341837704
likelihood 2507.0920177422545
likelihood 2507.0920177422545
vecc statistic is 5.09900386943201
likelihood 2507.168987461403
likelihood 2507.168987461403
vecc statistic is 5.32417886417015
likelihood 2507.003368575228
likelihood 2507.003368575228
vecc statistic is 5.409555477662674
likelihood 2507.2581959007302
likelihood 2507.2581959007302
vecc statistic is 5.296278515140054
likelihood 2507.768033104974
likelihood 2507.768033104974
vecc statistic is 5.4752822606873774
likelihood 2507.6583166714668
likelihood 2507.6583166714668
vecc statistic is 5.4529586855481975


# vecchia condition on two lags

In [27]:
params = [ 27.25, 2.18, 2.294, 4.099e-4, -0.07915, 0.0999, 3.65]   #200
params = torch.tensor(params, dtype=torch.float64, requires_grad=True)

def compute_statistic_vecc(params, mm_cond_number):
    instance = load_data_local_computer()
    map, ord_mm, nns_map= instance.load_mm20k_data_bymonthyear( lat_lon_resolution= lat_lon_resolution, mm_cond_number=mm_cond_number,years_=years, months_=month_range)
    analysis_data_map, aggregated_data = instance.load_working_data_byday( map, ord_mm, nns_map, idx_for_datamap= idx_for_datamap)
    instance = matern_advec_beta_torch_vecchia(analysis_data_map, params, nns_map, mm_cond_number)
    
    # Define the function to compute the loss
    def compute_loss(params):
        # return instance.vecchia_like_local_computer(params, instance.matern_cov_ani)
        ll = instance.vecchia_b2(params, instance.matern_cov_ani, 35)
        print(f'vecc_b2 likelihood {ll}')
        return ll

    grad_function = torch.func.grad(compute_loss)
    gradient = grad_function(params)
    try:
        hessian_matrix =  torch.func.hessian(compute_loss)(params)
        # print(hessian_matrix)
    except Exception as e:
        print(f'Error computing Hessian: {e}')

    statistic  = torch.matmul(gradient, torch.linalg.solve(hessian_matrix, gradient))
    # print(f' statistic is {statistic}')
    print(f'vecc statistic is {statistic}')
    return statistic

compute_statistic_full(params,aggregated_data[:, :4].to(torch.float64),aggregated_data[:, 2].to(torch.float64))
compute_statistic_vecc(params, 10)

full likelihood 2547.258276245673
Gradient: tensor([-0.6224, -0.2272, -0.0769, -1.2766,  2.0406,  2.8008,  0.0732],
       dtype=torch.float64, grad_fn=<AddBackward0>)
full likelihood 2547.258276245673
full statistic is 5.100717330882949
vecc_b2 likelihood 2561.9232868312483
vecc_b2 likelihood 2561.9232868312483
vecc statistic is 7.17562895976746


tensor(7.1756, dtype=torch.float64, grad_fn=<DotBackward0>)

In [28]:
for i in range(4,20):
    compute_statistic_vecc(params, i)

vecc_b2 likelihood 2567.9938103573804
vecc_b2 likelihood 2567.9938103573804
vecc statistic is 8.095305048640391
vecc_b2 likelihood 2567.724168564479
vecc_b2 likelihood 2567.724168564479
vecc statistic is 5.7202865316076785
vecc_b2 likelihood 2566.1323226707896
vecc_b2 likelihood 2566.1323226707896
vecc statistic is 4.295006659870506
vecc_b2 likelihood 2563.7262156730662
vecc_b2 likelihood 2563.7262156730662
vecc statistic is 4.92789794545219
vecc_b2 likelihood 2563.624953626223
vecc_b2 likelihood 2563.624953626223
vecc statistic is 5.242287801146139
vecc_b2 likelihood 2563.532784455573
vecc_b2 likelihood 2563.532784455573
vecc statistic is -1.8559311918389358
vecc_b2 likelihood 2561.9232868312483
vecc_b2 likelihood 2561.9232868312483
vecc statistic is 7.17562895976746
vecc_b2 likelihood 2562.294788225645
vecc_b2 likelihood 2562.294788225645
vecc statistic is 5.441139626193181
vecc_b2 likelihood 2563.4818904439358
vecc_b2 likelihood 2563.4818904439358
vecc statistic is 4.767300390127459

# vecchia_local_extra_base (condition on past with some base conditioning set)

In [None]:
sd = analysis_data_map['2024_07_y24m07day01_hm01:00']
# Compute the required statistics
# Compute the required statistics
max_lat = torch.max(sd[:, 0])
min_lat = torch.min(sd[:, 0])
median_lat = torch.median(sd[:, 0])

max_lon = torch.max(sd[:, 1])
min_lon = torch.min(sd[:, 1])
median_lon = torch.median(sd[:, 1])

# Extract the 9 points along with their locations (indices)
points = [
    (min_lon, min_lat),
    #(min_lon, median_lat),
    (min_lon, max_lat),
    (median_lon, min_lat),
    #(median_lon, median_lat),
    (median_lon, max_lat),
    (max_lon, min_lat),
    #(max_lon, median_lat),
    (max_lon, max_lat)
]
print(points)

indices = []
for lon, lat in points:
    condition = (sd[:, 0] == lat) & (sd[:, 1] == lon)
    indices.append(torch.where(condition)[0])

# Create the indices tensor
indices_tensor = torch.cat(indices)

print("Indices in Tensor Frame:")
print(indices_tensor)

base_list = indices_tensor.clone().detach().tolist()

def compute_statistic_vecc(params, mm_cond_number):
    instance = load_data_local_computer()
    map, ord_mm, nns_map= instance.load_mm20k_data_bymonthyear( lat_lon_resolution= lat_lon_resolution, mm_cond_number=mm_cond_number,years_=years, months_=month_range)
    analysis_data_map, aggregated_data = instance.load_working_data_byday( map, ord_mm, nns_map, idx_for_datamap= idx_for_datamap)
    instance = matern_advec_beta_torch_vecchia(analysis_data_map, params, nns_map, mm_cond_number)
    
    # Define the function to compute the loss
    def compute_loss(params):
        # return instance.vecchia_like_local_computer(params, instance.matern_cov_ani)
        ll = instance.vecchia_local_extra_base(params, instance.matern_cov_ani)
        print(f'likelihood {ll}')
        return ll

    grad_function = torch.func.grad(compute_loss)
    gradient = grad_function(params)
    try:
        hessian_matrix =  torch.func.hessian(compute_loss)(params)
        # print(hessian_matrix)
    except Exception as e:
        print(f'Error computing Hessian: {e}')

    statistic  = torch.matmul(gradient, torch.linalg.solve(hessian_matrix, gradient))
    # statistic  = torch.linalg.solve(hessian_matrix, gradient)
    # print(f' statistic is {statistic}')
    print(f'vecc statistic is {statistic}')
    return statistic

params = [ 27.25, 2.18, 2.294, 4.099e-4, -0.07915, 0.0999, 3.65]   #200
params = torch.tensor(params, dtype=torch.float64, requires_grad=True)

for i in range(4,10):
    compute_statistic_vecc(params, i)

[(tensor(110.0250, dtype=torch.float64), tensor(5.0250, dtype=torch.float64)), (tensor(110.0250, dtype=torch.float64), tensor(9.5250, dtype=torch.float64)), (tensor(114.5250, dtype=torch.float64), tensor(5.0250, dtype=torch.float64)), (tensor(114.5250, dtype=torch.float64), tensor(9.5250, dtype=torch.float64)), (tensor(119.5250, dtype=torch.float64), tensor(5.0250, dtype=torch.float64)), (tensor(119.5250, dtype=torch.float64), tensor(9.5250, dtype=torch.float64))]
Indices in Tensor Frame:
tensor([  1, 145, 160,  50,   2,   4])
grad*hessian*grad from <bound method matern_advec_beta_torch_vecchia.vecchia_local_extra_base of <__main__.matern_advec_beta_torch_vecchia object at 0x30823ca40>> is 53.17959587711484


KeyboardInterrupt: 

[(tensor(110.0250, dtype=torch.float64), tensor(5.0250, dtype=torch.float64)), (tensor(110.0250, dtype=torch.float64), tensor(9.5250, dtype=torch.float64)), (tensor(114.5250, dtype=torch.float64), tensor(5.0250, dtype=torch.float64)), (tensor(114.5250, dtype=torch.float64), tensor(9.5250, dtype=torch.float64)), (tensor(119.5250, dtype=torch.float64), tensor(5.0250, dtype=torch.float64)), (tensor(119.5250, dtype=torch.float64), tensor(9.5250, dtype=torch.float64))]
Indices in Tensor Frame:
tensor([  1, 145, 160,  50,   2,   4])
vecc statistic is 59.01199704244943
vecc statistic is 59.449997640819724
vecc statistic is 52.416180825260284
vecc statistic is 52.62143386846466
vecc statistic is 52.21418603035472
vecc statistic is 55.220113253939715

In [60]:
num_points =5
max_lat, min_lat, median_lat, max_lon, min_lon, median_lon = compute_statistics(sd)

lat_points = torch.linspace(min_lat, max_lat, num_points)
lon_points = torch.linspace(min_lon, max_lon, num_points)

points = [(lon, lat) for lat in lat_points for lon in lon_points]

points

[(tensor(110.0250), tensor(5.0250)),
 (tensor(112.2750), tensor(5.0250)),
 (tensor(114.5250), tensor(5.0250)),
 (tensor(116.7750), tensor(5.0250)),
 (tensor(119.0250), tensor(5.0250)),
 (tensor(110.0250), tensor(6.0250)),
 (tensor(112.2750), tensor(6.0250)),
 (tensor(114.5250), tensor(6.0250)),
 (tensor(116.7750), tensor(6.0250)),
 (tensor(119.0250), tensor(6.0250)),
 (tensor(110.0250), tensor(7.0250)),
 (tensor(112.2750), tensor(7.0250)),
 (tensor(114.5250), tensor(7.0250)),
 (tensor(116.7750), tensor(7.0250)),
 (tensor(119.0250), tensor(7.0250)),
 (tensor(110.0250), tensor(8.0250)),
 (tensor(112.2750), tensor(8.0250)),
 (tensor(114.5250), tensor(8.0250)),
 (tensor(116.7750), tensor(8.0250)),
 (tensor(119.0250), tensor(8.0250)),
 (tensor(110.0250), tensor(9.0250)),
 (tensor(112.2750), tensor(9.0250)),
 (tensor(114.5250), tensor(9.0250)),
 (tensor(116.7750), tensor(9.0250)),
 (tensor(119.0250), tensor(9.0250))]

In [54]:

sd = analysis_data_map['2024_07_y24m07day01_hm01:00']
num_points = 4
indices_tensor = get_indices(sd, num_points)
base_list = indices_tensor.clone().detach().tolist()


for i in range(4,20):
    num_points = i
    indices_tensor = get_indices(sd, num_points)
    print(indices_tensor)

tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)
tensor([], dtype=torch.int64)


In [50]:
def compute_statistics(sd):
    max_lat = torch.max(sd[:, 0])
    min_lat = torch.min(sd[:, 0])
    median_lat = torch.median(sd[:, 0])

    max_lon = torch.max(sd[:, 1])
    min_lon = torch.min(sd[:, 1])
    median_lon = torch.median(sd[:, 1])

    return max_lat, min_lat, median_lat, max_lon, min_lon, median_lon


def get_indices(sd, num_points):
    max_lat, min_lat, median_lat, max_lon, min_lon, median_lon = compute_statistics(sd)
    
    lat_points = torch.linspace(min_lat, max_lat, num_points)
    lon_points = torch.linspace(min_lon, max_lon, num_points)
    
    points = [(lon, lat) for lat in lat_points for lon in lon_points]
    
    indices = []
    for lon, lat in points:
        condition = (sd[:, 0] == lat) & (sd[:, 1] == lon)
        indices.append(torch.where(condition)[0])
    
    indices_tensor = torch.cat(indices)
    
    return indices_tensor

# Example usage
sd = analysis_data_map['2024_07_y24m07day01_hm01:00']
num_points = 4
indices_tensor = get_indices(sd, num_points)
base_list = indices_tensor.clone().detach().tolist()


for i in range(4,20):
    num_points = i
    indices_tensor = get_indices(sd, num_points)
    base_list = indices_tensor.clone().detach().tolist()

    compute_statistic_vecc(params, 5)

vecc statistic is 61.74867066606819
vecc statistic is 61.74867066606819
vecc statistic is 61.74867066606819
vecc statistic is 61.74867066606819
vecc statistic is 61.74867066606819


KeyboardInterrupt: 