In [1]:
# Standard libraries
import sys
# Add your custom path
gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"
sys.path.append(gems_tco_path)
import logging
import argparse # Argument parsing

# Data manipulation and analysis
import pandas as pd
import numpy as np
import pickle
import torch
import torch.optim as optim
import copy                    # clone tensor
import time

# Custom imports
import GEMS_TCO
from GEMS_TCO import kernels
from GEMS_TCO import data_preprocess 
from GEMS_TCO import kernels 
from GEMS_TCO import orderings as _orderings 
from GEMS_TCO import load_data
from GEMS_TCO import alg_optimization, alg_opt_Encoder
from GEMS_TCO import configuration as config

from typing import Optional, List, Tuple
from pathlib import Path
import typer
import json
from json import JSONEncoder

from GEMS_TCO.data_loader import load_data2


In [3]:
from pathlib import Path
# Assuming 'config' and 'load_data' class are defined and imported elsewhere

# --- Parameters derived from your framework ---
v: float = 0.5
space: List[str] = ['4', '4']
days: List[str] = ['0', '31']
mm_cond_number: int = 20
# --- End of framework parameters ---

lat_lon_resolution = [int(s) for s in space]
days_s_e = [int(d) for d in days]
days_list = list(range(days_s_e[0], days_s_e[1]))

# These values were not in the framework, so they remain as set in your snippet
years = ['2024']
month_range = [7] 

# Assuming 'config' is available in your environment
output_path = input_path = Path(config.mac_estimates_day_path)

## load ozone data from amarel
data_load_instance = load_data2(config.mac_data_load_path)

# Call the function using the variables from the framework
df_map, ord_mm, nns_map = data_load_instance.load_maxmin_ordered_data_bymonthyear(
lat_lon_resolution=lat_lon_resolution, 
mm_cond_number=mm_cond_number,
years_=years, 
months_=month_range,
lat_range=[0.0, 5.0],      # <-- Add this
lon_range=[123.0, 133.0]   # <-- Add this
)

Subsetting data to lat: [0.0, 5.0], lon: [123.0, 133.0]


In [10]:
df_day_aggregated_list = []
df_day_map_list = []
for i in range(31):
    idx_for_datamap = [i*8, (i+1)*8]

    cur_map, cur_df = analysis_map_no_mm, agg_data_no_mm = data_load_instance.load_working_data_keep_ori(
    df_map, 
    idx_for_datamap, 
    ord_mm=None,  # or just omit it
    dtype=torch.float # or just omit it
)
    df_day_aggregated_list.append( cur_df )
    df_day_map_list.append( cur_map )

print(df_day_aggregated_list[0].shape)
print(df_day_aggregated_list[1][:100])

