In [1]:
# work environment: jl2815
# Standard libraries
import sys
import logging
import argparse # Argument parsing
import math
from collections import defaultdict
import concurrent
from concurrent.futures import ThreadPoolExecutor  # Importing specific executor for clarity
import time

# Data manipulation and analysis
import pandas as pd
import numpy as np

# Nearest neighbor search
import sklearn
from sklearn.neighbors import BallTree

# Special functions and optimizations
from scipy.special import gamma, kv  # Bessel function and gamma function
from scipy.stats import multivariate_normal  # Simulation
from scipy.optimize import minimize
from scipy.spatial.distance import cdist  # For space and time distance
from scipy.spatial import distance  # Find closest spatial point
from scipy.optimize import differential_evolution

# Plotting and visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Type hints
from typing import Callable, Union, Tuple

# Add your custom path
sys.path.append("/cache/home/jl2815/tco")

# Custom imports
from GEMS_TCO import orbitmap 
from GEMS_TCO import kernels 
from GEMS_TCO import evaluate

import pickle

  '''


In [2]:
lat_lon_resolution = [10,10]
mm_cond_number = 10
params= [60, 5.25, 5.25, 0.2, 0.5, 5]
# params= [20, 8.25, 5.25, 0.2, 0.5, 5]
key_for_dict= 8


# Load the one dictionary to set spaital coordinates
filepath = "C:/Users/joonw/TCO/GEMS_data/data_2023/sparse_cen_map23_01.pkl"

with open(filepath, 'rb') as pickle_file:
    coarse_dict_24_1 = pickle.load(pickle_file)

sample_df = coarse_dict_24_1['y23m01day01_hm02:12']

sample_key = coarse_dict_24_1.get('y23m01day01_hm02:12')
if sample_key is None:
    print("Key 'y23m01day01_hm02:12' not found in the dictionary.")

# { (20,20):(5,1), (5,5):(20,40) }
rho_lat = lat_lon_resolution[0]          
rho_lon = lat_lon_resolution[1]
lat_n = sample_df['Latitude'].unique()[::rho_lat]
lon_n = sample_df['Longitude'].unique()[::rho_lon]

lat_number = len(lat_n)
lon_number = len(lon_n)

# Set spatial coordinates for each dataset
coarse_dicts = {}

years = ['2024']
for year in years:
    for month in range(7, 8):  # Iterate over all months
        filepath = f"C:/Users/joonw/TCO/GEMS_data/data_{year}/sparse_cen_map{year[2:]}_{month:02d}.pkl"
        with open(filepath, 'rb') as pickle_file:
            loaded_map = pickle.load(pickle_file)
            for key in loaded_map:
                tmp_df = loaded_map[key]
                coarse_filter = (tmp_df['Latitude'].isin(lat_n)) & (tmp_df['Longitude'].isin(lon_n))
                coarse_dicts[f"{year}_{month:02d}_{key}"] = tmp_df[coarse_filter].reset_index(drop=True)


key_idx = sorted(coarse_dicts)
if not key_idx:
    raise ValueError("coarse_dicts is empty")

# extract first hour data because all data shares the same spatial grid
data_for_coord = coarse_dicts[key_idx[0]]
x1 = data_for_coord['Longitude'].values
y1 = data_for_coord['Latitude'].values 
coords1 = np.stack((x1, y1), axis=-1)

instance = orbitmap.MakeOrbitdata(data_for_coord, lat_s=5, lat_e=10, lon_s=110, lon_e=120)
s_dist = cdist(coords1, coords1, 'euclidean')
ord_mm, _ = instance.maxmin_naive(s_dist, 0)

data_for_coord = data_for_coord.iloc[ord_mm].reset_index(drop=True)
coords1_reordered = np.stack((data_for_coord['Longitude'].values, data_for_coord['Latitude'].values), axis=-1)
nns_map = instance.find_nns_naive(locs=coords1_reordered, dist_fun='euclidean', max_nn=mm_cond_number)


key_for_dict= 8
analysis_data_map = {}
for i in range(key_for_dict):
    tmp = coarse_dicts[key_idx[i]]
    tmp['Hours_elapsed'] = np.round(tmp['Hours_elapsed'])
    # tmp = tmp.iloc[ord_mm].reset_index(drop=True)  
    tmp = tmp.iloc[ord_mm, :4].to_numpy()
    # Sort by the first two columns
    sorted_indices = np.lexsort((tmp[:, 0], tmp[:, 1]))
    tmp = tmp[sorted_indices]
    analysis_data_map[key_idx[i]] = tmp

aggregated_data = pd.DataFrame()
for i in range((key_for_dict)):
    tmp = coarse_dicts[key_idx[i]]
    tmp = tmp.iloc[ord_mm].reset_index(drop=True)  
    tmp = tmp.iloc[ord_mm].reset_index(drop=True)  
    aggregated_data = pd.concat((aggregated_data, tmp), axis=0)
          
aggregated_np = aggregated_data.iloc[:,:4].to_numpy()
sorted_indices = np.lexsort((aggregated_np[:, 0], aggregated_np[:, 1]))
aggregated_np = aggregated_np[sorted_indices]
print(f'Aggregated data shape: {aggregated_data.shape}')



instance = kernels.likelihood_function(smooth=0.5, input_map=analysis_data_map, nns_map=nns_map, mm_cond_number=mm_cond_number)

from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler


# Initialize the MinMaxScaler
scaler = MinMaxScaler()

# Select the columns to scale (0, 1, and 3)
columns_to_scale = [ 3]

aggregated_np2 = aggregated_np.copy()
# Fit and transform the selected columns
aggregated_np2[:, columns_to_scale] = scaler.fit_transform(aggregated_np2[:, columns_to_scale])

Aggregated data shape: (1600, 5)


In [3]:
instance.full_likelihood(params, aggregated_np[:,:4],aggregated_np[:,2], instance.matern_cov_yx)

np.float64(4312.9922255818)

In [4]:
from line_profiler import LineProfiler
lp = LineProfiler()
lp_wrapper = lp(instance.matern_cov_yx)
lp_wrapper(params, aggregated_np2[:,:4],aggregated_np2[:,:4])
lp.print_stats()

Timer unit: 1e-07 s

Total time: 0.733604 s
File: c:\Users\joonw\TCO\GEMS_TCO-1\GEMS_TCO\kernels.py
Function: matern_cov_yx at line 162

Line #      Hits         Time  Per Hit   % Time  Line Contents
   162                                               def matern_cov_yx(self,params: Tuple[float,float,float,float,float,float], y: np.ndarray, x: np.ndarray) -> np.ndarray:
   163                                               
   164         1         11.0     11.0      0.0          sigmasq, range_lat, range_lon, advec, beta, nugget  = params
   165                                                   # Validate inputs
   166         1          7.0      7.0      0.0          if y is None or x is None:
   167                                                       raise ValueError("Both y and x_df must be provided.")
   168                                                   # Extract values
   169         1         24.0     24.0      0.0          x1 = x[:, 0]
   170         1          6.0      6.