In [1]:
# for path in sys.path:
#   print(path)

import sys
gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"
sys.path.append(gems_tco_path)

import logging
import argparse # Argument parsing
import math
from collections import defaultdict
import concurrent
from concurrent.futures import ThreadPoolExecutor  # Importing specific executor for clarity
import time

# Data manipulation and analysis
import pandas as pd
import numpy as np

# Nearest neighbor search
import sklearn
from sklearn.neighbors import BallTree

# Special functions and optimizations
from scipy.special import gamma, kv  # Bessel function and gamma function
from scipy.stats import multivariate_normal  # Simulation
from scipy.optimize import minimize
from scipy.spatial.distance import cdist  # For space and time distance
from scipy.spatial import distance  # Find closest spatial point
from scipy.optimize import differential_evolution

# Plotting and visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Type hints
from typing import Callable, Union, Tuple

# Add your custom path
# sys.path.append("/cache/home/jl2815/tco")

# Custom imports

from GEMS_TCO import orbitmap 
from GEMS_TCO import kernels 
from GEMS_TCO import evaluate
from GEMS_TCO import orderings as _orderings

import pickle
import torch
import torch.optim as optim
import copy                    # clone tensor

In [2]:
lat_lon_resolution = [15,15]
mm_cond_number = 10
params= [20, 8.25, 5.25, 0.2, 0.5, 5]
idx_for_datamap= [0,8]

In [3]:
# Load the one dictionary to set spaital coordinates
# filepath = "C:/Users/joonw/TCO/GEMS_data/data_2023/sparse_cen_map23_01.pkl"
filepath = "/Users/joonwonlee/Documents/GEMS_DATA/pickle_2023/coarse_cen_map23_01.pkl"
with open(filepath, 'rb') as pickle_file:
    coarse_dict_24_1 = pickle.load(pickle_file)

sample_df = coarse_dict_24_1['y23m01day01_hm02:12']

sample_key = coarse_dict_24_1.get('y23m01day01_hm02:12')
if sample_key is None:
    print("Key 'y23m01day01_hm02:12' not found in the dictionary.")

# { (20,20):(5,1), (5,5):(20,40) }
rho_lat = lat_lon_resolution[0]          
rho_lon = lat_lon_resolution[1]
lat_n = sample_df['Latitude'].unique()[::rho_lat]
lon_n = sample_df['Longitude'].unique()[::rho_lon]

lat_number = len(lat_n)
lon_number = len(lon_n)

# Set spatial coordinates for each dataset
coarse_dicts = {}

years = ['2024']
for year in years:
    for month in range(7, 8):  # Iterate over all months
        # filepath = f"C:/Users/joonw/TCO/GEMS_data/data_{year}/sparse_cen_map{year[2:]}_{month:02d}.pkl"
        filepath = f"/Users/joonwonlee/Documents/GEMS_DATA/pickle_{year}/coarse_cen_map{year[2:]}_{month:02d}.pkl"
        with open(filepath, 'rb') as pickle_file:
            loaded_map = pickle.load(pickle_file)
            for key in loaded_map:
                tmp_df = loaded_map[key]
                coarse_filter = (tmp_df['Latitude'].isin(lat_n)) & (tmp_df['Longitude'].isin(lon_n))
                coarse_dicts[f"{year}_{month:02d}_{key}"] = tmp_df[coarse_filter].reset_index(drop=True)


key_idx = sorted(coarse_dicts)
if not key_idx:
    raise ValueError("coarse_dicts is empty")

# extract first hour data because all data shares the same spatial grid
data_for_coord = coarse_dicts[key_idx[0]]
x1 = data_for_coord['Longitude'].values
y1 = data_for_coord['Latitude'].values 
coords1 = np.stack((x1, y1), axis=-1)


# instance = orbitmap.MakeOrbitdata(data_for_coord, lat_s=5, lat_e=10, lon_s=110, lon_e=120)
# s_dist = cdist(coords1, coords1, 'euclidean')
# ord_mm, _ = instance.maxmin_naive(s_dist, 0)