torch.Size([8960, 6])
tensor([[  4.8880, 132.9840, 264.4437,  45.0000,   4.8891, 132.9843],
        [  4.8880, 132.7320, 271.4257,  45.0000,   4.8896, 132.7330],
        [  4.8880, 132.4800, 266.4241,  45.0000,   4.8903, 132.4819],
        [  4.8880, 132.2280, 257.9527,  45.0000,   4.8904, 132.2309],
        [  4.8880, 131.9760, 259.0169,  45.0000,   4.8904, 131.9805],
        [  4.8880, 131.7240, 258.4678,  45.0000,   4.8906, 131.7299],
        [  4.8880, 131.4720, 272.3340,  45.0000,   4.8911, 131.4797],
        [  4.8880, 131.2200, 264.8994,  45.0000,   4.8916, 131.2293],
        [  4.8880, 130.9680, 269.7025,  45.0000,   4.8925, 130.9793],
        [  4.8880, 130.7160, 261.1764,  45.0000,   4.8925, 130.7288],
        [  4.8880, 130.4640, 267.1273,  45.0000,   4.8931, 130.4789],
        [  4.8880, 130.2120, 266.5924,  45.0000,   4.8936, 130.2286],
        [  4.8880, 129.9600, 256.9923,  45.0000,   4.8940, 129.9786],
        [  4.8880, 129.7080, 262.3289,  45.0000,   4.8944, 129.7282]

In [5]:
import torch
import numpy as np
import torch.nn.functional as F
import os
import pickle

# Assume GEMS_TCO is a custom class/module you have available
# from your_project import GEMS_TCO

# =========================================================================
# 1. Helper Functions
# =========================================================================

def subset_tensor(df_tensor: torch.Tensor) -> torch.Tensor:
    """Subsets a tensor to a specific lat/lon range."""
    #lat_mask = (df_tensor[:, 0] >= -5) & (df_tensor[:, 0] <= 6.3)
    #lon_mask = (df_tensor[:, 1] >= 118) & (df_tensor[:, 1] <= 134.2)
    lat_mask = (df_tensor[:, 0] >= 0) & (df_tensor[:, 0] <= 5)
    lon_mask = (df_tensor[:, 1] >= 123) & (df_tensor[:, 1] <= 133)

    df_sub = df_tensor[lat_mask & lon_mask].clone()
    return df_sub

def apply_first_difference_2d_tensor(df_tensor: torch.Tensor) -> torch.Tensor:
    """
    Applies a 2D first-order difference filter using convolution.
    This approximates Z(s) = [X(s+d_lat) - X(s)] + [X(s+d_lon) - X(s)].
    """
    if df_tensor.size(0) == 0:
        return torch.empty(0, 4)

    # 1. Get grid dimensions and validate
    unique_lats = torch.unique(df_tensor[:, 0])
    unique_lons = torch.unique(df_tensor[:, 1])
    lat_count, lon_count = unique_lats.size(0), unique_lons.size(0)

    if df_tensor.size(0) != lat_count * lon_count:
        raise ValueError("Tensor size does not match grid dimensions. Must be a complete grid.")
    if lat_count < 2 or lon_count < 2:
        return torch.empty(0, 4)

    # 2. Reshape data and define the correct kernel
    ozone_data = df_tensor[:, 2].reshape(1, 1, lat_count, lon_count)
    
    # ✅ CORRECT KERNEL: This kernel results in the standard first-order difference:
    # Z(i,j) = X(i+1,j) + X(i,j+1) - 2*X(i,j)
    # Note: F.conv2d in PyTorch actually performs cross-correlation. To get a true
    # convolution result, the kernel would need to be flipped. However, for a 
    # forward difference operator, defining the kernel for cross-correlation is more direct.
    # The kernel below is designed for cross-correlation to achieve the desired differencing.
    diff_kernel = torch.tensor([[[[-2., 1.],
                                  [ 1., 0.]]]], dtype=torch.float32)

    # 3. Apply convolution (which acts as cross-correlation)
    filtered_grid = F.conv2d(ozone_data, diff_kernel, padding='valid').squeeze()

    # 4. Determine coordinates for the new, smaller grid
    # The new grid corresponds to the anchor points of the kernel
    new_lats = unique_lats[:-1]
    new_lons = unique_lons[:-1]

    # 5. Reconstruct the output tensor
    new_lat_grid, new_lon_grid = torch.meshgrid(new_lats, new_lons, indexing='ij')
    filtered_values = filtered_grid.flatten()
    time_value = df_tensor[0, 3].repeat(filtered_values.size(0))

    new_tensor = torch.stack([
        new_lat_grid.flatten(),
        new_lon_grid.flatten(),
        filtered_values,
        time_value
    ], dim=1)
    
    return new_tensor

# =========================================================================
# 2. Data Loading (Unchanged)
# =========================================================================
# ⚠️ NOTE: You must define these variables
# mac_data_path = "..."
# year = 2022
# month_str = "01"
# class GEMS_TCO: # Placeholder
#     def load_data(self, path): return self
#     def load_working_data_byday_wo_mm(self, data, indices):
#         return {'key': torch.randn(100, 4)}, torch.randn(100, 4)

df_day_aggregated_list = []
df_day_map_list = []
for i in range(31):
    idx_for_datamap = [i*8, (i+1)*8]

    cur_map, cur_df = analysis_map_no_mm, agg_data_no_mm = data_load_instance.load_working_data(
    df_map, 
    idx_for_datamap, 
    ord_mm=None,  # or just omit it
    dtype=torch.float # or just omit it
)
    df_day_aggregated_list.append( cur_df )
    df_day_map_list.append( cur_map )


# =========================================================================
# 3. Main Processing Loop (REVISED)
# =========================================================================
spatially_filtered_day_maps = [] # This will be a list of dicts
spatially_filtered_day_aggregates = [] # This will be a list of tensors

print("Starting data filtering...")
for day_idx, day_map in enumerate(df_day_map_list):

    filtered_map_for_this_day = {}
    tensors_to_aggregate_for_this_day = []
    
    # Sort keys to ensure proper time ordering (e.g., '0', '1', '2'...)
    # Adjust this sort if your keys are not integer strings
    try:
        sorted_keys = sorted(day_map.keys(), key=lambda k: int(k))
    except ValueError:
        sorted_keys = sorted(day_map.keys()) # Fallback for non-integer keys

    for key in sorted_keys:
        tensor = day_map[key]
        subsetted = subset_tensor(tensor)
        if subsetted.size(0) > 0:
            try:
                diff_applied = apply_first_difference_2d_tensor(subsetted)
                if diff_applied.size(0) > 0:
                    # Add the filtered tensor to the map for this day
                    filtered_map_for_this_day[key] = diff_applied
                    # Add it to the list for the aggregated tensor
                    tensors_to_aggregate_for_this_day.append(diff_applied)
            except ValueError as e:
                print(f"Skipping data chunk {key} on day {day_idx+1} due to error: {e}")

    # Add the new filtered map (dict) to the list
    spatially_filtered_day_maps.append(filtered_map_for_this_day)
    
    # Add the aggregated tensor for this day to the other list
    if tensors_to_aggregate_for_this_day:
        aggregated_day_tensor = torch.cat(tensors_to_aggregate_for_this_day, dim=0)
        spatially_filtered_day_aggregates.append(aggregated_day_tensor)
    else:
        # Handle case where a day has no valid data after filtering
        print(f"Warning: Day {day_idx+1} has no data after filtering.")
        spatially_filtered_day_aggregates.append(torch.empty(0, 4, dtype=torch.float)) 

print("Data filtering complete.")

Starting data filtering...
Data filtering complete.


In [6]:
instance1 = kernels.vecchia_experiment(0.5, df_day_map_list[0], df_day_aggregated_list[0], nns_map, mm_cond_number, nheads=10)

a = [21.303, 1.307, 1.563, 0.022, -0.144, 0.198, 4.769]
#a = [30.2594, 0.665, 1.8981, 0.0, 0.1317, -0.0, 1.9785]
#a = [45.1402, 0.6299, 0.7308, -0.0003, -0.0151, 0.0, 7.8922]
#a = [21.7335, 1.2817, 1.5946, 0.042, -0.1241, 0.218, 4.8654]
#a = [20.453542336448137, 1.4506118600616982, 2.43096923637867, -0.03476556019978718, -0.1559262606484541, 0.1254833595232136, 3.938183829354925]
params = torch.tensor(a, dtype=torch.float64, requires_grad=True)

cov_map = instance1.cov_structure_saver(params, instance1.matern_cov_anisotropy_v05)  
instance1.vecchia_oct22( params, instance1.matern_cov_anisotropy_v05, cov_map )

v = 0.5 # smooth
mm_cond_number = 20
nheads = 300
lr = 0.02
step = 100
gamma_par = 0.3
epochs = 900

DELTA_LAT = 0.044  # <-- REPLACE WITH YOUR ACTUAL VALUE (e.g., unique_lats[1] - unique_lats[0])
DELTA_LON = 0.063  # <-- REPLACE WITH YOUR ACTUAL VALUE (e.g., unique_lons[1] - unique_lons[0])

In [7]:
# This list is now just for iterating
#days_list = range(len(df_day_map_list)) 

from functools import partial

days_list = [0]
for day in days_list:  
    
    # ==========================================================
    # --- ‼️ CRITICAL FIX ‼️ ---
    # Load the NEW filtered data for this day
    analysis_data_map = spatially_filtered_day_maps[day]
    aggregated_data = spatially_filtered_day_aggregates[day]

    # If this day had no data, skip it
    if aggregated_data.size(0) == 0:
        print(f"Skipping Day {day+1}, no data after filtering.")
        continue

    #a = [21.303, 1.307, 1.563, 0.022, -0.144, 0.198, 4.769]
    #a = [28.75, 0.98, 1.06, 0, 0, 0, 1.890]
    a = [28.75, 0.98, 1.06, 0.036, -0.155, 0.179, 1.890]
    params = torch.tensor(a, dtype=torch.float64, requires_grad=True)
    
    # Calculate resolution for printing
    res_calc = (200 / lat_lon_resolution[0]) * (100 / lat_lon_resolution[0])
    print(f'\n--- Starting Day {day+1} (2024-07-{day+1}) ---')
    print(f'Data size per day: { res_calc }, smooth: {v}')
    print(f'mm_cond_number: {mm_cond_number},\ninitial parameters: \n {params.detach().numpy()}')
            
    # --- Data loading is now done *before* the loop ---

    # We need to define the device (though we aren't passing it anymore)
    device_str = 'cuda' if torch.cuda.is_available() else 'cpu'

    model_instance = kernels.model_fitting(
            smooth = v,
            input_map = analysis_data_map,
            aggregated_data = aggregated_data,
            nns_map = nns_map,
            mm_cond_number = mm_cond_number,
            nheads = nheads
            # device = device_str  <--- REMOVED: This was causing the TypeError
        )

    start_time = time.time()
    # Adjusted optimizer call based on expected return values (step size changed to step)
    optimizer, scheduler = model_instance.optimizer_fun(
        params, 
        lr=lr, 
        betas=(0.9, 0.8), 
        eps=1e-8, 
        step_size=step, # Using the 'step' variable here
        gamma=gamma_par  # Using gamma_par
    ) 

    # --- CRITICAL CORRECTION ---
    # 1. We no longer need to create a separate 'instance_map'.
    #    'model_instance' is already the correct instance.
    # 2. We do NOT pre-calculate 'cov_map'. The optimized training loop
    #    'run_vecc_may9' does this internally on each epoch
    #    to ensure the gradients are correct.
    # 3. We call 'run_vecc_may9' (the optimized loop) instead of 'run_vecc_grp9'.
    #    This version does not take 'cov_map' as an argument.
    
    final_covariance_function = partial(
            model_instance.build_cov_matrix_spatial_difference_anisotropy, 
            delta1=DELTA_LAT, 
            delta2=DELTA_LON
        )


    # Calling the optimized 'run_vecc_may9'
    out, epoch_ran = model_instance.run_vecc_oct22(
        params, 
        optimizer,
        scheduler, 
        final_covariance_function, 
        epochs=epochs
    )
    # --- End Correction ---

    end_time = time.time()
    epoch_time = end_time - start_time
    print(f"Day {day+1} optimization finished in {epoch_time:.2f}s over {epoch_ran+1} epochs.")
    print(f"Day {day+1} final results: {out}")

print("\n--- All Days Processed ---")


--- Starting Day 1 (2024-07-1) ---
Data size per day: 1250.0, smooth: 0.5
mm_cond_number: 20,
initial parameters: 
 [28.75   0.98   1.06   0.036 -0.155  0.179  1.89 ]
Epoch 1, Gradients: [-196.87185871 3030.8760038  2457.8183985  -352.71639333 -819.59381364
 -252.23272776    0.        ]
 Loss: 31227.155567381862, Parameters: [28.75   0.98   1.06   0.036 -0.155  0.179  1.89 ]
Epoch 51, Gradients: [   13.78772642 -2119.456945   -1263.4687549    -70.14938054
   -74.18145216    15.86189106     0.        ]
 Loss: 20196.289473327266, Parameters: [ 2.98522138e+01  7.78091877e-02  1.35500314e-01 -5.54604269e-03
 -3.75252782e-02  6.34196943e-01  1.89000000e+00]
Epoch 101, Gradients: [   1.96599319 -257.1479175  -139.71705761   14.13488168   -7.91772452
   -0.89160506    0.        ]
 Loss: 19813.565429684313, Parameters: [ 3.01839854e+01  1.28558234e-01  2.07867132e-01 -9.13521642e-04
 -6.69047038e-02  4.18648487e-01  1.89000000e+00]
Epoch 151, Gradients: [  0.14235057 -47.31547926 -29.98173637

KeyboardInterrupt: 