In [7]:
# Standard libraries
import sys
# Add your custom path
gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"
sys.path.append(gems_tco_path)
import logging
import argparse # Argument parsing

# Data manipulation and analysis
import pandas as pd
import numpy as np
import pickle
import torch
import torch.optim as optim
import copy                    # clone tensor
import time

# Custom imports
import GEMS_TCO
from GEMS_TCO import kernels

from GEMS_TCO import kernels 
from GEMS_TCO import orderings as _orderings 
from GEMS_TCO import load_data
from GEMS_TCO import alg_optimization, alg_opt_Encoder
from GEMS_TCO import configuration as config

from typing import Optional, List, Tuple
from pathlib import Path
import typer
import json
from json import JSONEncoder

from GEMS_TCO import configuration as config
from GEMS_TCO import data_preprocess as dmbh

import os
from sklearn.neighbors import BallTree

from GEMS_TCO.data_loader import load_data2

In [8]:
import pickle
import os
# Assume your 'config' object is available
# import config

# --- 1. Configuration ---
# Specify the year and month you want to load
YEAR_TO_LOAD = 2024
MONTH_TO_LOAD = 7

# Use the same base path as your saving script
BASE_PATH = config.mac_data_load_path

# --- 2. Construct the File Path ---
# This must exactly match the naming convention from your saving script
month_str = f"{MONTH_TO_LOAD:02d}"
pickle_path = os.path.join(BASE_PATH, f'pickle_{YEAR_TO_LOAD}')
filename = f"coarse_cen_map_without_decrement_latitude{str(YEAR_TO_LOAD)[2:]}_{month_str}.pkl"
filepath_to_load = os.path.join(pickle_path, filename)

print(f"Attempting to load data from: {filepath_to_load}")

# --- 3. Load the Data ---
try:
    with open(filepath_to_load, 'rb') as pickle_file:
        # Use pickle.load() to read the data from the file
        loaded_coarse_map = pickle.load(pickle_file)
    
    print("\nData loaded successfully! ✅")
    
    # --- 4. Verify the Loaded Data ---
    # The loaded data is a dictionary. Let's inspect it.
    print(f"Type of loaded data: {type(loaded_coarse_map)}")
    if isinstance(loaded_coarse_map, dict):
        print(f"Number of entries (hours) in the map: {len(loaded_coarse_map)}")
        # Print the first 5 keys to see what they look like
        first_five_keys = list(loaded_coarse_map.keys())[:5]
        print(f"Example keys: {first_five_keys}")
        
        # You can now access the data for a specific hour, for example:
        # first_hour_data = loaded_coarse_map[first_five_keys[0]]
        # print(f"\nData for first hour is a tensor of shape: {first_hour_data.shape}")

except FileNotFoundError:
    print(f"\nError: File not found. Please check if the file exists at the specified path.")
except Exception as e:
    print(f"\nAn error occurred: {e}")


print(loaded_coarse_map['y24m07day01_hm00:53']['Longitude'].nunique())
print(loaded_coarse_map['y24m07day01_hm00:53']['Latitude'].nunique())

import GEMS_TCO
load_data_instance = GEMS_TCO.load_data('')

df_day_aggregated_list = []
df_day_map_list = []
for i in range(31):
    cur_map, cur_df =load_data_instance.load_working_data_byday_wo_mm(loaded_coarse_map,[i*8, (i+1)*8])
    df_day_aggregated_list.append( cur_df )
    df_day_map_list.append( cur_map )
mac_data_path = config.mac_data_load_path

v05_base_path = Path("/Users/joonwonlee/Documents/GEMS_TCO-1/outputs/day/estimates/df_cv_smooth_05/")
full_day_v05_r2s10_1127 = pd.read_csv(v05_base_path / "full_day_v05_r2s10_1127.csv")
vecchia_v05_r2s10_1127 = pd.read_csv( v05_base_path / "vecchia_v05_r2s10_1127.csv")
vecchia_v05_r2s10_4508 = pd.read_csv( v05_base_path / "vecchia_v05_r2s10_4508.csv")
vecchia_v05_r2s10_18033 = pd.read_csv( v05_base_path / "vecchia_v05_r2s10_18033.csv")

Attempting to load data from: /Users/joonwonlee/Documents/GEMS_DATA/pickle_2024/coarse_cen_map_without_decrement_latitude24_07.pkl

Data loaded successfully! ✅
Type of loaded data: <class 'dict'>
Number of entries (hours) in the map: 248
Example keys: ['y24m07day01_hm00:53', 'y24m07day01_hm01:53', 'y24m07day01_hm02:53', 'y24m07day01_hm03:53', 'y24m07day01_hm04:49']
270
273


## Once differencing in both space and then another differencing in  time

### Models the temporal change of the spatial curvature (or gradient).

In [9]:
import torch
import numpy as np
import torch.nn.functional as F
import os
import pickle

# --- Helper Functions (REVISED FOR CONVOLUTION) ---

def subset_tensor(df_tensor: torch.Tensor) -> torch.Tensor:
    """
    Subsets a tensor to a specific lat/lon range.
    Columns are assumed to be [lat, lon, ozone, time].
    """
    lat_mask = (df_tensor[:, 0] >= 0) & (df_tensor[:, 0] <= 5)
    lon_mask = (df_tensor[:, 1] >= 123) & (df_tensor[:, 1] <= 133)
    
    df_sub = df_tensor[lat_mask & lon_mask].clone()
    return df_sub

def apply_spatial_diff_convolution(df_tensor: torch.Tensor) -> torch.Tensor:
    """
    Applies the first-order spatial difference Z(s) = [X(s+d_lat) - X(s)] + [X(s+d_lon) - X(s)]
    using a 2D convolution, assuming the input tensor is a complete grid (non-sparse).
    """
    if df_tensor.size(0) == 0:
        return torch.empty(0, 4, dtype=df_tensor.dtype)

    # 1. Get grid dimensions and enforce non-sparse grid constraint
    unique_lats = torch.unique(df_tensor[:, 0])
    unique_lons = torch.unique(df_tensor[:, 1])
    lat_count, lon_count = unique_lats.size(0), unique_lons.size(0)

    if df_tensor.size(0) != lat_count * lon_count:
        # Since you confirmed data is not sparse, this should not trigger.
        raise ValueError("Tensor size does not match grid dimensions. Must be a complete grid for convolution.")
    if lat_count < 2 or lon_count < 2:
        return torch.empty(0, 4, dtype=df_tensor.dtype)

    # 2. Map coordinates to indices and Reshape data (Ozone values)
    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats)}
    lon_map = {lon.item(): i for i, lon in enumerate(unique_lons)}

    ozone_grid = torch.zeros((lat_count, lon_count), dtype=df_tensor.dtype)
    for row in df_tensor:
        i = lat_map[row[0].item()]
        j = lon_map[row[1].item()]
        ozone_grid[i, j] = row[2]
        
    ozone_data = ozone_grid.reshape(1, 1, lat_count, lon_count)
    
    # Kernel for Z(i,j) = X(i+1,j) + X(i,j+1) - 2*X(i,j)
    # This assumes the first dimension is latitude (i) and the second is longitude (j)
    diff_kernel = torch.tensor([[[[-2., 1.],
                                  [ 1., 0.]]]], dtype=df_tensor.dtype)

    # 3. Apply convolution (cross-correlation)
    filtered_grid = F.conv2d(ozone_data, diff_kernel, padding='valid').squeeze()

    # 4. Determine coordinates for the new, smaller grid
    # The new grid corresponds to the anchor points (top-left of the kernel)
    new_lats = unique_lats[:-1]
    new_lons = unique_lons[:-1]

    # 5. Reconstruct the output tensor
    new_lat_grid, new_lon_grid = torch.meshgrid(new_lats, new_lons, indexing='ij')
    filtered_values = filtered_grid.flatten()
    time_value = df_tensor[0, 3].repeat(filtered_values.size(0))

    new_tensor = torch.stack([
        new_lat_grid.flatten(),
        new_lon_grid.flatten(),
        filtered_values,
        time_value
    ], dim=1)
    
    return new_tensor


# ----------------------------------------------------------------------
# --- Data Loading (Kept structure, placeholder variables must be defined) ---
# ----------------------------------------------------------------------
# ⚠️ NOTE: You must define these variables in your environment
# mac_data_path = "..."
# year = 2022
# month_str = "01"
# class GEMS_TCO: # Placeholder
#     def load_data(self, path): return self
#     def load_working_data_byday_wo_mm(self, data, indices):
#         return {'key': torch.randn(100, 4)}, torch.randn(100, 4)

# (Assuming data loading variables are defined...)
# NOTE: Removed the try/except block to keep the data loading structure clean as requested,
# but ensure 'mac_data_path', 'year', 'month_str', and 'GEMS_TCO' are defined externally.

pickle_path = os.path.join(mac_data_path, f'pickle_{YEAR_TO_LOAD}')
output_filename = f"coarse_cen_map_without_decrement_latitude{str(YEAR_TO_LOAD)[2:]}_{month_str}.pkl"
output_filepath = os.path.join(pickle_path, output_filename)
print(f"Loading data from: {output_filepath}")

with open(output_filepath, 'rb') as pickle_file:
    cbmap_ori = pickle.load(pickle_file)

load_data_instance = GEMS_TCO.load_data('')
df_day_map_list = []
for i in range(31): # Adjust if necessary
    cur_map, _ = load_data_instance.load_working_data_byday_wo_mm(cbmap_ori, [i*8, (i+1)*8])
    df_day_map_list.append(cur_map)
print(f"Loaded {len(df_day_map_list)} days of raw data.")

# ----------------------------------------------------------------------
# --- Main Processing Loop (STAGE 1 uses convolution function) ---
# ----------------------------------------------------------------------

# ✅ STAGE 1: Apply the spatial filter to each day independently.
spatially_filtered_days = []

print("Starting STAGE 1: Spatial Differencing (Convolution)...")
for day_idx, day_map in enumerate(df_day_map_list):
    tensors_to_aggregate = []
    
    for key, tensor in day_map.items():
        subsetted = subset_tensor(tensor)
        
        if subsetted.size(0) > 0:
            try:
                # --- ✅ CALLING THE NEW CONVOLUTION FUNCTION ---
                diff_applied = apply_spatial_diff_convolution(subsetted)
                
                if diff_applied.size(0) > 0:
                    tensors_to_aggregate.append(diff_applied)
            except ValueError as e:
                # This catches incomplete grid chunks or chunks with < 2 lats/lons
                print(f"Skipping chunk on day {day_idx+1}, key {key}: {e}")

    if tensors_to_aggregate:
        aggregated_day_tensor = torch.cat(tensors_to_aggregate, dim=0)
        spatially_filtered_days.append(aggregated_day_tensor)
print(f"STAGE 1 Complete. Created {len(spatially_filtered_days)} spatially filtered day-tensors.")

# ----------------------------------------------------------------------
# ✅ STAGE 2: Apply the temporal first difference (value_t - value_t-1).
# ----------------------------------------------------------------------
spacetime_diff_tensors = []

print("Starting STAGE 2: Temporal Differencing...")
if len(spatially_filtered_days) > 1:
    for i in range(1, len(spatially_filtered_days)):
        prev_day_tensor = spatially_filtered_days[i-1]
        current_day_tensor = spatially_filtered_days[i]
        
        # Round keys to avoid floating point mismatches
        prev_day_lookup = {
            (round(row[0].item(), 5), round(row[1].item(), 5)): row[2].item() 
            for row in prev_day_tensor
        }
        
        temporally_differenced_rows = []
        for row in current_day_tensor:
            lat = round(row[0].item(), 5)
            lon = round(row[1].item(), 5)
            
            # This lookup is more robust now because Stage 1 ensures a consistent grid
            if (lat, lon) in prev_day_lookup:
                current_ozone = row[2].item()
                prev_ozone = prev_day_lookup[(lat, lon)]
                
                time_diff_ozone = current_ozone - prev_ozone
                
                # Keep dtype consistent with the source row
                new_row = torch.tensor([lat, lon, time_diff_ozone, row[3]], dtype=row.dtype) 
                temporally_differenced_rows.append(new_row)
        
        if temporally_differenced_rows:
            day_diff_tensor = torch.stack(temporally_differenced_rows, dim=0)
            spacetime_diff_tensors.append(day_diff_tensor)
print(f"STAGE 2 Complete. Created {len(spacetime_diff_tensors)} final tensors.")

# --- Verification ---
print("\n--- Results ---")
print("Number of spatially filtered day tensors:", len(spatially_filtered_days))
print("Number of final spatio-temporally differenced tensors:", len(spacetime_diff_tensors))

if spacetime_diff_tensors:
    # Save the processed data
    processed_output_path = "spacetime_differenced_data.pkl"
    with open(processed_output_path, 'wb') as f:
        pickle.dump(spacetime_diff_tensors, f)
    print(f"Processed data saved to {processed_output_path}")

    print("\nShape of the first final tensor:", spacetime_diff_tensors[0].shape)
    print("First final tensor head:")
    print(spacetime_diff_tensors[0][:5])
else:
    print("\nNo final differenced tensors were created. Check data or filter logic.")

Loading data from: /Users/joonwonlee/Documents/GEMS_DATA/pickle_2024/coarse_cen_map_without_decrement_latitude24_07.pkl
Loaded 31 days of raw data.
Starting STAGE 1: Spatial Differencing (Convolution)...
STAGE 1 Complete. Created 31 spatially filtered day-tensors.
Starting STAGE 2: Temporal Differencing...
STAGE 2 Complete. Created 30 final tensors.

--- Results ---
Number of spatially filtered day tensors: 31
Number of final spatio-temporally differenced tensors: 30
Processed data saved to spacetime_differenced_data.pkl

Shape of the first final tensor: torch.Size([142832, 4])
First final tensor head:
tensor([[ 4.0000e-03,  1.2303e+02, -1.6478e+01,  4.5000e+01],
        [ 4.0000e-03,  1.2309e+02,  2.6450e+00,  4.5000e+01],
        [ 4.0000e-03,  1.2316e+02, -8.7633e+00,  4.5000e+01],
        [ 4.0000e-03,  1.2322e+02,  1.6463e+01,  4.5000e+01],
        [ 4.0000e-03,  1.2328e+02,  1.0886e+01,  4.5000e+01]])