ord_mm = _orderings.maxmin_cpp(coords1)
data_for_coord = data_for_coord.iloc[ord_mm].reset_index(drop=True)
coords1_reordered = np.stack((data_for_coord['Longitude'].values, data_for_coord['Latitude'].values), axis=-1)
# nns_map = instance.find_nns_naive(locs=coords1_reordered, dist_fun='euclidean', max_nn=mm_cond_number)
nns_map=_orderings.find_nns_l2(locs= coords1_reordered  ,max_nn = mm_cond_number)


analysis_data_map = {}
for i in range(idx_for_datamap[0],idx_for_datamap[1]):
    tmp = coarse_dicts[key_idx[i]].copy()
    tmp['Hours_elapsed'] = np.round(tmp['Hours_elapsed']-477700)

    tmp = tmp.iloc[ord_mm, :4].to_numpy()
    tmp = torch.from_numpy(tmp).float()  # Convert NumPy to Tensor
    # tmp = tmp.clone().detach().requires_grad_(True)  # Enable gradients
    
    analysis_data_map[key_idx[i]] = tmp

aggregated_data = pd.DataFrame()
for i in range(idx_for_datamap[0],idx_for_datamap[1]):
    tmp = coarse_dicts[key_idx[i]].copy()
    tmp['Hours_elapsed'] = np.round(tmp['Hours_elapsed']-477700)
    tmp = tmp.iloc[ord_mm].reset_index(drop=True)  
    aggregated_data = pd.concat((aggregated_data, tmp), axis=0)

aggregated_data = aggregated_data.iloc[:, :4].to_numpy()

aggregated_data = torch.from_numpy(aggregated_data).float()  # Convert NumPy to Tensor
# aggregated_np = aggregated_np.clone().detach().requires_grad_(True)  # Enable gradients




In [None]:
result = {}

for day in range(2):
    idx_for_datamap= [8*day,8*(day+1)]
    analysis_data_map = {}
    for i in range(idx_for_datamap[0],idx_for_datamap[1]):
        tmp = coarse_dicts[key_idx[i]].copy()
        tmp['Hours_elapsed'] = np.round(tmp['Hours_elapsed']-477700)

        tmp = tmp.iloc[ord_mm, :4].to_numpy()
        tmp = torch.from_numpy(tmp).float()  # Convert NumPy to Tensor
        # tmp = tmp.clone().detach().requires_grad_(True)  # Enable gradients
        
        analysis_data_map[key_idx[i]] = tmp
    aggregated_data = pd.DataFrame()
    for i in range(idx_for_datamap[0],idx_for_datamap[1]):
        tmp = coarse_dicts[key_idx[i]].copy()
        tmp['Hours_elapsed'] = np.round(tmp['Hours_elapsed']-477700)
        tmp = tmp.iloc[ord_mm].reset_index(drop=True)  
        aggregated_data = pd.concat((aggregated_data, tmp), axis=0)
    
    aggregated_data = aggregated_data.iloc[:, :4].to_numpy()

    aggregated_data = torch.from_numpy(aggregated_data).float()  # Convert NumPy to Tensor

    params = [24.42, 1.92, 1.92, 0.001, -0.045, 0.237, 3.34]
    params = torch.tensor(params, requires_grad=True)

    torch_smooth = torch.tensor(0.5, dtype=torch.float32)


    instance = kernels.model_fitting(
        smooth=0.5,
        input_map=analysis_data_map,
        aggregated_data=aggregated_data,
        nns_map=nns_map,
        mm_cond_number=mm_cond_number
    )

    # optimizer = optim.Adam([params], lr=0.01)  # For Adam
    optimizer, scheduler = instance.optimizer_fun(params, lr=0.01, betas=(0.9, 0.8), eps=1e-8, step_size=20, gamma=0.9)    
    out = instance.run_full(params, optimizer,scheduler, epochs=3000)
    result[day+1] = out
        

### Save estimates in to pickle fime

In [None]:
isinstance(result, dict)
import os
# Save pickle
output_filename = f"estimation_1250_july24.pkl"

# base_path = "/home/jl2815/tco/data/pickle_data"
output_path = "/Users/joonwonlee/Documents/"
output_filepath = os.path.join(output_path, output_filename)
with open(output_filepath, 'wb') as pickle_file:
    pickle.dump(result, pickle_file)

