In [1]:
# Configuration
gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"

# --- Standard Libraries ---
import sys
import os
import json
import time
import copy
import cmath
import pickle
import logging
import argparse

# Path configuration (only run once)
sys.path.append(gems_tco_path)

# --- Third-Party Libraries ---
from pathlib import Path
from typing import Optional, List, Tuple, Dict, Any, Callable
from json import JSONEncoder

# Data manipulation and analysis
import pandas as pd
import numpy as np
from sklearn.neighbors import BallTree
import typer

# Torch and Numerical Libraries
import torch
import torch.optim as optim
import torch.fft
import torch.nn.functional as F
from torch.nn import Parameter
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
import matplotlib.pyplot as plt 

# --- Custom (GEMS_TCO) Imports ---
import GEMS_TCO
from GEMS_TCO import kernels_reparam_space_time 
from GEMS_TCO import data_preprocess, data_preprocess as dmbh
from GEMS_TCO import orderings as _orderings 

from GEMS_TCO import alg_optimization, alg_opt_Encoder
from GEMS_TCO import configuration as config
from GEMS_TCO.data_loader import load_data2
from GEMS_TCO import debiased_whittle

load monthly data

In [2]:
space: List[str] = ['1', '1']
lat_lon_resolution = [int(s) for s in space]
mm_cond_number: int = 8
years = ['2024']
month_range = [7] 

output_path = input_path = Path(config.mac_estimates_day_path)
data_load_instance = load_data2(config.mac_data_load_path)

#lat_range_input = [1, 3]
#lon_range_input = [125.0, 129.0]

lat_range_input=[3,5]      
lon_range_input=[129, 133.0] 

df_map, ord_mm, nns_map = data_load_instance.load_maxmin_ordered_data_bymonthyear(
lat_lon_resolution=lat_lon_resolution, 
mm_cond_number=mm_cond_number,
years_=years, 
months_=month_range,

lat_range=lat_range_input,   
lon_range=lon_range_input

)

Subsetting data to lat: [3, 5], lon: [129, 133.0]


In [3]:
daily_aggregated_tensors_dw = [] 
daily_hourly_maps_dw = []      

daily_aggregated_tensors_vecc = [] 
daily_hourly_maps_vecc = []   


for day_index in range(31):
    hour_start_index = day_index * 8
    hour_end_index = (day_index + 1) * 8
    #hour_end_index = day_index*8 + 1
    hour_indices = [hour_start_index, hour_end_index]

    day_hourly_map, day_aggregated_tensor = data_load_instance.load_working_data(
    df_map, 
    hour_indices, 
    ord_mm= None,  # or just omit it
    dtype=torch.float64, # or just omit it 
    keep_ori=False  #keep_exact_loc
    )

    daily_aggregated_tensors_dw.append( day_aggregated_tensor )
    daily_hourly_maps_dw.append( day_hourly_map )

    day_hourly_map, day_aggregated_tensor = data_load_instance.load_working_data(
    df_map, 
    hour_indices, 
    ord_mm= ord_mm,  # or just omit it
    dtype=torch.float64, # or just omit it 
    keep_ori=False  #keep_exact_loc
    )

    daily_aggregated_tensors_vecc.append( day_aggregated_tensor )
    daily_hourly_maps_vecc.append( day_hourly_map )
print(daily_aggregated_tensors_vecc[0].shape)
#print(daily_hourly_maps[0])
nn = daily_aggregated_tensors_vecc[0].shape[0]

torch.Size([23040, 4])


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fftshift
import cmath

# =========================================================================
# 1. Tapering and Modeling Functions
# =========================================================================
def cgn_hamming(u, n1, n2):
    """Computes a 2D Hamming window."""
    u1, u2 = u
    u1_tensor = u1 if isinstance(u1, torch.Tensor) else torch.tensor(u1, dtype=torch.float32)
    u2_tensor = u2 if isinstance(u2, torch.Tensor) else torch.tensor(u2, dtype=torch.float32)
    hamming1 = 0.54 + 0.46 * torch.cos(2 * torch.pi * u1_tensor / n1)
    hamming2 = 0.54 + 0.46 * torch.cos(2 * torch.pi * u2_tensor / n2)
    return hamming1 * hamming2