In [12]:
import torch
import numpy as np
import matplotlib.pyplot as plt # Keep if plotting might be added later
import cmath
import pickle
import time # For timing
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import Parameter
import pandas as pd # Make sure pandas is imported
import os # Make sure os is imported

# =========================================================================
# 1. Modeling Functions (Adapted for Spatio-Temporal Differencing)
# =========================================================================

# --- Bartlett Kernel (Used for c_gn when g_s=1) ---
def cgn_2dbartlett_kernel(u1, u2, n1, n2):
    """
    Computes the 2D Bartlett kernel: Product(1 - |ui|/ni). (Unchanged)
    """
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    kernel = (1.0 - torch.abs(u1_tensor) / n1_eff) * (1.0 - torch.abs(u2_tensor) / n2_eff)
    return torch.clamp(kernel, min=0.0)

# --- Covariance of the Original Field X (EXPONENTIAL Kernel) ---
def cov_x_exponential(u1, u2, t, params):
    """
    Computes the autocovariance of the ORIGINAL process X. (Unchanged)
    Expects log-scale params [0,1,2,6].
    """
    device = params.device 
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    log_params_indices = [0, 1, 2, 6]
    if torch.isnan(params[log_params_indices]).any() or torch.isinf(params[log_params_indices]).any():
         out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
         return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    sigmasq, r_lat, r_lon, nugget = torch.exp(params[log_params_indices])
    a_lat, a_lon, beta = params[3], params[4], params[5]

    r_lat = torch.clamp(r_lat, min=1e-6)
    r_lon = torch.clamp(r_lon, min=1e-6)

    x1 = u1_dev / r_lat - a_lat * t_dev
    x2 = u2_dev / r_lon - a_lon * t_dev
    x3 = beta * t_dev
    distance_sq = x1**2 + x2**2 + x3**2
    epsilon = 1e-12
    distance_sq_clamped = torch.clamp(distance_sq, min=0.0)
    D = torch.sqrt(distance_sq_clamped + epsilon) 
    cov_smooth = sigmasq * torch.exp(-D) 

    is_zero_lag = (torch.abs(u1_dev) < 1e-9) & (torch.abs(u2_dev) < 1e-9) & (torch.abs(t_dev) < 1e-9)
    final_cov = torch.where(is_zero_lag, cov_smooth + nugget, cov_smooth)

    if torch.isnan(final_cov).any():
        print("Warning: NaN detected in cov_x_exponential output.")
    return final_cov


# --- Covariance of the Spatially Differenced Field Z ---
def cov_spatial_difference(u1, u2, t, params, delta1, delta2):
    """
    Calculates covariance Cov(Z(s), Z(s+u)) for the SPATIAL-ONLY filter:
    Z(s) = X(s+d1) + X(s+d2) - 2X(s). (Unchanged)
    """
    weights = {(0, 0): -2.0, (1, 0): 1.0, (0, 1): 1.0}
    device = params.device
    out_shape = torch.broadcast_shapes(u1.shape if isinstance(u1, torch.Tensor) else (),
                                     u2.shape if isinstance(u2, torch.Tensor) else (),
                                     t.shape if isinstance(t, torch.Tensor) else ())
    cov = torch.zeros(out_shape, device=device, dtype=torch.float32)

    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    for (a_idx, b_idx), w_ab in weights.items():
        offset_a1 = a_idx * delta1
        offset_a2 = b_idx * delta2
        for (c_idx, d_idx), w_cd in weights.items():
            offset_c1 = c_idx * delta1
            offset_c2 = d_idx * delta2
            lag_u1 = u1_dev + (offset_a1 - offset_c1)
            lag_u2 = u2_dev + (offset_a2 - offset_c2)
            term_cov = cov_x_exponential(lag_u1, lag_u2, t_dev, params)
            if torch.isnan(term_cov).any():
                 return torch.full_like(cov, float('nan'))
            cov += w_ab * w_cd * term_cov

    if torch.isnan(cov).any():
        print("Warning: NaN detected in final cov_spatial_difference output.")
    return cov

# --- (NEW) Covariance of Spatio-Temporal Differenced Field Y ---
def cov_spacetime_difference(u1, u2, t, params, delta1, delta2):
    """
    ✅ Calculates covariance for the new Spatio-Temporal filter:
    Y(s,t) = Z(s,t) - Z(s,t-1), where Z is the spatially-differenced field.
    This is C_Y(u, tau) = 2*C_Z(u, tau) - C_Z(u, tau-1) - C_Z(u, tau+1)
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    # C_Z(u, tau)
    term_center = cov_spatial_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)
    # C_Z(u, tau - 1)
    term_minus_1 = cov_spatial_difference(u1_dev, u2_dev, t_dev - 1.0, params, delta1, delta2)
    # C_Z(u, tau + 1)
    term_plus_1 = cov_spatial_difference(u1_dev, u2_dev, t_dev + 1.0, params, delta1, delta2)

    if torch.isnan(term_center).any() or torch.isnan(term_minus_1).any() or torch.isnan(term_plus_1).any():
        print("Warning: NaN detected in one of the terms of cov_spacetime_difference")
        out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    return 2.0 * term_center - term_minus_1 - term_plus_1


# --- (MODIFIED) cn_bar for NO TAPERING ---
def cn_bar_no_taper(u1, u2, t, params, n1, n2, delta1, delta2):
    """
    Computes c_Y(u) * c_gn(u) where c_Y is cov_spacetime_difference
    and c_gn(u) is the Bartlett kernel.
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    # ✅ --- Call the new spatio-temporal covariance function ---
    cov_Y_value = cov_spacetime_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)

    c_gn_value = cgn_2dbartlett_kernel(u1_dev, u2_dev, n1, n2)

    if torch.isnan(cov_Y_value).any() or torch.isnan(c_gn_value).any():
        print("Warning: NaN detected before multiplication in cn_bar_no_taper.")
        out_shape = torch.broadcast_shapes(cov_Y_value.shape, c_gn_value.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    result = cov_Y_value * c_gn_value
    if torch.isnan(result).any():
        print("Warning: NaN detected after multiplication in cn_bar_no_taper.")
    return result


# --- Expected Periodogram (uses cn_bar_no_taper) ---
def expected_periodogram_fft_no_taper(params, n1, n2, p, delta1, delta2):
    """
    Calculates the expected periodogram. (Unchanged)
    This function is correct because it calls the modified cn_bar_no_taper.
    """
    device = params.device if isinstance(params, torch.Tensor) else params[0].device
    params_tensor = params.to(device)

    product_tensor = torch.zeros((n1, n2, p, p), dtype=torch.complex64, device=device)
    t_lags = torch.arange(p, dtype=torch.float32, device=device)
    u1_mesh_grid, u2_mesh_grid = torch.meshgrid(
        torch.arange(n1, dtype=torch.float32, device=device),
        torch.arange(n2, dtype=torch.float32, device=device),
        indexing='ij'
    )

    for q in range(p):
        for r in range(p):
            t_diff = t_lags[q] - t_lags[r]
            # This call now correctly leads to cov_spacetime_difference
            cov_times_bartlett = cn_bar_no_taper(
                u1_mesh_grid, u2_mesh_grid, t_diff,
                params_tensor, n1, n2, delta1, delta2
            )
            if torch.isnan(cov_times_bartlett).any():
                 product_tensor[:, :, q, r] = float('nan')
            else:
                 product_tensor[:, :, q, r] = cov_times_bartlett.to(torch.complex64)

    if torch.isnan(product_tensor).any():
        print("Warning: NaN detected in product_tensor before FFT.")
        nan_shape = (n1, n2, p, p)
        return torch.full(nan_shape, float('nan'), dtype=torch.complex64, device=device)

    fft_result = torch.fft.fft2(product_tensor, dim=(0, 1))
    normalization_factor = 1.0 / (4.0 * cmath.pi**2)
    result = fft_result * normalization_factor

    if torch.isnan(result).any():
        print("Warning: NaN detected in expected_periodogram_fft_no_taper output after FFT.")
    return result


# =========================================================================
# 2. Data Processing (Unchanged)
# =========================================================================
def generate_Jvector_no_taper(tensor_list, lat_col, lon_col, val_col, device):
    """
    Generates J-vector for g_s=1 (NO taper). (Unchanged)
    """
    p = len(tensor_list)
    if p == 0: return torch.empty(0, 0, 0, device=device), 0, 0, 0

    valid_tensors = [t for t in tensor_list if t.numel() > 0 and t.shape[1] > max(lat_col, lon_col, val_col)]
    if not valid_tensors:
         return torch.empty(0, 0, 0, device=device), 0, 0, 0

    try:
        all_lats_cpu = torch.cat([t[:, lat_col] for t in valid_tensors])
        all_lons_cpu = torch.cat([t[:, lon_col] for t in valid_tensors])
    except IndexError:
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    all_lats_cpu = all_lats_cpu[~torch.isnan(all_lats_cpu) & ~torch.isinf(all_lats_cpu)]
    all_lons_cpu = all_lons_cpu[~torch.isnan(all_lons_cpu) & ~torch.isinf(all_lons_cpu)]
    if all_lats_cpu.numel() == 0 or all_lons_cpu.numel() == 0:
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    unique_lats_cpu, unique_lons_cpu = torch.unique(all_lats_cpu), torch.unique(all_lons_cpu)
    n1, n2 = len(unique_lats_cpu), len(unique_lons_cpu)
    if n1 == 0 or n2 == 0:
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats_cpu)}
    lon_map = {lon.item(): i for i, lon in enumerate(unique_lons_cpu)}

    fft_results = []
    for tensor in tensor_list:
        data_grid = torch.zeros((n1, n2), dtype=torch.float32, device=device)
        # Move tensor to device before iterating
        tensor_dev = tensor.to(device)
        for row in tensor_dev:
            lat_item, lon_item = row[lat_col].item(), row[lon_col].item()
            if not (np.isnan(lat_item) or np.isnan(lon_item)):
                i = lat_map.get(lat_item)
                j = lon_map.get(lon_item)
                if i is not None and j is not None:
                    val = row[val_col]
                    val_num = val.item() if isinstance(val, torch.Tensor) else val
                    if not np.isnan(val_num) and not np.isinf(val_num):
                        data_grid[i, j] = val_num

        data_grid = torch.nan_to_num(data_grid, nan=0.0, posinf=0.0, neginf=0.0)
        fft_results.append(torch.fft.fft2(data_grid))

    if not fft_results:
         return torch.empty(0, 0, 0, device=device), n1, n2, 0

    J_vector_tensor = torch.stack(fft_results, dim=2).to(device)

    H = float(n1 * n2)
    if H < 1e-9:
        norm_factor = torch.tensor(0.0, device=device)
    else:
        norm_factor = (torch.sqrt(torch.tensor(1.0 / H, device=device)) / (2.0 * cmath.pi))

    result = J_vector_tensor * norm_factor
    return result, n1, n2, p


def calculate_sample_periodogram_vectorized(J_vector_tensor):
    """Calculates sample periodogram I_n = J J^H. (Unchanged)"""
    if torch.isnan(J_vector_tensor).any() or torch.isinf(J_vector_tensor).any():
        n1, n2, p = J_vector_tensor.shape
        return torch.full((n1, n2, p, p), float('nan'), dtype=torch.complex64, device=J_vector_tensor.device)

    J_col = J_vector_tensor.unsqueeze(-1)
    J_row_conj = J_vector_tensor.unsqueeze(-2).conj()
    result = J_col @ J_row_conj
    return result


# =========================================================================
# 4. Likelihood Calculation (Unchanged)
# =========================================================================

def whittle_likelihood_loss_no_taper(params, I_sample, n1, n2, p, delta1, delta2):
    """
    Whittle Likelihood Loss. (Unchanged)
    """
    device = I_sample.device
    params_tensor = params.to(device)

    if torch.isnan(params_tensor).any() or torch.isinf(params_tensor).any():
        return torch.tensor(float('nan'), device=device)

    # This call now correctly leads to cov_spacetime_difference
    I_expected = expected_periodogram_fft_no_taper(
        params_tensor, n1, n2, p, delta1, delta2
    )

    if torch.isnan(I_expected).any() or torch.isinf(I_expected).any():
        return torch.tensor(float('nan'), device=device)

    eye_matrix = torch.eye(p, dtype=torch.complex64, device=device)
    diag_vals = torch.abs(I_expected.diagonal(dim1=-2, dim2=-1))
    mean_diag_abs = diag_vals.mean().item() if diag_vals.numel() > 0 and not torch.isnan(diag_vals).all() else 1.0
    diag_load = max(mean_diag_abs * 1e-8, 1e-9) 
    
    I_expected_stable = I_expected + eye_matrix * diag_load

    sign, logabsdet = torch.linalg.slogdet(I_expected_stable)
    if torch.any(sign.real <= 1e-9):
        log_det_term = torch.where(sign.real > 1e-9, logabsdet, torch.tensor(1e10, device=device))
    else:
        log_det_term = logabsdet

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        return torch.tensor(float('nan'), device=device)

    try:
        solved_term = torch.linalg.solve(I_expected_stable, I_sample)
        trace_term = torch.einsum('...ii->...', solved_term).real
    except torch.linalg.LinAlgError as e:
        return torch.tensor(float('inf'), device=device)

    if torch.isnan(trace_term).any() or torch.isinf(trace_term).any():
        return torch.tensor(float('nan'), device=device)

    likelihood_terms = log_det_term + trace_term

    if torch.isnan(likelihood_terms).any():
        return torch.tensor(float('nan'), device=device)

    total_sum = torch.sum(likelihood_terms)
    dc_term = likelihood_terms[0, 0] if n1 > 0 and n2 > 0 else torch.tensor(0.0, device=device)
    if torch.isnan(dc_term).any() or torch.isinf(dc_term).any():
        dc_term = torch.tensor(0.0, device=device)

    loss = total_sum - dc_term if (n1 > 1 or n2 > 1) else total_sum

    if torch.isnan(loss) or torch.isinf(loss):
         return torch.tensor(float('inf'), device=device) 

    return loss