input_filepath = output_filepath
# Load pickle
with open(input_filepath, 'rb') as pickle_file:
    loaded_map = pickle.load(pickle_file)

loaded_map

{1: [array([ 2.6412588e+01,  2.5630600e+00,  2.5388060e+00,  2.7070900e-02,
         -2.4097290e-02,  1.0483751e-01,  5.3329625e+00], dtype=float32),
  1374.03466796875],
 2: [array([24.588697  ,  3.3285947 ,  3.9052231 ,  0.13455719, -0.03456095,
          0.0597389 ,  2.938084  ], dtype=float32),
  1148.790771484375]}

## load estimates from amarel and make it into pd data frame

In [None]:

input_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/Exercises/st_model/estimates"
input_filename = "estimation_200_july24.pkl"
input_filename = "estimation_1250_july24.pkl"
input_filepath = os.path.join(input_path, input_filename)
# Load pickle
with open(input_filepath, 'rb') as pickle_file:
    amarel_map1250= pickle.load(pickle_file)

# Assuming df_1250 is your DataFrame
df_1250 = pd.DataFrame()
for key in amarel_map1250:
    tmp = pd.DataFrame(amarel_map1250[key][0].reshape(1, -1), columns=['sigmasq', 'range_lat', 'range_lon', 'advec_lat', 'advec_lon', 'beta', 'nugget'])
    tmp['loss'] = amarel_map1250[key][1]
    df_1250 = pd.concat((df_1250, tmp), axis=0)

# Generate date range
date_range = pd.date_range(start='07-01-24', end='07-31-24')

# Ensure the number of dates matches the number of rows in df_1250
if len(date_range) == len(df_1250):
    df_1250.index = date_range
else:
    print("The number of dates does not match the number of rows in the DataFrame.")

print(df_1250)


              sigmasq  range_lat  range_lon  advec_lat  advec_lon      beta  \
2024-07-01  25.964643   2.120772   2.225752   0.001676  -0.079428  0.102517   
2024-07-02  23.878902   3.283870   3.612040   0.048285   0.019385  0.079440   
2024-07-03  26.226320   1.905852   2.220417  -0.010698  -0.116340  0.119043   
2024-07-04  24.515753   2.836888   3.613807  -0.132429  -0.165947  0.072058   
2024-07-05  23.338093   3.692467   3.726604  -0.060823  -0.136659  0.059121   
2024-07-06  25.878857   2.543730   2.867047   0.003362  -0.137558  0.094139   
2024-07-07  26.262951   2.129691   3.281083   0.050776  -0.301583  0.087698   
2024-07-08  26.109503   1.798699   2.407871   0.027176  -0.394423  0.111374   
2024-07-09  24.766806   2.113744   2.775355  -0.026761  -0.209648  0.105213   
2024-07-10  26.156773   1.482763   2.269531  -0.007672  -0.005897  0.145461   
2024-07-11  24.291435   2.281666   3.302899  -0.030417   0.053161  0.104988   
2024-07-12  22.386654   3.815042   3.928914  -0.0625

              sigmasq  range_lat  range_lon  advec_lat  advec_lon      beta  \
2024-07-01  24.793444   1.584529   1.718248   0.009089  -0.107299  0.131038   
2024-07-02  24.424301   1.997055   1.942683   0.043588  -0.072679  0.137124   
2024-07-03  26.009497   1.215236   1.558868   0.023392  -0.150548  0.199850   
2024-07-04  24.701347   1.612308   1.822960  -0.164069  -0.237443  0.131595   
2024-07-05  22.598671   2.901185   3.722327  -0.011729  -0.152072  0.072866   
2024-07-06  25.594908   1.702692   2.255174   0.017462  -0.158125  0.098684   
2024-07-07  26.030510   1.261084   2.831952   0.054831  -0.343255  0.103045   
2024-07-08  26.043682   0.995279   1.629503  -0.019824  -0.411626  0.164296   
2024-07-09  24.052071   1.377774   2.357721   0.021439  -0.220316  0.142847   
2024-07-10  25.766109   1.392051   2.358171   0.026684  -0.077366  0.150648   
2024-07-11  23.945438   1.490333   2.470762  -0.009915   0.027429  0.137959   
2024-07-12  23.036034   2.299998   3.346955  -0.0542