def cgn_2dbartlett(u, n1, n2):
    """Computes a 2D Bartlett window function (triangular window)."""
    u1, u2 = u
    return (1 - torch.abs(u1) / n1) * (1 - torch.abs(u2) / n2) 

# ✅ CHANGE: Replaced the old cov_x with the new, pure PyTorch Matérn model.
def cov_x(u1, u2, t, params, nu):
    """
    Covariance function for PLOTTING using the Matérn model (nu=1.5).
    It expects parameters in their natural space.
    """
    if nu != 1.5:
        raise ValueError("This implementation of cov_x is only for nu=1.5")

    sigmasq, range_lat, range_lon, advec_lat, advec_lon, beta, nugget = params
    
    distance_sq = (u1 / range_lat - advec_lat * t)**2 + (u2 / range_lon - advec_lon * t)**2 + (beta * t)**2
    distance = torch.sqrt(distance_sq + 1e-9)
    
    matern_cov = sigmasq * (1 + distance) * torch.exp(-distance)
    
    return torch.where(distance_sq > 1e-9, matern_cov, sigmasq + nugget)

def cov_first_difference(u1, u2, t, params, nu):
    """Computes the autocovariance of the first-difference-filtered process."""
    delta1, delta2 = 0.044, 0.063
    stencil_weights = {(0, 0): -2, (1, 0): 1, (0, 1): 1}
    cov = torch.zeros_like(u1, dtype=torch.float32)
    for (a, b), w_ab in stencil_weights.items():
        for (c, d), w_cd in stencil_weights.items():
            lag_x = (u1 + a - c) * delta1
            lag_y = (u2 + b - d) * delta2
            cov += w_ab * w_cd * cov_x(lag_x, lag_y, t, params, nu)
    return cov

def cn_bar(u1, u2, t, params, n1, n2, taper_func, nu):
    """Computes the tapered covariance for the filtered field."""
    u = (u1, u2)
    return cov_first_difference(u1, u2, t, params, nu) * taper_func(u, n1, n2)

def expected_periodogram_fft_multivariate(params, n1, n2, p, taper_func, nu):
    """Calculates the expected periodogram for the FILTERED field."""
    product_tensor = torch.zeros((n1, n2, p, p), dtype=torch.complex64)
    t_lags = torch.arange(p, dtype=torch.float32)
    u1_mesh, u2_mesh = torch.meshgrid(torch.arange(n1, dtype=torch.float32), torch.arange(n2, dtype=torch.float32), indexing='ij')

    for q in range(p):
        for r in range(p):
            t = t_lags[q] - t_lags[r]
            product_tensor[:, :, q, r] = cn_bar(u1_mesh, u2_mesh, t, params, n1, n2, taper_func, nu)

    fft_result = torch.fft.fft2(product_tensor, dim=(-4, -3))
    normalization_factor = 1 / (2 * cmath.pi)**2
    return fft_result * normalization_factor

# =========================================================================
# 2. Data Processing Functions
# =========================================================================
def generate_Jvector_final(tensor_list, taper_func, lat_col=0, lon_col=1, val_col=2):
    p = len(tensor_list)
    if p == 0:
        return torch.empty(0, 0, 0), 0, 0, 0

    all_lats = torch.cat([t[:, lat_col] for t in tensor_list])
    all_lons = torch.cat([t[:, lon_col] for t in tensor_list])
    unique_lats, unique_lons = torch.unique(all_lats), torch.unique(all_lons)
    n1, n2 = len(unique_lats), len(unique_lons)

    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats)}
    lon_map = {lon.item(): i for i, lon in enumerate(unique_lons)}

    u1_mesh, u2_mesh = torch.meshgrid(torch.arange(n1, dtype=torch.float32), torch.arange(n2, dtype=torch.float32), indexing='ij')
    taper_grid = taper_func((u1_mesh, u2_mesh), n1, n2)

    fft_results = []
    for tensor in tensor_list:
        data_grid = torch.zeros((n1, n2), dtype=torch.float32)
        for row in tensor:
             i = lat_map[row[lat_col].item()]
             j = lon_map[row[lon_col].item()]
             data_grid[i, j] = row[val_col]
        
        data_grid_tapered = data_grid * taper_grid
        fft_results.append(torch.fft.fft2(data_grid_tapered))

    J_vector_tensor = torch.stack(fft_results, dim=2)

    H = torch.sum(taper_grid**2)
    norm_factor = torch.sqrt(1 / H) / (2 * cmath.pi)
    
    return J_vector_tensor * norm_factor, n1, n2, p