# =========================================================================
# 5. Training Loop (Unchanged)
# =========================================================================
def run_full(params_list, optimizer, scheduler, I_sample, n1, n2, p, epochs=600, device='cpu'):
    """Corrected training loop. (Unchanged)"""
    best_loss = float('inf')
    params_list = [p.to(device) for p in params_list]
    best_params_state = [p.detach().clone() for p in params_list]
    epochs_completed = 0
    # DELTA_LAT, DELTA_LON are needed for the spatial part of the filter
    DELTA_LAT, DELTA_LON = 0.044, 0.063

    def get_printable_params(p_list):
        valid_tensors = [p for p in p_list if isinstance(p, torch.Tensor)]
        if not valid_tensors: return "Invalid params_list"
        p_cat = torch.cat([p.detach().clone().cpu() for p in valid_tensors])
        log_indices = [0, 1, 2, 6]
        if all(idx < len(p_cat) for idx in log_indices):
            log_vals = p_cat[log_indices]
            if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
                 p_cat[log_indices] = torch.exp(log_vals)
            else:
                 p_cat[log_indices] = float('nan')
        return p_cat.numpy().round(4)

    I_sample_dev = I_sample.to(device)

    for epoch in range(epochs):
        epochs_completed = epoch + 1
        optimizer.zero_grad()
        params_tensor = torch.cat(params_list)

        loss = whittle_likelihood_loss_no_taper(
            params_tensor, I_sample_dev, n1, n2, p, DELTA_LAT, DELTA_LON
        )

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Loss became NaN or Inf at epoch {epoch+1}. Stopping.")
            if epoch == 0: best_params_state = None
            epochs_completed = epoch
            break 

        loss.backward()

        nan_grad = False
        for param in params_list:
            if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                nan_grad = True
                break
        if nan_grad:
             optimizer.zero_grad() 
             continue 

        all_params_on_device = params_list
        if all_params_on_device:
            torch.nn.utils.clip_grad_norm_(all_params_on_device, max_norm=1.0)

        optimizer.step()
        scheduler.step() 

        current_loss_item = loss.item()
        if current_loss_item < best_loss:
            params_valid = not any(torch.isnan(p.data).any() or torch.isinf(p.data).any() for p in params_list)
            if params_valid:
                best_loss = current_loss_item
                best_params_state = [p.detach().clone() for p in params_list]

        if epoch % 50 == 0 or epoch == epochs - 1:
            current_lr = optimizer.param_groups[0]['lr'] if optimizer.param_groups else 0.0
            print(f'--- Epoch {epoch+1}/{epochs} (LR: {current_lr:.6f}) ---')
            print(f' Loss: {current_loss_item:.4f}')
            print(f' Parameters (Natural Scale): {get_printable_params(params_list)}')

    if best_params_state is None:
        return None, epochs_completed

    final_params_log_scale = torch.cat([p.cpu() for p in best_params_state])
    final_params_natural_scale = final_params_log_scale.detach().clone()
    log_indices = [0, 1, 2, 6]
    if all(idx < len(final_params_natural_scale) for idx in log_indices):
        log_vals = final_params_natural_scale[log_indices]
        if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
            final_params_natural_scale[log_indices] = torch.exp(log_vals)
        else:
            final_params_natural_scale[log_indices] = float('nan')

    final_params_rounded = [round(p.item(), 4) if not np.isnan(p.item()) else float('nan') for p in final_params_natural_scale]
    final_loss_rounded = round(best_loss, 3) if best_loss != float('inf') else float('inf')

    print("\n--- Training Complete ---")
    print(f'\nFINAL BEST STATE ACHIEVED (during training):')
    print(f'Best Loss: {final_loss_rounded}')
    print(f'Parameters Corresponding to Best Loss (Natural Scale): {final_params_rounded}')

    return final_params_rounded + [final_loss_rounded], epochs_completed




In [43]:
# =========================================================================
# 6. Main Execution Script (REVISED with Fixed Initial Parameters)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    NUM_RUNS = 1 
    EPOCHS = 700 
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- Grid Spacing (Still needed for the spatial part of the filter) ---
    DELTA_LAT, DELTA_LON = 0.044, 0.063 

    # --- Column Indices ---
    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2
    TIME_COL = 3

    # --- Load Spatio-Temporal Differenced Data ---
    try:
        with open("spacetime_differenced_data.pkl", 'rb') as f:
            processed_df = pickle.load(f)
        print(f"Loaded {len(processed_df)} days from spacetime_differenced_data.pkl.")
        
        processed_df = [
            torch.tensor(arr, dtype=torch.float32).cpu() if not isinstance(arr, torch.Tensor)
            else arr.cpu().to(torch.float32)
            for arr in processed_df
        ]
        if not processed_df: raise ValueError("'processed_df' is empty.")
    except FileNotFoundError:
        print("Error: `spacetime_differenced_data.pkl` not found.")
        print("Please run the data preparation script first.")
        exit()
    except Exception as e:
        print(f"Error loading or processing 'processed_df': {e}")
        exit()

    if DAY_TO_RUN > len(processed_df) or DAY_TO_RUN <= 0:
        print(f"Error: DAY_TO_RUN ({DAY_TO_RUN}) out of bounds.")
        exit()

    cur_df = processed_df[DAY_TO_RUN - 1]
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL):
        print(f"Error: Data for Day {DAY_TO_RUN} is empty or invalid.")
        exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute Sample Periodogram (NO Tapering) ---
    print("Pre-computing sample periodogram (NO data taper)...")
    J_vec, n1, n2, p = generate_Jvector_no_taper(
        time_slices_list,
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0:
       print(f"Error: J-vector generation failed for Day {DAY_TO_RUN}.")
       exit()

    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Error: NaN or Inf detected in the sample periodogram. Cannot proceed.")
        exit()

    print(f"Data grid: {n1}x{n2} spatial points, {p} time points. Sample Periodogram on {DEVICE}.")
    
    # --- 2. Optimization Loop ---
    all_final_results = []
    all_final_losses = []

    for i in range(NUM_RUNS):
        print(f"\n{'='*30} Initialization Run {i+1}/{NUM_RUNS} {'='*30}")

        # ✅ --- Use Fixed Initial Parameters ---
        # User specified natural scale: [21.303, 1.307, 1.563, 0.022, -0.144, 0.198, 4.769]
        # Convert indices 0, 1, 2, 6 to log-scale for the model
        initial_params_values = [
            np.log(21.303), # log(sigmasq)
            np.log(1.307), # log(r_lat)
            np.log(1.563), # log(r_lon)
            0.022,         # a_lat
            -0.144,        # a_lon
            0.198,         # beta
            np.log(4.769)  # log(nugget)
        ]
        
        print(f"Starting with fixed params (log-scale for [0,1,2,6]): {[round(p, 4) for p in initial_params_values]}")

        params_list = [
            Parameter(torch.tensor([val], dtype=torch.float32))
            for val in initial_params_values
        ]

        lr_slow, lr_fast = 0.005, 0.02
        slow_indices = [0, 1, 2, 6]
        fast_indices = [3, 4, 5]

        valid_slow_indices = [idx for idx in slow_indices if idx < len(params_list)]
        valid_fast_indices = [idx for idx in fast_indices if idx < len(params_list)]

        param_groups = [
            {'params': [params_list[idx] for idx in valid_slow_indices], 'lr': lr_slow, 'name': 'slow_group'},
            {'params': [params_list[idx] for idx in valid_fast_indices], 'lr': lr_fast, 'name': 'fast_group'}
        ]

        optimizer = torch.optim.Adam(param_groups)
        scheduler = CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)

        print(f"Starting optimization run {i+1} on device {DEVICE} (NO data taper, Spatio-Temporal Diff)...")
        final_results, epochs_run = run_full(
            params_list=params_list,
            optimizer=optimizer,
            scheduler=scheduler,
            I_sample=I_sample,
            n1=n1, n2=n2, p=p,
            epochs=EPOCHS,
            device=DEVICE
        )

        if final_results:
            all_final_results.append(final_results)
            all_final_losses.append(final_results[-1])
        else:
            all_final_results.append(None)
            all_final_losses.append(float('inf'))

    print(f"\n\n{'='*25} Overall Result from Run {'='*25}")
    valid_losses = [l for l in all_final_losses if l is not None and l != float('inf')]

    if not valid_losses:
        print(f"The run failed or resulted in an invalid loss for Day {DAY_TO_RUN}.")
    else:
        best_loss = valid_losses[0]
        best_run_index = 0
        best_results = all_final_results[best_run_index]
        print(f"Run Loss: {best_results[-1]}")
        print(f"Final Parameters (Natural Scale): {best_results[:-1]}")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Loaded 30 days from spacetime_differenced_data.pkl.
Pre-computing sample periodogram (NO data taper)...
Data grid: 113x158 spatial points, 8 time points. Sample Periodogram on cpu.

Starting with fixed params (log-scale for [0,1,2,6]): [3.0588, 0.2677, 0.4466, 0.022, -0.144, 0.198, 1.5621]
Starting optimization run 1 on device cpu (NO data taper, Spatio-Temporal Diff)...
--- Epoch 1/700 (LR: 0.005000) ---
 Loss: 837379.0000
 Parameters (Natural Scale): [21.4098  1.3005  1.5552  0.042  -0.124   0.178   4.7929]
--- Epoch 51/700 (LR: 0.004240) ---
 Loss: 636619.0000
 Parameters (Natural Scale): [ 2.57281e+01  1.08040e+00  1.28980e+00 -4.00000e-03 -1.11000e-02
  1.90000e-03  5.56710e+00]
--- Epoch 101/700 (LR: 0.002461) ---
 Loss: 587499.3750
 Parameters (Natural Scale): [ 2.80687e+01  1.00900e+00  1.17890e+00 -1.00000e-03 -9.40000e-03
 -1.00000e-03  5.79440e+00]
--- Epoch 151/700 (LR: 0.000706) ---
 Loss: 563940.6875
 Parameters (Natural Scale): [ 2.98396e+01  9.63200e-0

# 3d once difference not two stage

ey Change: From Sparse Lookup to Grid Convolution

The logic now:

Reshapes the day-long tensor into a 3D grid: [1, 1, N_lat, N_lon, N_time].

Applies the 3D first-order difference kernel K: Y(s,t)=X(s+Δ 
lat
​	
 ,t)+X(s+Δ 
lon
​	
 ,t)+X(s,t+1)−3X(s,t).

The kernel K has weights {−3,1,1,1} at (0,0,0),(Δ 
lat
​	
 ,0,0),(0,Δ 
lon
​	
 ,0),(0,0,Δ 
time
​	
 ).

In [44]:
import torch
import numpy as np
import os
import pickle
import torch.nn.functional as F # <-- Added F

# --- Constants from your reference code ---
STEP_LAT = 0.044
STEP_LON = 0.063

# =========================================================================
# 1. Helper Functions
# =========================================================================

def subset_tensor(df_tensor: torch.Tensor) -> torch.Tensor:
    """Subsets a tensor to lat [0, 5] and lon [123, 133]."""
    lat_mask = (df_tensor[:, 0] >= 0) & (df_tensor[:, 0] <= 5)
    lon_mask = (df_tensor[:, 1] >= 123) & (df_tensor[:, 1] <= 133)
    
    df_sub = df_tensor[lat_mask & lon_mask].clone()
    return df_sub

def reshape_day_tensor_to_3d_grid(day_tensor: torch.Tensor):
    """
    Utility function to reshape a 1D spatio-temporal tensor into a 3D grid [N_lat, N_lon, N_time].
    Requires a non-sparse grid for each time point.
    """
    if day_tensor.size(0) == 0:
        raise ValueError("Input tensor is empty.")
    
    unique_lats = torch.unique(day_tensor[:, 0])
    unique_lons = torch.unique(day_tensor[:, 1])
    unique_times = torch.unique(day_tensor[:, 3])
    
    n_lat, n_lon, n_time = len(unique_lats), len(unique_lons), len(unique_times)
    
    # Check for non-sparse grid across all time points
    if day_tensor.size(0) != n_lat * n_lon * n_time:
         # This check is crucial for the convolution approach
         raise ValueError(f"Input tensor size ({day_tensor.size(0)}) does not match expected 3D grid size ({n_lat*n_lon*n_time}). Data must be a complete, non-sparse grid.")

    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats)}
    lon_map = {lon.item(): j for j, lon in enumerate(unique_lons)}
    time_map = {t.item(): k for k, t in enumerate(unique_times)}
    
    # Create the 3D grid
    grid_data = torch.zeros((n_lat, n_lon, n_time), dtype=day_tensor.dtype, device=day_tensor.device)
    
    for row in day_tensor:
        lat, lon, ozone, t = row[0].item(), row[1].item(), row[2], row[3].item()
        i, j, k = lat_map[lat], lon_map[lon], time_map[t]
        grid_data[i, j, k] = ozone

    return grid_data, unique_lats, unique_lons, unique_times

def apply_3d_filter_convolution(day_tensor: torch.Tensor) -> torch.Tensor:
    """
    ✅ Applies the ONE-STAGE 3D filter Y(s,t) = X(s+d_lat) + X(s+d_lon) + X(s,t+1) - 3X(s,t)
    using 3D convolution, requiring a complete, non-sparse 3D grid input.
    """
    if day_tensor.size(0) == 0:
        return torch.empty(0, 4, dtype=day_tensor.dtype, device=day_tensor.device)

    # 1. Reshape data into a 3D grid [N_lat, N_lon, N_time]
    grid_data, unique_lats, unique_lons, unique_times = reshape_day_tensor_to_3d_grid(day_tensor)
    
    n_lat, n_lon, n_time = grid_data.shape
    if n_lat < 2 or n_lon < 2 or n_time < 2:
        # Not enough dimensions for 3D differencing
        return torch.empty(0, 4, dtype=day_tensor.dtype, device=day_tensor.device)

    # 2. Define the 3D Kernel (for cross-correlation)
    # Filter: X(i+1, j, k) + X(i, j+1, k) + X(i, j, k+1) - 3*X(i, j, k)
    kernel_weights = torch.zeros((1, 1, 2, 2, 2), dtype=day_tensor.dtype, device=day_tensor.device)
    kernel_weights[0, 0, 0, 0, 0] = -3.0 # X(i, j, k)
    kernel_weights[0, 0, 1, 0, 0] = 1.0  # X(i+1, j, k) - Lat Diff
    kernel_weights[0, 0, 0, 1, 0] = 1.0  # X(i, j+1, k) - Lon Diff
    kernel_weights[0, 0, 0, 0, 1] = 1.0  # X(i, j, k+1) - Time Diff

    # 3. Apply 3D convolution
    # Input shape: [1, 1, N_lat, N_lon, N_time]
    input_conv = grid_data.unsqueeze(0).unsqueeze(0)
    
    # Output shape: [1, 1, N_lat-1, N_lon-1, N_time-1]
    filtered_grid = F.conv3d(input_conv, kernel_weights, padding='valid').squeeze()
    
    # 4. Determine coordinates for the new, smaller grid
    new_lats = unique_lats[:-1]
    new_lons = unique_lons[:-1]
    new_times = unique_times[:-1]

    # 5. Reconstruct the output tensor
    new_lat_grid, new_lon_grid, new_time_grid = torch.meshgrid(
        new_lats, new_lons, new_times, indexing='ij'
    )
    
    filtered_values = filtered_grid.flatten()
    
    new_tensor = torch.stack([
        new_lat_grid.flatten(),
        new_lon_grid.flatten(),
        filtered_values,
        new_time_grid.flatten() # Anchored at time t
    ], dim=1)
    
    return new_tensor


# =========================================================================
# 2. Data Loading (Structure retained)
# =========================================================================
# ⚠️ NOTE: You must define 'mac_data_path', 'year', 'month', and the GEMS_TCO class
# (Assuming data loading variables are defined...)
# NOTE: The data loading assumes 'load_working_data_byday_wo_mm' returns a
# dictionary of chunks/hours (cur_map) AND an aggregated tensor (aggregated_day_tensor).
# The original code used the aggregated tensor, so we will use it directly below.

year = 2024
month = 7
month_str = f"{month:02d}"

# Placeholder definitions for running:
# class GEMS_TCO:
#     def load_data(self, path): return self
#     def load_working_data_byday_wo_mm(self, data, indices):
#         # Placeholder for a complete 3D grid (N_lat x N_lon x N_hour)
#         lats = torch.linspace(0.0, 5.0, 5) # N_lat=5
#         lons = torch.linspace(123.0, 133.0, 5) # N_lon=5
#         times = torch.arange(8) + indices[0] # N_hour=8
#         grid_lats, grid_lons, grid_times = torch.meshgrid(lats, lons, times, indexing='ij')
#         ozone = torch.randn_like(grid_lats) * 10
#         aggregated_tensor = torch.stack([grid_lats.flatten(), grid_lons.flatten(), ozone.flatten(), grid_times.flatten()], dim=1)
#         return {'chunk': aggregated_tensor}, aggregated_tensor # cur_map (dict), aggregated_day_tensor

# (Need actual data loading setup here)

pickle_path = os.path.join(mac_data_path, f'pickle_{year}')
output_filename = f"coarse_cen_map_without_decrement_latitude{str(year)[2:]}_{month_str}.pkl"
output_filepath = os.path.join(pickle_path, output_filename)
print(f"Loading data from: {output_filepath}")

try:
    with open(output_filepath, 'rb') as pickle_file:
        cbmap_ori = pickle.load(pickle_file)
except:
     # Placeholder data loading in case actual file is missing
     class GEMS_TCO: # Placeholder
        def load_data(self, path): return self
        def load_working_data_byday_wo_mm(self, data, indices):
             lats = torch.linspace(0.0, 5.0, 5) # N_lat=5
             lons = torch.linspace(123.0, 133.0, 5) # N_lon=5
             times = torch.arange(8) + indices[0] # N_hour=8
             grid_lats, grid_lons, grid_times = torch.meshgrid(lats, lons, times, indexing='ij')
             ozone = torch.randn_like(grid_lats) * 10
             aggregated_tensor = torch.stack([grid_lats.flatten(), grid_lons.flatten(), ozone.flatten(), grid_times.flatten()], dim=1)
             return {'chunk': aggregated_tensor}, aggregated_tensor
     cbmap_ori = {}


load_data_instance = GEMS_TCO.load_data('')

# =========================================================================
# 3. Build df_day_aggregated_list (MODIFIED to use aggregated_day_tensor)
# =========================================================================
df_day_aggregated_list = []
num_days_to_process = 31 # For July

print(f"\nLoading and SUBSETTING aggregated data for {num_days_to_process} days...")
for i in range(num_days_to_process): 
    # Use the entire aggregated_day_tensor for 3D processing
    cur_map, aggregated_day_tensor = load_data_instance.load_working_data_byday_wo_mm(
        cbmap_ori, [i*8, (i+1)*8]
    )
    
    if aggregated_day_tensor is not None and aggregated_day_tensor.numel() > 0:
        
        subsetted_tensor = subset_tensor(aggregated_day_tensor)
        
        if subsetted_tensor.size(0) > 0:
            df_day_aggregated_list.append(subsetted_tensor)
            print(f"  Aggregated & Subset tensor shape for day {i+1}: {subsetted_tensor.shape}")
        else:
            print(f"  No valid data found after SUBSETTING for day {i+1}.")
            
    else:
         print(f"  No valid aggregated data (cur_df) found for day {i+1}.")

print(f"\nFinished loading. Created {len(df_day_aggregated_list)} aggregated day tensors.")


# =========================================================================
# 4. Main 3D Filtering Loop (MODIFIED to use convolution)
# =========================================================================
all_filtered_days = [] 

print("\nApplying ONE-STAGE 3D Convolution filter...")
if not df_day_aggregated_list:
     print("Error: `df_day_aggregated_list` is empty after loading/subsetting.")
else:
    for day_idx, aggregated_day_tensor in enumerate(df_day_aggregated_list):
        print(f"Filtering Day {day_idx+1}...")
        if day_idx==5:
            break
        
        # Move to GPU if available for convolution speed
        tensor_dev = aggregated_day_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) 

        try:
            # Apply the 3D convolution filter
            filtered_day_tensor = apply_3d_filter_convolution(tensor_dev)
            
            if filtered_day_tensor.numel() > 0:
                # Move back to CPU for consistency before saving
                all_filtered_days.append(filtered_day_tensor.cpu()) 
                print(f"  Successfully filtered day {day_idx+1}. New shape: {filtered_day_tensor.shape}")
            else:
                print(f"  Skipping Day {day_idx+1}: filter resulted in an empty tensor (no valid points found).")
                 
        except ValueError as e:
            # Catch the non-sparse grid error here
            print(f" Skipping Day {day_idx+1}: Data structure error: {e}")
        except Exception as e: 
            print(f" An unexpected error occurred filtering Day {day_idx+1}: {e}")


print(f"\nFiltering complete. Generated {len(all_filtered_days)} final filtered day-tensors.")

# =========================================================================
# 5. Verification and Saving (Unchanged)
# =========================================================================
if all_filtered_days:
    # Filename now reflects the correct year and month
    processed_output_path = f"filtered_3d_convolution_data_{year}_{month_str}.pkl" 
    with open(processed_output_path, 'wb') as f:
        pickle.dump(all_filtered_days, f)
    print(f"Processed data for {len(all_filtered_days)} days saved to {processed_output_path}")

    first_day_tensor = all_filtered_days[0] 
    print("\nShape of the first filtered day tensor:", first_day_tensor.shape)
    print("Head of the first filtered day tensor:")
    print(first_day_tensor[:5])
    
    if first_day_tensor.numel() > 0:
        print("Unique time values in first tensor:", torch.unique(first_day_tensor[:, 3]))

else:
    print(f"\nNo final filtered tensors were created for {year}-{month_str}.")

Loading data from: /Users/joonwonlee/Documents/GEMS_DATA/pickle_2024/coarse_cen_map_without_decrement_latitude24_07.pkl

Loading and SUBSETTING aggregated data for 31 days...
  Aggregated & Subset tensor shape for day 1: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 2: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 3: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 4: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 5: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 6: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 7: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 8: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 9: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 10: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 11: torch.Size([145008, 4])
  Aggregated & Subset tensor shape for day 12: torch.

In [46]:
import torch
import numpy as np
import matplotlib.pyplot as plt # Keep if plotting might be added later
import cmath
import pickle
import time # For timing
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import Parameter
import pandas as pd # Make sure pandas is imported
import os # Make sure os is imported

# =========================================================================
# 1. Modeling Functions (Adapted for 3D Differencing)
# =========================================================================

# --- Bartlett Kernel (Used for c_gn when g_s=1) ---
def cgn_2dbartlett_kernel(u1, u2, n1, n2):
    """Computes the 2D Bartlett kernel. (Unchanged)"""
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    kernel = (1.0 - torch.abs(u1_tensor) / n1_eff) * (1.0 - torch.abs(u2_tensor) / n2_eff)
    return torch.clamp(kernel, min=0.0)

# --- Covariance of the Original Field X (EXPONENTIAL Kernel) ---
def cov_x_exponential(u1, u2, t, params):
    """Computes the autocovariance of the ORIGINAL process X. (Unchanged)"""
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    log_params_indices = [0, 1, 2, 6]
    if torch.isnan(params[log_params_indices]).any() or torch.isinf(params[log_params_indices]).any():
         out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
         return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    sigmasq, r_lat, r_lon, nugget = torch.exp(params[log_params_indices])
    a_lat, a_lon, beta = params[3], params[4], params[5]

    r_lat = torch.clamp(r_lat, min=1e-6)
    r_lon = torch.clamp(r_lon, min=1e-6)

    x1 = u1_dev / r_lat - a_lat * t_dev
    x2 = u2_dev / r_lon - a_lon * t_dev
    x3 = beta * t_dev
    distance_sq = x1**2 + x2**2 + x3**2
    epsilon = 1e-12
    distance_sq_clamped = torch.clamp(distance_sq, min=0.0)
    D = torch.sqrt(distance_sq_clamped + epsilon)
    cov_smooth = sigmasq * torch.exp(-D)

    is_zero_lag = (torch.abs(u1_dev) < 1e-9) & (torch.abs(u2_dev) < 1e-9) & (torch.abs(t_dev) < 1e-9)
    final_cov = torch.where(is_zero_lag, cov_smooth + nugget, cov_smooth)

    if torch.isnan(final_cov).any():
        print("Warning: NaN detected in cov_x_exponential output.")
    return final_cov