def calculate_sample_periodogram_vectorized(J_vector_tensor):
    """Efficient vectorized version."""
    J_col = J_vector_tensor.unsqueeze(-1)
    J_row_conj = J_vector_tensor.unsqueeze(-2).conj()
    return J_col @ J_row_conj

# =========================================================================
# 3. Main Execution Logic for Plotting
# =========================================================================
if __name__ == '__main__':
    # --- IMPORTANT ---
    # Before running, ensure your `processed_df_diff1` list of tensors is loaded.
    
    # --- Configuration ---
    DAY_TO_RUN = 1
    TAMPERING_FUNC = cgn_2dbartlett
    # ✅ Make sure this matches the nu used during training
    NU_SMOOTHNESS = 1.5
    
    print(f"--- Day {DAY_TO_RUN} Diagnostic Plot for Matérn (nu=1.5) Model ---")
    
    # --- Define the parameters you want to test ---
    # ✅ Use the final, optimized parameters from your Matérn model training run
    best_params = [22.54, 3.00, 4.00, 0.004, -0.08, 0.01, 2.36] # Example parameters
    print(f"Using pre-defined parameters: {best_params}")
    
    # --- Data Preparation ---
    cur_df = processed_df_diff1[DAY_TO_RUN - 1]
    unique_times = torch.unique(cur_df[:, 3])
    tensor_list = [cur_df[cur_df[:, 3] == t_val] for t_val in unique_times]
    
    # --- Calculation ---
    J_vec, n1, n2, p = generate_Jvector_final(tensor_list, TAMPERING_FUNC, lat_col=0, lon_col=1, val_col=2)
    print(f"Data grid: {n1}x{n2} spatial points, {p} time points.")
    
    I_sample = calculate_sample_periodogram_vectorized(J_vec)
    # ✅ Pass the nu parameter to the calculation
    I_expected = expected_periodogram_fft_multivariate(best_params, n1, n2, p, TAMPERING_FUNC, nu=NU_SMOOTHNESS)

    I_sample_shifted = torch.fft.fftshift(I_sample, dim=(-4, -3))
    I_expected_shifted = torch.fft.fftshift(I_expected, dim=(-4, -3))
    
    epsilon = 1e-15
    ratio_complex = I_sample_shifted / (I_expected_shifted + epsilon)
    
    diagonal_ratio = torch.diagonal(ratio_complex, dim1=-2, dim2=-1) 
    ratio_magnitude_avg = torch.mean(torch.abs(diagonal_ratio), dim=-1)
    
    # --- Plotting Data Preparation ---
    freq_lat = np.fft.fftfreq(n1)
    freq_lon = np.fft.fftfreq(n2)
    freq_lon_shifted, freq_lat_shifted = fftshift(freq_lon), fftshift(freq_lat)
    freq_grid_lon, freq_grid_lat = np.meshgrid(freq_lon_shifted, freq_lat_shifted)
    frequency_norm = np.sqrt(freq_grid_lat**2 + freq_grid_lon**2)
    
    norm_flat = frequency_norm.flatten()
    ratio_flat = ratio_magnitude_avg.detach().numpy().flatten()

    # --- Plotting ---
    plt.figure(figsize=(12, 7))
    plt.scatter(norm_flat, ratio_flat, s=10, alpha=0.6)
    plt.axhline(1.0, color='r', linestyle='--', linewidth=2, label='Ideal Ratio (1.0)')
    plt.yscale('log')
    plt.ylim(1e-2, 1e2)
    plt.title(f'Diagnostic Plot for Matérn Model (nu={NU_SMOOTHNESS}, Day {DAY_TO_RUN})')
    plt.xlabel(r'Frequency Norm $||\boldsymbol{\lambda}||$')
    plt.ylabel(r'Average Diagonal Ratio $|\mathbf{I} / E[\mathbf{I}]|$ (Log Scale)')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    
    print(f"\nRatio Statistics: Max={np.max(ratio_flat):.2e}, Min={np.min(ratio_flat):.2e}, Mean={np.mean(ratio_flat):.2f}")
    plt.show()