# --- (NEW) Covariance of the 3D Differenced Field Y ---
def cov_3d_difference(u1, u2, t, params, delta1, delta2):
    """
    ✅ Calculates covariance for the 3D filter:
    Y(s,t) = X(s+d1,t) + X(s+d2,t) + X(s,t+1) - 3X(s,t)
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    d1 = delta1 # spatial lag lat
    d2 = delta2 # spatial lag lon
    
    # Pre-calculate terms C_X(u+h_j-h_k, tau)
    term_00 = cov_x_exponential(u1_dev, u2_dev, t_dev, params) # u, tau
    
    term_p10 = cov_x_exponential(u1_dev + d1, u2_dev, t_dev, params) # u+d1, tau
    term_m10 = cov_x_exponential(u1_dev - d1, u2_dev, t_dev, params) # u-d1, tau
    
    term_0p1 = cov_x_exponential(u1_dev, u2_dev + d2, t_dev, params) # u+d2, tau
    term_0m1 = cov_x_exponential(u1_dev, u2_dev - d2, t_dev, params) # u-d2, tau
    
    term_m1p1 = cov_x_exponential(u1_dev - d1, u2_dev + d2, t_dev, params) # u-d1+d2, tau
    term_p1m1 = cov_x_exponential(u1_dev + d1, u2_dev - d2, t_dev, params) # u+d1-d2, tau
    
    term_00_tp1 = cov_x_exponential(u1_dev, u2_dev, t_dev + 1.0, params) # u, tau+1
    term_00_tm1 = cov_x_exponential(u1_dev, u2_dev, t_dev - 1.0, params) # u, tau-1
    
    term_m10_tp1 = cov_x_exponential(u1_dev - d1, u2_dev, t_dev + 1.0, params) # u-d1, tau+1
    term_p10_tm1 = cov_x_exponential(u1_dev + d1, u2_dev, t_dev - 1.0, params) # u+d1, tau-1
    
    term_0m1_tp1 = cov_x_exponential(u1_dev, u2_dev - d2, t_dev + 1.0, params) # u-d2, tau+1
    term_0p1_tm1 = cov_x_exponential(u1_dev, u2_dev + d2, t_dev - 1.0, params) # u+d2, tau-1

    # Check for NaNs in any term
    all_terms = [term_00, term_p10, term_m10, term_0p1, term_0m1, term_m1p1, term_p1m1,
                 term_00_tp1, term_00_tm1, term_m10_tp1, term_p10_tm1, term_0m1_tp1, term_0p1_tm1]
    if any(torch.isnan(term).any() for term in all_terms):
        print("Warning: NaN detected in one of the terms of cov_3d_difference")
        out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    # Combine terms according to the derived formula
    cov_Y = (12 * term_00
             - 3 * (term_p10 + term_m10)
             - 3 * (term_0p1 + term_0m1)
             + (term_m1p1 + term_p1m1)
             - 3 * (term_00_tp1 + term_00_tm1)
             + (term_m10_tp1 + term_p10_tm1)
             + (term_0m1_tp1 + term_0p1_tm1)
            )

    return cov_Y


# --- (MODIFIED) cn_bar for NO TAPERING ---
def cn_bar_no_taper(u1, u2, t, params, n1, n2, delta1, delta2):
    """
    Computes c_Y(u) * c_gn(u) where c_Y is cov_3d_difference
    and c_gn(u) is the Bartlett kernel.
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    # ✅ --- Call the NEW 3D difference covariance function ---
    cov_Y_value = cov_3d_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)

    c_gn_value = cgn_2dbartlett_kernel(u1_dev, u2_dev, n1, n2)

    if torch.isnan(cov_Y_value).any() or torch.isnan(c_gn_value).any():
        print("Warning: NaN detected before multiplication in cn_bar_no_taper.")
        out_shape = torch.broadcast_shapes(cov_Y_value.shape, c_gn_value.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    result = cov_Y_value * c_gn_value
    if torch.isnan(result).any():
        print("Warning: NaN detected after multiplication in cn_bar_no_taper.")
    return result


# --- Expected Periodogram (uses cn_bar_no_taper) ---
def expected_periodogram_fft_no_taper(params, n1, n2, p, delta1, delta2):
    """
    Calculates the expected periodogram. (Unchanged structure)
    This function is correct because it calls the modified cn_bar_no_taper.
    """
    device = params.device if isinstance(params, torch.Tensor) else params[0].device
    params_tensor = params.to(device)

    product_tensor = torch.zeros((n1, n2, p, p), dtype=torch.complex64, device=device)
    t_lags = torch.arange(p, dtype=torch.float32, device=device)
    u1_mesh_grid, u2_mesh_grid = torch.meshgrid(
        torch.arange(n1, dtype=torch.float32, device=device),
        torch.arange(n2, dtype=torch.float32, device=device),
        indexing='ij'
    )

    for q in range(p):
        for r in range(p):
            t_diff = t_lags[q] - t_lags[r] # This is tau
            # This call now correctly leads to cov_3d_difference
            cov_times_bartlett = cn_bar_no_taper(
                u1_mesh_grid, u2_mesh_grid, t_diff,
                params_tensor, n1, n2, delta1, delta2
            )
            if torch.isnan(cov_times_bartlett).any():
                 product_tensor[:, :, q, r] = float('nan')
            else:
                 product_tensor[:, :, q, r] = cov_times_bartlett.to(torch.complex64)

    if torch.isnan(product_tensor).any():
        print("Warning: NaN detected in product_tensor before FFT.")
        nan_shape = (n1, n2, p, p)
        return torch.full(nan_shape, float('nan'), dtype=torch.complex64, device=device)

    fft_result = torch.fft.fft2(product_tensor, dim=(0, 1))
    normalization_factor = 1.0 / (4.0 * cmath.pi**2)
    result = fft_result * normalization_factor

    if torch.isnan(result).any():
        print("Warning: NaN detected in expected_periodogram_fft_no_taper output after FFT.")
    return result


# =========================================================================
# 2. Data Processing (Unchanged)
# =========================================================================
def generate_Jvector_no_taper(tensor_list, lat_col, lon_col, val_col, device):
    """Generates J-vector for g_s=1 (NO taper). (Unchanged)"""
    p = len(tensor_list) # p is now the number of hours (e.g., 7)
    if p == 0: return torch.empty(0, 0, 0, device=device), 0, 0, 0

    valid_tensors = [t for t in tensor_list if t is not None and t.numel() > 0 and t.shape[1] > max(lat_col, lon_col, val_col)]
    if not valid_tensors:
         print("Warning: No valid tensors found in tensor_list.")
         return torch.empty(0, 0, 0, device=device), 0, 0, 0

    try:
        # Collect coords only from valid tensors
        all_lats_cpu = torch.cat([t[:, lat_col] for t in valid_tensors])
        all_lons_cpu = torch.cat([t[:, lon_col] for t in valid_tensors])
    except IndexError:
        print(f"Error: Invalid column index. Check tensor shapes.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0
    except Exception as e:
        print(f"Error concatenating coordinates: {e}")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    # Ensure coordinates are finite
    all_lats_cpu = all_lats_cpu[~torch.isnan(all_lats_cpu) & ~torch.isinf(all_lats_cpu)]
    all_lons_cpu = all_lons_cpu[~torch.isnan(all_lons_cpu) & ~torch.isinf(all_lons_cpu)]
    if all_lats_cpu.numel() == 0 or all_lons_cpu.numel() == 0:
        print("Warning: No valid coordinates found after filtering.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    unique_lats_cpu, unique_lons_cpu = torch.unique(all_lats_cpu), torch.unique(all_lons_cpu)
    n1, n2 = len(unique_lats_cpu), len(unique_lons_cpu)
    if n1 == 0 or n2 == 0:
        print("Warning: Grid dimensions are zero.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    lat_map = {round(lat.item(), 5): i for i, lat in enumerate(unique_lats_cpu)}
    lon_map = {round(lon.item(), 5): i for i, lon in enumerate(unique_lons_cpu)}

    fft_results = []
    # Iterate through the HOURLY tensors in the list
    for tensor in tensor_list:
        data_grid = torch.zeros((n1, n2), dtype=torch.float32, device=device)
        if tensor is not None and tensor.numel() > 0: # Check if tensor is valid
            tensor_dev = tensor.to(device)
            for row in tensor_dev:
                # Round coordinates for lookup consistency
                lat_item = round(row[lat_col].item(), 5)
                lon_item = round(row[lon_col].item(), 5)
                
                # Check if coordinates exist in the map (handles potentially missing coords)
                i = lat_map.get(lat_item)
                j = lon_map.get(lon_item)
                
                if i is not None and j is not None:
                    val = row[val_col]
                    val_num = val.item() if isinstance(val, torch.Tensor) else val
                    if not np.isnan(val_num) and not np.isinf(val_num):
                        data_grid[i, j] = val_num
        
        # Ensure grid is finite before FFT
        data_grid = torch.nan_to_num(data_grid, nan=0.0, posinf=0.0, neginf=0.0)
        fft_results.append(torch.fft.fft2(data_grid))

    if not fft_results:
         print("Warning: No FFT results generated.")
         return torch.empty(0, 0, 0, device=device), n1, n2, 0

    J_vector_tensor = torch.stack(fft_results, dim=2).to(device) # Shape [n1, n2, p]

    H = float(n1 * n2)
    if H < 1e-9:
        norm_factor = torch.tensor(0.0, device=device)
    else:
        norm_factor = (torch.sqrt(torch.tensor(1.0 / H, device=device)) / (2.0 * cmath.pi))

    result = J_vector_tensor * norm_factor
    return result, n1, n2, p # p is the number of hours


def calculate_sample_periodogram_vectorized(J_vector_tensor):
    """Calculates sample periodogram I_n = J J^H. (Unchanged)"""
    if torch.isnan(J_vector_tensor).any() or torch.isinf(J_vector_tensor).any():
        n1, n2, p = J_vector_tensor.shape
        return torch.full((n1, n2, p, p), float('nan'), dtype=torch.complex64, device=J_vector_tensor.device)

    J_col = J_vector_tensor.unsqueeze(-1)
    J_row_conj = J_vector_tensor.unsqueeze(-2).conj()
    result = J_col @ J_row_conj
    return result


# =========================================================================
# 4. Likelihood Calculation (Unchanged Structure)
# =========================================================================

def whittle_likelihood_loss_no_taper(params, I_sample, n1, n2, p, delta1, delta2):
    """Whittle Likelihood Loss. (Unchanged structure)"""
    device = I_sample.device
    params_tensor = params.to(device)

    if torch.isnan(params_tensor).any() or torch.isinf(params_tensor).any():
        return torch.tensor(float('nan'), device=device)

    # This call now correctly leads to cov_3d_difference
    I_expected = expected_periodogram_fft_no_taper(
        params_tensor, n1, n2, p, delta1, delta2
    )

    if torch.isnan(I_expected).any() or torch.isinf(I_expected).any():
        print("Warning: NaN/Inf returned from expected_periodogram calculation.")
        return torch.tensor(float('nan'), device=device)

    eye_matrix = torch.eye(p, dtype=torch.complex64, device=device)
    diag_vals = torch.abs(I_expected.diagonal(dim1=-2, dim2=-1))
    mean_diag_abs = diag_vals.mean().item() if diag_vals.numel() > 0 and not torch.isnan(diag_vals).all() else 1.0
    diag_load = max(mean_diag_abs * 1e-8, 1e-9)

    I_expected_stable = I_expected + eye_matrix * diag_load

    sign, logabsdet = torch.linalg.slogdet(I_expected_stable)
    if torch.any(sign.real <= 1e-9):
        log_det_term = torch.where(sign.real > 1e-9, logabsdet, torch.tensor(1e10, device=device))
    else:
        log_det_term = logabsdet

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        return torch.tensor(float('nan'), device=device)

    try:
        # Check shapes before solve
        if I_expected_stable.shape[-2:] != (p, p) or I_sample.shape[-2:] != (p,p):
             print(f"Shape mismatch: I_expected_stable {I_expected_stable.shape}, I_sample {I_sample.shape}, p={p}")
             return torch.tensor(float('nan'), device=device)
        solved_term = torch.linalg.solve(I_expected_stable, I_sample)
        trace_term = torch.einsum('...ii->...', solved_term).real
    except torch.linalg.LinAlgError as e:
         print(f"Warning: LinAlgError during solve: {e}. Applying high loss penalty.")
         return torch.tensor(float('inf'), device=device)
    except RuntimeError as e: # Catch other potential errors like shape mismatch if not caught above
        print(f"Runtime Error during solve/trace: {e}")
        return torch.tensor(float('inf'), device=device)


    if torch.isnan(trace_term).any() or torch.isinf(trace_term).any():
        return torch.tensor(float('nan'), device=device)

    likelihood_terms = log_det_term + trace_term

    if torch.isnan(likelihood_terms).any():
        return torch.tensor(float('nan'), device=device)

    total_sum = torch.sum(likelihood_terms)
    # DC term is at index (0,0) in the spatial frequency domain
    dc_term = likelihood_terms[0, 0] if n1 > 0 and n2 > 0 else torch.tensor(0.0, device=device)
    if torch.isnan(dc_term).any() or torch.isinf(dc_term).any():
        dc_term = torch.tensor(0.0, device=device)

    # Subtract DC term only if there are non-DC frequencies
    loss = total_sum - dc_term if (n1 > 1 or n2 > 1) else total_sum

    if torch.isnan(loss) or torch.isinf(loss):
         return torch.tensor(float('inf'), device=device)

    return loss


# =========================================================================
# 5. Training Loop (Unchanged)
# =========================================================================
def run_full(params_list, optimizer, scheduler, I_sample, n1, n2, p, epochs=600, device='cpu'):
    """Corrected training loop. (Unchanged)"""
    best_loss = float('inf')
    params_list = [p.to(device) for p in params_list]
    best_params_state = [p.detach().clone() for p in params_list]
    epochs_completed = 0
    DELTA_LAT, DELTA_LON = 0.044, 0.063 # Still needed

    def get_printable_params(p_list):
        valid_tensors = [p for p in p_list if isinstance(p, torch.Tensor)]
        if not valid_tensors: return "Invalid params_list"
        p_cat = torch.cat([p.detach().clone().cpu() for p in valid_tensors])
        log_indices = [0, 1, 2, 6]
        if all(idx < len(p_cat) for idx in log_indices):
            log_vals = p_cat[log_indices]
            if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
                 # Ensure log_vals are positive before exp
                 if torch.all(log_vals > -torch.inf):
                     try:
                         p_cat[log_indices] = torch.exp(log_vals)
                     except RuntimeError: # Handle potential overflow if log_vals are huge
                          p_cat[log_indices] = float('inf')
                 else:
                     p_cat[log_indices] = float('nan') # Cannot exp non-finite
            else:
                 p_cat[log_indices] = float('nan')
        return p_cat.numpy().round(4)

    I_sample_dev = I_sample.to(device)

    for epoch in range(epochs):
        epochs_completed = epoch + 1
        optimizer.zero_grad()
        params_tensor = torch.cat(params_list)

        loss = whittle_likelihood_loss_no_taper(
            params_tensor, I_sample_dev, n1, n2, p, DELTA_LAT, DELTA_LON
        )

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Loss became NaN or Inf at epoch {epoch+1}. Stopping.")
            if epoch == 0: best_params_state = None
            epochs_completed = epoch
            break

        loss.backward()

        nan_grad = False
        for param in params_list:
            if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                nan_grad = True
                break
        if nan_grad:
             print(f"Warning: NaN/Inf gradient detected at epoch {epoch+1}. Skipping step.")
             optimizer.zero_grad()
             # Optionally reduce LR or revert params here if needed
             continue # Skip optimizer step and scheduler step

        # Only clip and step if gradients are valid
        all_params_on_device = params_list
        if all_params_on_device:
            torch.nn.utils.clip_grad_norm_(all_params_on_device, max_norm=1.0)
        
        optimizer.step()
        scheduler.step() # Step scheduler after optimizer

        current_loss_item = loss.item()
        if current_loss_item < best_loss:
            params_valid = not any(torch.isnan(p.data).any() or torch.isinf(p.data).any() for p in params_list)
            if params_valid:
                best_loss = current_loss_item
                best_params_state = [p.detach().clone() for p in params_list]
            # else: # Optionally print warning if params become invalid after step
            #     print(f"Warning: Params became invalid after step {epoch+1}. Not saving state.")


        if epoch % 50 == 0 or epoch == epochs - 1:
            current_lr = optimizer.param_groups[0]['lr'] if optimizer.param_groups else 0.0
            print(f'--- Epoch {epoch+1}/{epochs} (LR: {current_lr:.6f}) ---')
            print(f' Loss: {current_loss_item:.4f}')
            print(f' Parameters (Natural Scale): {get_printable_params(params_list)}')

    if best_params_state is None:
        print("Training failed to find a valid model state.")
        return None, epochs_completed

    final_params_log_scale = torch.cat([p.cpu() for p in best_params_state])
    final_params_natural_scale = final_params_log_scale.detach().clone()
    log_indices = [0, 1, 2, 6]
    if all(idx < len(final_params_natural_scale) for idx in log_indices):
        log_vals = final_params_natural_scale[log_indices]
        if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
             if torch.all(log_vals > -torch.inf):
                 try:
                      final_params_natural_scale[log_indices] = torch.exp(log_vals)
                 except RuntimeError:
                      final_params_natural_scale[log_indices] = float('inf')
             else:
                  final_params_natural_scale[log_indices] = float('nan')
        else:
            final_params_natural_scale[log_indices] = float('nan')

    final_params_rounded = [round(p.item(), 4) if not np.isnan(p.item()) else float('nan') for p in final_params_natural_scale]
    final_loss_rounded = round(best_loss, 3) if best_loss != float('inf') else float('inf')

    print("\n--- Training Complete ---")
    print(f'\nFINAL BEST STATE ACHIEVED (during training):')
    print(f'Best Loss: {final_loss_rounded}')
    print(f'Parameters Corresponding to Best Loss (Natural Scale): {final_params_rounded}')

    return final_params_rounded + [final_loss_rounded], epochs_completed

# =========================================================================
# 6. Main Execution Script (MODIFIED to load correct sparse data)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_INDEX_TO_RUN = 0 # Index in the outer list (0 corresponds to the first day processed)
    NUM_RUNS = 1
    EPOCHS = 700
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- Grid Spacing (Needed for the covariance model) ---
    DELTA_LAT, DELTA_LON = 0.044, 0.063

    # --- Column Indices ---
    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2
    TIME_COL = 3 # This column now represents the original hour (0-6)

    # --- (MODIFIED) Load 3D Filtered **Sparse** Data ---
    try:
        # ✅ --- LOAD THE NEW, CORRECT FILE ---
        # (Make sure year/month match the file you just created)
        year = 2024
        month = 7
        month_str = f"{month:02d}"
        data_filename = f"filtered_3d_sparse_data_{year}_{month_str}.pkl"
        
        with open(data_filename, 'rb') as f:
            # This is a list of Tensors: [day1_tensor, day2_tensor, ...]
            all_filtered_days = pickle.load(f)
        
        print(f"Loaded {len(all_filtered_days)} days from {data_filename}.")

        if not all_filtered_days: raise ValueError("Loaded data is empty.")

        # ✅ --- Select the single tensor for the desired day ---
        if DAY_INDEX_TO_RUN < 0 or DAY_INDEX_TO_RUN >= len(all_filtered_days):
             raise IndexError(f"DAY_INDEX_TO_RUN ({DAY_INDEX_TO_RUN}) is out of bounds for the loaded data.")
        
        day_tensor = all_filtered_days[DAY_INDEX_TO_RUN]
        
        if day_tensor is None or day_tensor.numel() == 0:
            raise ValueError(f"Data for day index {DAY_INDEX_TO_RUN} is empty.")

        day_tensor = day_tensor.cpu().to(torch.float32)
        
        # ✅ --- SPLIT the day tensor into a list of hourly tensors ---
        # This is the new step required by generate_Jvector_no_taper
        unique_times_in_day = torch.unique(day_tensor[:, TIME_COL])
        time_slices_list = []
        for t in unique_times_in_day:
            hourly_tensor = day_tensor[day_tensor[:, TIME_COL] == t]
            time_slices_list.append(hourly_tensor)
        
        if not time_slices_list:
            raise ValueError(f"Could not split day {DAY_INDEX_TO_RUN} into time slices.")

    except FileNotFoundError:
        print(f"Error: `{data_filename}` not found.")
        print("Please run the 3D data preparation script first.")
        exit()
    except Exception as e:
        print(f"Error loading or processing data: {e}")
        exit()

    # --- 1. Pre-compute Sample Periodogram (NO Tapering) ---
    print(f"Pre-computing sample periodogram for Day Index {DAY_INDEX_TO_RUN} (NO data taper)...")
    
    # ✅ Pass the NEW list of hourly tensors
    J_vec, n1, n2, p = generate_Jvector_no_taper(
        time_slices_list,  # This is now correctly formatted
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0:
       print(f"Error: J-vector generation failed for Day Index {DAY_INDEX_TO_RUN}.")
       print(f"Number of time slices: {len(time_slices_list)}")
       print(f"Shapes of time slices: {[t.shape for t in time_slices_list if t is not None]}")
       exit()

    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Error: NaN or Inf detected in the sample periodogram. Cannot proceed.")
        exit()

    # p now represents the number of hours used (e.g., 7)
    print(f"Data grid: {n1}x{n2} spatial points, {p} time points (hours). Sample Periodogram on {DEVICE}.")

    # --- 2. Optimization Loop (Unchanged from here down) ---
    all_final_results = []
    all_final_losses = []

    for i in range(NUM_RUNS):
        print(f"\n{'='*30} Initialization Run {i+1}/{NUM_RUNS} {'='*30}")

        # ✅ --- Use Fixed Initial Parameters ---
        initial_params_values = [
            np.log(21.303), np.log(1.307), np.log(1.563),
            0.022, -0.144, 0.198,
            np.log(4.769)
        ]

        print(f"Starting with fixed params (log-scale for [0,1,2,6]): {[round(p, 4) for p in initial_params_values]}")

        params_list = [
            Parameter(torch.tensor([val], dtype=torch.float32))
            for val in initial_params_values
        ]

        lr_slow, lr_fast = 0.005, 0.02
        slow_indices = [0, 1, 2, 6]
        fast_indices = [3, 4, 5]

        valid_slow_indices = [idx for idx in slow_indices if idx < len(params_list)]
        valid_fast_indices = [idx for idx in fast_indices if idx < len(params_list)]

        param_groups = [
            {'params': [params_list[idx] for idx in valid_slow_indices], 'lr': lr_slow, 'name': 'slow_group'},
            {'params': [params_list[idx] for idx in valid_fast_indices], 'lr': lr_fast, 'name': 'fast_group'}
        ]

        optimizer = torch.optim.Adam(param_groups)
        scheduler = CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)

        print(f"Starting optimization run {i+1} on device {DEVICE} (NO data taper, 3D Diff)...")
        final_results, epochs_run = run_full(
            params_list=params_list,
            optimizer=optimizer,
            scheduler=scheduler,
            I_sample=I_sample,
            n1=n1, n2=n2, p=p, # p is now number of hours
            epochs=EPOCHS,
            device=DEVICE
        )

        if final_results:
            all_final_results.append(final_results)
            all_final_losses.append(final_results[-1])
        else:
            all_final_results.append(None)
            all_final_losses.append(float('inf'))

    print(f"\n\n{'='*25} Overall Result from Run {'='*25}")
    valid_losses = [l for l in all_final_losses if l is not None and l != float('inf')]

    if not valid_losses:
        print(f"The run failed or resulted in an invalid loss for Day Index {DAY_INDEX_TO_RUN}.")
    else:
        best_loss = valid_losses[0]
        best_run_index = 0
        best_results = all_final_results[best_run_index]
        print(f"Run Loss: {best_results[-1]}")
        print(f"Final Parameters (Natural Scale): {best_results[:-1]}")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Loaded 5 days from filtered_3d_sparse_data_2024_07.pkl.
Pre-computing sample periodogram for Day Index 0 (NO data taper)...
Data grid: 113x110 spatial points, 7 time points (hours). Sample Periodogram on cpu.

Starting with fixed params (log-scale for [0,1,2,6]): [3.0588, 0.2677, 0.4466, 0.022, -0.144, 0.198, 1.5621]
Starting optimization run 1 on device cpu (NO data taper, 3D Diff)...
--- Epoch 1/700 (LR: 0.005000) ---
 Loss: 111587.6562
 Parameters (Natural Scale): [21.1967  1.3136  1.5708  0.042  -0.124   0.178   4.7452]
--- Epoch 51/700 (LR: 0.004240) ---
 Loss: 107623.2734
 Parameters (Natural Scale): [ 1.84694e+01  1.52260e+00  1.78040e+00  1.54000e-02  2.62000e-02
 -7.90000e-03  4.00510e+00]
--- Epoch 101/700 (LR: 0.002461) ---
 Loss: 107617.1875
 Parameters (Natural Scale): [ 1.88965e+01  1.59290e+00  1.61780e+00  5.80000e-03  2.71000e-02
 -2.00000e-03  4.00710e+00]
--- Epoch 151/700 (LR: 0.000706) ---
 Loss: 107613.0859
 Parameters (Natural Scale): [ 1.94704e

## Time twice + space once

In [3]:
import torch
import numpy as np
import torch.nn.functional as F
import os
import pickle

# --- Helper Functions (Unchanged) ---

def subset_tensor(df_tensor: torch.Tensor) -> torch.Tensor:
    """
    Subsets a tensor to a specific lat/lon range.
    Columns are assumed to be [lat, lon, ozone, time].
    """
    lat_mask = (df_tensor[:, 0] >= 0) & (df_tensor[:, 0] <= 5)
    lon_mask = (df_tensor[:, 1] >= 123) & (df_tensor[:, 1] <= 133)
    
    df_sub = df_tensor[lat_mask & lon_mask].clone()
    return df_sub

def apply_spatial_diff_sparse(df_tensor: torch.Tensor) -> torch.Tensor:
    """
    Applies the 2D spatial filter Z(s) = X(s+d_lat) + X(s+d_lon) - 2X(s)
    to sparse data using a dictionary lookup.
    """
    if df_tensor.size(0) == 0:
        return torch.empty(0, 4)

    unique_lats = torch.unique(df_tensor[:, 0])
    unique_lons = torch.unique(df_tensor[:, 1])

    if len(unique_lats) < 2 or len(unique_lons) < 2:
        raise ValueError("Not enough unique lat/lon points to find grid spacing.")

    sorted_lats = torch.sort(unique_lats)[0]
    sorted_lons = torch.sort(unique_lons)[0]
    delta_lat = round((sorted_lats[1:] - sorted_lats[:-1]).min().item(), 5)
    delta_lon = round((sorted_lons[1:] - sorted_lons[:-1]).min().item(), 5)

    if delta_lat == 0 or delta_lon == 0:
        raise ValueError("Could not determine a valid grid spacing.")

    data_lookup = {
        (round(row[0].item(), 5), round(row[1].item(), 5)): row[2].item() 
        for row in df_tensor
    }
    time_value = df_tensor[0, 3] 

    filtered_rows = []
    for (lat, lon), val_s in data_lookup.items():
        val_s_lat = data_lookup.get((round(lat + delta_lat, 5), lon))
        val_s_lon = data_lookup.get((lat, round(lon + delta_lon, 5)))

        if val_s_lat is not None and val_s_lon is not None:
            diff_val = val_s_lat + val_s_lon - (2 * val_s)
            new_row = torch.tensor([lat, lon, diff_val, time_value])
            filtered_rows.append(new_row)

    if not filtered_rows:
        return torch.empty(0, 4)

    return torch.stack(filtered_rows, dim=0)


# --- Data Loading (Unchanged) ---
# ⚠️ NOTE: You must define these variables in your environment
# mac_data_path = "..."
# year = 2022
# month_str = "01"
# class GEMS_TCO: # Placeholder
#     def load_data(self, path): return self
#     def load_working_data_byday_wo_mm(self, data, indices):
#         return {'key': torch.randn(100, 4)}, torch.randn(100, 4)

# (Assuming data loading variables are defined...)
pickle_path = os.path.join(mac_data_path, f'pickle_{year}')
output_filename = f"coarse_cen_map_without_decrement_latitude{str(year)[2:]}_{month_str}.pkl"
output_filepath = os.path.join(pickle_path, output_filename)
print(f"Loading data from: {output_filepath}")

with open(output_filepath, 'rb') as pickle_file:
    cbmap_ori = pickle.load(pickle_file)

load_data_instance = GEMS_TCO.load_data('')
df_day_map_list = []
for i in range(31): # Adjust if necessary
    cur_map, _ = load_data_instance.load_working_data_byday_wo_mm(cbmap_ori, [i*8, (i+1)*8])
    df_day_map_list.append(cur_map)
print(f"Loaded {len(df_day_map_list)} days of raw data.")


# --- Main Processing Loop ---

# ✅ STAGE 1: Apply the spatial filter (Unchanged)
spatially_filtered_days = [] # Result: Z(s,t)

print("Starting STAGE 1: Spatial Differencing...")
for day_idx, day_map in enumerate(df_day_map_list):
    tensors_to_aggregate = []
    for key, tensor in day_map.items():
        subsetted = subset_tensor(tensor)
        if subsetted.size(0) > 0:
            try:
                diff_applied = apply_spatial_diff_sparse(subsetted)
                if diff_applied.size(0) > 0:
                    tensors_to_aggregate.append(diff_applied)
            except ValueError as e:
                print(f"Skipping chunk on day {day_idx+1}, key {key}: {e}")
    if tensors_to_aggregate:
        aggregated_day_tensor = torch.cat(tensors_to_aggregate, dim=0)
        spatially_filtered_days.append(aggregated_day_tensor)
print(f"STAGE 1 Complete. Created {len(spatially_filtered_days)} spatially filtered day-tensors.")

# ✅ STAGE 2: Apply the FIRST temporal difference (Unchanged logic, variable renamed)
first_temporal_diff_tensors = [] # Result: Y1(s,t) = Z(s,t) - Z(s,t-1)

print("Starting STAGE 2: First Temporal Differencing...")
if len(spatially_filtered_days) > 1:
    for i in range(1, len(spatially_filtered_days)):
        prev_day_tensor = spatially_filtered_days[i-1] # Z(s, t-1)
        current_day_tensor = spatially_filtered_days[i] # Z(s, t)
        
        prev_day_lookup = {
            (round(row[0].item(), 5), round(row[1].item(), 5)): row[2].item() 
            for row in prev_day_tensor
        }
        
        temporally_differenced_rows = []
        for row in current_day_tensor:
            lat = round(row[0].item(), 5)
            lon = round(row[1].item(), 5)
            
            if (lat, lon) in prev_day_lookup:
                current_Z = row[2].item()
                prev_Z = prev_day_lookup[(lat, lon)]
                Y1_value = current_Z - prev_Z # First difference
                
                new_row = torch.tensor([lat, lon, Y1_value, row[3]])
                temporally_differenced_rows.append(new_row)
        
        if temporally_differenced_rows:
            day_diff_tensor = torch.stack(temporally_differenced_rows, dim=0)
            first_temporal_diff_tensors.append(day_diff_tensor)
print(f"STAGE 2 Complete. Created {len(first_temporal_diff_tensors)} first-temporal-difference tensors.")

# ✅ STAGE 3: Apply the SECOND temporal difference
second_temporal_diff_tensors = [] # Result: Y2(s,t) = Y1(s,t) - Y1(s,t-1)

print("Starting STAGE 3: Second Temporal Differencing...")
# We need at least two Y1 tensors to compute the second difference
if len(first_temporal_diff_tensors) > 1:
    # Iterate from the second Y1 tensor (which corresponds to day index i=2 of Z)
    for i in range(1, len(first_temporal_diff_tensors)):
        prev_Y1_tensor = first_temporal_diff_tensors[i-1] # Y1(s, t-1)
        current_Y1_tensor = first_temporal_diff_tensors[i]  # Y1(s, t)
        
        # Create lookup for the previous Y1 values
        prev_Y1_lookup = {
            (round(row[0].item(), 5), round(row[1].item(), 5)): row[2].item() 
            for row in prev_Y1_tensor
        }
        
        second_diff_rows = []
        for row in current_Y1_tensor:
            lat = round(row[0].item(), 5)
            lon = round(row[1].item(), 5)
            
            if (lat, lon) in prev_Y1_lookup:
                current_Y1 = row[2].item()
                prev_Y1 = prev_Y1_lookup[(lat, lon)]
                Y2_value = current_Y1 - prev_Y1 # Second difference
                
                # Keep coordinates and the time stamp of the *current* Y1 tensor
                new_row = torch.tensor([lat, lon, Y2_value, row[3]])
                second_diff_rows.append(new_row)
        
        if second_diff_rows:
            day_second_diff_tensor = torch.stack(second_diff_rows, dim=0)
            second_temporal_diff_tensors.append(day_second_diff_tensor)
print(f"STAGE 3 Complete. Created {len(second_temporal_diff_tensors)} second-temporal-difference tensors.")

# --- Verification ---
print("\n--- Results ---")
print("Number of spatially filtered day tensors (Z):", len(spatially_filtered_days))
print("Number of first-temporal difference tensors (Y1):", len(first_temporal_diff_tensors))
print("Number of final second-temporal difference tensors (Y2):", len(second_temporal_diff_tensors))

if second_temporal_diff_tensors:
    # Save the final processed data
    processed_output_path = "spacetime_second_diff_data.pkl" # New filename
    with open(processed_output_path, 'wb') as f:
        pickle.dump(second_temporal_diff_tensors, f)
    print(f"Processed data saved to {processed_output_path}")

    print("\nShape of the first final tensor (Y2):", second_temporal_diff_tensors[0].shape)
    print("First final tensor head (Y2):")
    print(second_temporal_diff_tensors[0][:5])
else:
    print("\nNo final second-differenced tensors were created. Check data or filter logic.")

Loading data from: /Users/joonwonlee/Documents/GEMS_DATA/pickle_2024/coarse_cen_map_without_decrement_latitude24_07.pkl
Loaded 31 days of raw data.
Starting STAGE 1: Spatial Differencing...
STAGE 1 Complete. Created 31 spatially filtered day-tensors.
Starting STAGE 2: First Temporal Differencing...
STAGE 2 Complete. Created 30 first-temporal-difference tensors.
Starting STAGE 3: Second Temporal Differencing...
STAGE 3 Complete. Created 29 second-temporal-difference tensors.

--- Results ---
Number of spatially filtered day tensors (Z): 31
Number of first-temporal difference tensors (Y1): 30
Number of final second-temporal difference tensors (Y2): 29
Processed data saved to spacetime_second_diff_data.pkl

Shape of the first final tensor (Y2): torch.Size([7232, 4])
First final tensor head (Y2):
tensor([[  4.9320, 132.4170,   3.0988,  69.0000],
        [  4.9320, 131.5980,  -5.1850,  69.0000],
        [  4.9320, 131.2830,  -9.9230,  69.0000],
        [  4.9320, 130.7790,   8.8235,  69.000

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt # Keep if plotting might be added later
import cmath
import pickle
import time # For timing
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import Parameter
import pandas as pd # Make sure pandas is imported
import os # Make sure os is imported

# =========================================================================
# 1. Modeling Functions (Adapted for Spat + 2x Temp Differencing)
# =========================================================================

# --- Bartlett Kernel (Used for c_gn when g_s=1) ---
def cgn_2dbartlett_kernel(u1, u2, n1, n2):
    """
    Computes the 2D Bartlett kernel: Product(1 - |ui|/ni). (Unchanged)
    """
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    kernel = (1.0 - torch.abs(u1_tensor) / n1_eff) * (1.0 - torch.abs(u2_tensor) / n2_eff)
    return torch.clamp(kernel, min=0.0)

# --- Covariance of the Original Field X (EXPONENTIAL Kernel) ---
def cov_x_exponential(u1, u2, t, params):
    """
    Computes the autocovariance of the ORIGINAL process X. (Unchanged)
    Expects log-scale params [0,1,2,6].
    """
    device = params.device 
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    log_params_indices = [0, 1, 2, 6]
    if torch.isnan(params[log_params_indices]).any() or torch.isinf(params[log_params_indices]).any():
         out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
         return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    sigmasq, r_lat, r_lon, nugget = torch.exp(params[log_params_indices])
    a_lat, a_lon, beta = params[3], params[4], params[5]

    r_lat = torch.clamp(r_lat, min=1e-6)
    r_lon = torch.clamp(r_lon, min=1e-6)

    x1 = u1_dev / r_lat - a_lat * t_dev
    x2 = u2_dev / r_lon - a_lon * t_dev
    x3 = beta * t_dev
    distance_sq = x1**2 + x2**2 + x3**2
    epsilon = 1e-12
    distance_sq_clamped = torch.clamp(distance_sq, min=0.0)
    D = torch.sqrt(distance_sq_clamped + epsilon) 
    cov_smooth = sigmasq * torch.exp(-D) 

    is_zero_lag = (torch.abs(u1_dev) < 1e-9) & (torch.abs(u2_dev) < 1e-9) & (torch.abs(t_dev) < 1e-9)
    final_cov = torch.where(is_zero_lag, cov_smooth + nugget, cov_smooth)

    if torch.isnan(final_cov).any():
        print("Warning: NaN detected in cov_x_exponential output.")
    return final_cov


# --- Covariance of the Spatially Differenced Field Z ---
def cov_spatial_difference(u1, u2, t, params, delta1, delta2):
    """
    Calculates covariance Cov(Z(s), Z(s+u)) for the SPATIAL-ONLY filter:
    Z(s) = X(s+d1) + X(s+d2) - 2X(s). (Unchanged)
    """
    weights = {(0, 0): -2.0, (1, 0): 1.0, (0, 1): 1.0}
    device = params.device
    out_shape = torch.broadcast_shapes(u1.shape if isinstance(u1, torch.Tensor) else (),
                                     u2.shape if isinstance(u2, torch.Tensor) else (),
                                     t.shape if isinstance(t, torch.Tensor) else ())
    cov = torch.zeros(out_shape, device=device, dtype=torch.float32)

    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    for (a_idx, b_idx), w_ab in weights.items():
        offset_a1 = a_idx * delta1
        offset_a2 = b_idx * delta2
        for (c_idx, d_idx), w_cd in weights.items():
            offset_c1 = c_idx * delta1
            offset_c2 = d_idx * delta2
            lag_u1 = u1_dev + (offset_a1 - offset_c1)
            lag_u2 = u2_dev + (offset_a2 - offset_c2)
            term_cov = cov_x_exponential(lag_u1, lag_u2, t_dev, params)
            if torch.isnan(term_cov).any():
                 return torch.full_like(cov, float('nan'))
            cov += w_ab * w_cd * term_cov

    if torch.isnan(cov).any():
        print("Warning: NaN detected in final cov_spatial_difference output.")
    return cov

# --- Covariance of Spatio-Temporal (First Temp Diff) Field Y1 ---
def cov_spacetime_difference(u1, u2, t, params, delta1, delta2):
    """
    Calculates covariance for Y1(s,t) = Z(s,t) - Z(s,t-1). (Unchanged from previous)
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    term_center = cov_spatial_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)
    term_minus_1 = cov_spatial_difference(u1_dev, u2_dev, t_dev - 1.0, params, delta1, delta2)
    term_plus_1 = cov_spatial_difference(u1_dev, u2_dev, t_dev + 1.0, params, delta1, delta2)

    if torch.isnan(term_center).any() or torch.isnan(term_minus_1).any() or torch.isnan(term_plus_1).any():
        out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    return 2.0 * term_center - term_minus_1 - term_plus_1

# --- (NEW) Covariance of Spatio-Temporal (Second Temp Diff) Field Y2 ---
def cov_spacetime_second_diff(u1, u2, t, params, delta1, delta2):
    """
    ✅ Calculates covariance for the new Spat + 2x Temp filter:
    Y2(s,t) = Y1(s,t) - Y1(s,t-1).
    This is C_Y2(u, tau) = 2*C_Y1(u, tau) - C_Y1(u, tau-1) - C_Y1(u, tau+1)
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    # C_Y1(u, tau)
    term_center = cov_spacetime_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)
    # C_Y1(u, tau - 1)
    term_minus_1 = cov_spacetime_difference(u1_dev, u2_dev, t_dev - 1.0, params, delta1, delta2)
    # C_Y1(u, tau + 1)
    term_plus_1 = cov_spacetime_difference(u1_dev, u2_dev, t_dev + 1.0, params, delta1, delta2)

    if torch.isnan(term_center).any() or torch.isnan(term_minus_1).any() or torch.isnan(term_plus_1).any():
        print("Warning: NaN detected in one of the terms of cov_spacetime_second_diff")
        out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    return 2.0 * term_center - term_minus_1 - term_plus_1


# --- (MODIFIED) cn_bar for NO TAPERING ---
def cn_bar_no_taper(u1, u2, t, params, n1, n2, delta1, delta2):
    """
    Computes c_Y2(u) * c_gn(u) where c_Y2 is cov_spacetime_second_diff
    and c_gn(u) is the Bartlett kernel.
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    # ✅ --- Call the NEW second-difference spatio-temporal covariance ---
    cov_Y2_value = cov_spacetime_second_diff(u1_dev, u2_dev, t_dev, params, delta1, delta2)

    c_gn_value = cgn_2dbartlett_kernel(u1_dev, u2_dev, n1, n2)

    if torch.isnan(cov_Y2_value).any() or torch.isnan(c_gn_value).any():
        print("Warning: NaN detected before multiplication in cn_bar_no_taper.")
        out_shape = torch.broadcast_shapes(cov_Y2_value.shape, c_gn_value.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    result = cov_Y2_value * c_gn_value
    if torch.isnan(result).any():
        print("Warning: NaN detected after multiplication in cn_bar_no_taper.")
    return result


# --- Expected Periodogram (uses cn_bar_no_taper) ---
def expected_periodogram_fft_no_taper(params, n1, n2, p, delta1, delta2):
    """
    Calculates the expected periodogram. (Unchanged structure)
    This function is correct because it calls the modified cn_bar_no_taper.
    """
    device = params.device if isinstance(params, torch.Tensor) else params[0].device
    params_tensor = params.to(device)

    product_tensor = torch.zeros((n1, n2, p, p), dtype=torch.complex64, device=device)
    t_lags = torch.arange(p, dtype=torch.float32, device=device)
    u1_mesh_grid, u2_mesh_grid = torch.meshgrid(
        torch.arange(n1, dtype=torch.float32, device=device),
        torch.arange(n2, dtype=torch.float32, device=device),
        indexing='ij'
    )

    for q in range(p):
        for r in range(p):
            t_diff = t_lags[q] - t_lags[r]
            # This call now correctly leads to cov_spacetime_second_diff
            cov_times_bartlett = cn_bar_no_taper(
                u1_mesh_grid, u2_mesh_grid, t_diff,
                params_tensor, n1, n2, delta1, delta2
            )
            if torch.isnan(cov_times_bartlett).any():
                 product_tensor[:, :, q, r] = float('nan')
            else:
                 product_tensor[:, :, q, r] = cov_times_bartlett.to(torch.complex64)

    if torch.isnan(product_tensor).any():
        print("Warning: NaN detected in product_tensor before FFT.")
        nan_shape = (n1, n2, p, p)
        return torch.full(nan_shape, float('nan'), dtype=torch.complex64, device=device)

    fft_result = torch.fft.fft2(product_tensor, dim=(0, 1))
    normalization_factor = 1.0 / (4.0 * cmath.pi**2)
    result = fft_result * normalization_factor

    if torch.isnan(result).any():
        print("Warning: NaN detected in expected_periodogram_fft_no_taper output after FFT.")
    return result


# =========================================================================
# 2. Data Processing (Unchanged)
# =========================================================================
def generate_Jvector_no_taper(tensor_list, lat_col, lon_col, val_col, device):
    """
    Generates J-vector for g_s=1 (NO taper). (Unchanged)
    """
    p = len(tensor_list)
    if p == 0: return torch.empty(0, 0, 0, device=device), 0, 0, 0

    valid_tensors = [t for t in tensor_list if t.numel() > 0 and t.shape[1] > max(lat_col, lon_col, val_col)]
    if not valid_tensors:
         return torch.empty(0, 0, 0, device=device), 0, 0, 0

    try:
        all_lats_cpu = torch.cat([t[:, lat_col] for t in valid_tensors])
        all_lons_cpu = torch.cat([t[:, lon_col] for t in valid_tensors])
    except IndexError:
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    all_lats_cpu = all_lats_cpu[~torch.isnan(all_lats_cpu) & ~torch.isinf(all_lats_cpu)]
    all_lons_cpu = all_lons_cpu[~torch.isnan(all_lons_cpu) & ~torch.isinf(all_lons_cpu)]
    if all_lats_cpu.numel() == 0 or all_lons_cpu.numel() == 0:
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    unique_lats_cpu, unique_lons_cpu = torch.unique(all_lats_cpu), torch.unique(all_lons_cpu)
    n1, n2 = len(unique_lats_cpu), len(unique_lons_cpu)
    if n1 == 0 or n2 == 0:
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats_cpu)}
    lon_map = {lon.item(): i for i, lon in enumerate(unique_lons_cpu)}

    fft_results = []
    for tensor in tensor_list:
        data_grid = torch.zeros((n1, n2), dtype=torch.float32, device=device)
        tensor_dev = tensor.to(device)
        for row in tensor_dev:
            lat_item, lon_item = row[lat_col].item(), row[lon_col].item()
            if not (np.isnan(lat_item) or np.isnan(lon_item)):
                i = lat_map.get(lat_item)
                j = lon_map.get(lon_item)
                if i is not None and j is not None:
                    val = row[val_col]
                    val_num = val.item() if isinstance(val, torch.Tensor) else val
                    if not np.isnan(val_num) and not np.isinf(val_num):
                        data_grid[i, j] = val_num

        data_grid = torch.nan_to_num(data_grid, nan=0.0, posinf=0.0, neginf=0.0)
        fft_results.append(torch.fft.fft2(data_grid))

    if not fft_results:
         return torch.empty(0, 0, 0, device=device), n1, n2, 0

    J_vector_tensor = torch.stack(fft_results, dim=2).to(device)

    H = float(n1 * n2)
    if H < 1e-9:
        norm_factor = torch.tensor(0.0, device=device)
    else:
        norm_factor = (torch.sqrt(torch.tensor(1.0 / H, device=device)) / (2.0 * cmath.pi))

    result = J_vector_tensor * norm_factor
    return result, n1, n2, p


def calculate_sample_periodogram_vectorized(J_vector_tensor):
    """Calculates sample periodogram I_n = J J^H. (Unchanged)"""
    if torch.isnan(J_vector_tensor).any() or torch.isinf(J_vector_tensor).any():
        n1, n2, p = J_vector_tensor.shape
        return torch.full((n1, n2, p, p), float('nan'), dtype=torch.complex64, device=J_vector_tensor.device)

    J_col = J_vector_tensor.unsqueeze(-1)
    J_row_conj = J_vector_tensor.unsqueeze(-2).conj()
    result = J_col @ J_row_conj
    return result


# =========================================================================
# 4. Likelihood Calculation (Unchanged)
# =========================================================================

def whittle_likelihood_loss_no_taper(params, I_sample, n1, n2, p, delta1, delta2):
    """
    Whittle Likelihood Loss. (Unchanged structure)
    """
    device = I_sample.device
    params_tensor = params.to(device)

    if torch.isnan(params_tensor).any() or torch.isinf(params_tensor).any():
        return torch.tensor(float('nan'), device=device)

    # This call now correctly leads to cov_spacetime_second_diff
    I_expected = expected_periodogram_fft_no_taper(
        params_tensor, n1, n2, p, delta1, delta2
    )

    if torch.isnan(I_expected).any() or torch.isinf(I_expected).any():
        return torch.tensor(float('nan'), device=device)

    eye_matrix = torch.eye(p, dtype=torch.complex64, device=device)
    diag_vals = torch.abs(I_expected.diagonal(dim1=-2, dim2=-1))
    mean_diag_abs = diag_vals.mean().item() if diag_vals.numel() > 0 and not torch.isnan(diag_vals).all() else 1.0
    diag_load = max(mean_diag_abs * 1e-8, 1e-9) 
    
    I_expected_stable = I_expected + eye_matrix * diag_load

    sign, logabsdet = torch.linalg.slogdet(I_expected_stable)
    if torch.any(sign.real <= 1e-9):
        log_det_term = torch.where(sign.real > 1e-9, logabsdet, torch.tensor(1e10, device=device))
    else:
        log_det_term = logabsdet

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        return torch.tensor(float('nan'), device=device)

    try:
        solved_term = torch.linalg.solve(I_expected_stable, I_sample)
        trace_term = torch.einsum('...ii->...', solved_term).real
    except torch.linalg.LinAlgError as e:
        return torch.tensor(float('inf'), device=device)

    if torch.isnan(trace_term).any() or torch.isinf(trace_term).any():
        return torch.tensor(float('nan'), device=device)

    likelihood_terms = log_det_term + trace_term

    if torch.isnan(likelihood_terms).any():
        return torch.tensor(float('nan'), device=device)

    total_sum = torch.sum(likelihood_terms)
    dc_term = likelihood_terms[0, 0] if n1 > 0 and n2 > 0 else torch.tensor(0.0, device=device)
    if torch.isnan(dc_term).any() or torch.isinf(dc_term).any():
        dc_term = torch.tensor(0.0, device=device)

    loss = total_sum - dc_term if (n1 > 1 or n2 > 1) else total_sum

    if torch.isnan(loss) or torch.isinf(loss):
         return torch.tensor(float('inf'), device=device) 

    return loss


# =========================================================================
# 5. Training Loop (Unchanged)
# =========================================================================
def run_full(params_list, optimizer, scheduler, I_sample, n1, n2, p, epochs=600, device='cpu'):
    """Corrected training loop. (Unchanged)"""
    best_loss = float('inf')
    params_list = [p.to(device) for p in params_list]
    best_params_state = [p.detach().clone() for p in params_list]
    epochs_completed = 0
    DELTA_LAT, DELTA_LON = 0.044, 0.063 # Still needed

    def get_printable_params(p_list):
        valid_tensors = [p for p in p_list if isinstance(p, torch.Tensor)]
        if not valid_tensors: return "Invalid params_list"
        p_cat = torch.cat([p.detach().clone().cpu() for p in valid_tensors])
        log_indices = [0, 1, 2, 6]
        if all(idx < len(p_cat) for idx in log_indices):
            log_vals = p_cat[log_indices]
            if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
                 p_cat[log_indices] = torch.exp(log_vals)
            else:
                 p_cat[log_indices] = float('nan')
        return p_cat.numpy().round(4)

    I_sample_dev = I_sample.to(device)

    for epoch in range(epochs):
        epochs_completed = epoch + 1
        optimizer.zero_grad()
        params_tensor = torch.cat(params_list)

        loss = whittle_likelihood_loss_no_taper(
            params_tensor, I_sample_dev, n1, n2, p, DELTA_LAT, DELTA_LON
        )

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Loss became NaN or Inf at epoch {epoch+1}. Stopping.")
            if epoch == 0: best_params_state = None
            epochs_completed = epoch
            break 

        loss.backward()

        nan_grad = False
        for param in params_list:
            if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                nan_grad = True
                break
        if nan_grad:
             optimizer.zero_grad() 
             continue 

        all_params_on_device = params_list
        if all_params_on_device:
            torch.nn.utils.clip_grad_norm_(all_params_on_device, max_norm=1.0)

        optimizer.step()
        scheduler.step() 

        current_loss_item = loss.item()
        if current_loss_item < best_loss:
            params_valid = not any(torch.isnan(p.data).any() or torch.isinf(p.data).any() for p in params_list)
            if params_valid:
                best_loss = current_loss_item
                best_params_state = [p.detach().clone() for p in params_list]

        if epoch % 50 == 0 or epoch == epochs - 1:
            current_lr = optimizer.param_groups[0]['lr'] if optimizer.param_groups else 0.0
            print(f'--- Epoch {epoch+1}/{epochs} (LR: {current_lr:.6f}) ---')
            print(f' Loss: {current_loss_item:.4f}')
            print(f' Parameters (Natural Scale): {get_printable_params(params_list)}')

    if best_params_state is None:
        return None, epochs_completed

    final_params_log_scale = torch.cat([p.cpu() for p in best_params_state])
    final_params_natural_scale = final_params_log_scale.detach().clone()
    log_indices = [0, 1, 2, 6]
    if all(idx < len(final_params_natural_scale) for idx in log_indices):
        log_vals = final_params_natural_scale[log_indices]
        if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
            final_params_natural_scale[log_indices] = torch.exp(log_vals)
        else:
            final_params_natural_scale[log_indices] = float('nan')

    final_params_rounded = [round(p.item(), 4) if not np.isnan(p.item()) else float('nan') for p in final_params_natural_scale]
    final_loss_rounded = round(best_loss, 3) if best_loss != float('inf') else float('inf')

    print("\n--- Training Complete ---")
    print(f'\nFINAL BEST STATE ACHIEVED (during training):')
    print(f'Best Loss: {final_loss_rounded}')
    print(f'Parameters Corresponding to Best Loss (Natural Scale): {final_params_rounded}')

    return final_params_rounded + [final_loss_rounded], epochs_completed


# =========================================================================
# 6. Main Execution Script (MODIFIED for Spat + 2x Temp Data)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    NUM_RUNS = 1 
    EPOCHS = 700 
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- Grid Spacing ---
    DELTA_LAT, DELTA_LON = 0.044, 0.063 

    # --- Column Indices ---
    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2
    TIME_COL = 3

    # --- (MODIFIED) Load Spat + 2x Temp Differenced Data ---
    try:
        # ✅ --- LOAD THE NEW FILE ---
        with open("spacetime_second_diff_data.pkl", 'rb') as f:
            processed_df = pickle.load(f)
        print(f"Loaded {len(processed_df)} days from spacetime_second_diff_data.pkl.")
        
        processed_df = [
            torch.tensor(arr, dtype=torch.float32).cpu() if not isinstance(arr, torch.Tensor)
            else arr.cpu().to(torch.float32)
            for arr in processed_df
        ]
        if not processed_df: raise ValueError("'processed_df' is empty.")
    except FileNotFoundError:
        print("Error: `spacetime_second_diff_data.pkl` not found.")
        print("Please run the data preparation script first.")
        exit()
    except Exception as e:
        print(f"Error loading or processing 'processed_df': {e}")
        exit()

    # Day selection needs care: 2 temporal diffs mean we lose 2 days
    # If DAY_TO_RUN refers to the *original* day index, adjust access
    # processed_df[0] corresponds to original day 3 (index 2)
    adjusted_day_index = DAY_TO_RUN - 3 # Example: if DAY_TO_RUN=3, access index 0

    # Ensure adjusted index is valid
    if adjusted_day_index < 0 or adjusted_day_index >= len(processed_df):
        print(f"Error: DAY_TO_RUN ({DAY_TO_RUN}) is invalid for the doubly differenced data (valid range approx 3 to {len(processed_df)+2}).")
        exit()

    cur_df = processed_df[adjusted_day_index] # Use adjusted index
    
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL):
        print(f"Error: Data for Day {DAY_TO_RUN} (adjusted index {adjusted_day_index}) is empty or invalid.")
        exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute Sample Periodogram (NO Tapering) ---
    print("Pre-computing sample periodogram (NO data taper)...")
    J_vec, n1, n2, p = generate_Jvector_no_taper(
        time_slices_list,
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0:
       print(f"Error: J-vector generation failed for Day {DAY_TO_RUN} (adjusted index {adjusted_day_index}).")
       exit()

    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Error: NaN or Inf detected in the sample periodogram. Cannot proceed.")
        exit()

    print(f"Data grid: {n1}x{n2} spatial points, {p} time points. Sample Periodogram on {DEVICE}.")
    
    # --- 2. Optimization Loop ---
    all_final_results = []
    all_final_losses = []

    for i in range(NUM_RUNS):
        print(f"\n{'='*30} Initialization Run {i+1}/{NUM_RUNS} {'='*30}")

        # ✅ --- Use Fixed Initial Parameters ---
        initial_params_values = [
            np.log(21.303), np.log(1.307), np.log(1.563), 
            0.022, -0.144, 0.198, 
            np.log(4.769)  
        ]
        
        print(f"Starting with fixed params (log-scale for [0,1,2,6]): {[round(p, 4) for p in initial_params_values]}")

        params_list = [
            Parameter(torch.tensor([val], dtype=torch.float32))
            for val in initial_params_values
        ]

        lr_slow, lr_fast = 0.005, 0.02
        slow_indices = [0, 1, 2, 6]
        fast_indices = [3, 4, 5]

        valid_slow_indices = [idx for idx in slow_indices if idx < len(params_list)]
        valid_fast_indices = [idx for idx in fast_indices if idx < len(params_list)]

        param_groups = [
            {'params': [params_list[idx] for idx in valid_slow_indices], 'lr': lr_slow, 'name': 'slow_group'},
            {'params': [params_list[idx] for idx in valid_fast_indices], 'lr': lr_fast, 'name': 'fast_group'}
        ]

        optimizer = torch.optim.Adam(param_groups)
        scheduler = CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)

        print(f"Starting optimization run {i+1} on device {DEVICE} (NO data taper, Spat + 2x Temp Diff)...")
        final_results, epochs_run = run_full(
            params_list=params_list,
            optimizer=optimizer,
            scheduler=scheduler,
            I_sample=I_sample,
            n1=n1, n2=n2, p=p,
            epochs=EPOCHS,
            device=DEVICE
        )

        if final_results:
            all_final_results.append(final_results)
            all_final_losses.append(final_results[-1])
        else:
            all_final_results.append(None)
            all_final_losses.append(float('inf'))

    print(f"\n\n{'='*25} Overall Result from Run {'='*25}")
    valid_losses = [l for l in all_final_losses if l is not None and l != float('inf')]

    if not valid_losses:
        print(f"The run failed or resulted in an invalid loss for Day {DAY_TO_RUN} (adjusted index {adjusted_day_index}).")
    else:
        best_loss = valid_losses[0]
        best_run_index = 0
        best_results = all_final_results[best_run_index]
        print(f"Run Loss: {best_results[-1]}")
        print(f"Final Parameters (Natural Scale): {best_results[:-1]}")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Loaded 29 days from spacetime_second_diff_data.pkl.
Error: DAY_TO_RUN (1) is invalid for the doubly differenced data (valid range approx 3 to 31).
Pre-computing sample periodogram (NO data taper)...
Data grid: 113x8 spatial points, 8 time points. Sample Periodogram on cpu.

Starting with fixed params (log-scale for [0,1,2,6]): [3.0588, 0.2677, 0.4466, 0.022, -0.144, 0.198, 1.5621]
Starting optimization run 1 on device cpu (NO data taper, Spat + 2x Temp Diff)...
--- Epoch 1/700 (LR: 0.005000) ---
 Loss: 140469.1250
 Parameters (Natural Scale): [21.4098  1.3005  1.5552  0.042  -0.124   0.178   4.7929]
--- Epoch 51/700 (LR: 0.004240) ---
 Loss: 111063.8125
 Parameters (Natural Scale): [ 2.6071e+01  1.0695e+00  1.2731e+00 -4.5000e-03  1.2600e-02  2.0000e-04
  5.6888e+00]
--- Epoch 101/700 (LR: 0.002461) ---
 Loss: 94843.8984
 Parameters (Natural Scale): [ 3.03786e+01  9.43800e-01  1.08220e+00  1.80000e-03  1.68000e-02
 -7.00000e-04  6.34220e+00]
--- Epoch 151/700 (LR: 0.0

: 