In [53]:
# Standard libraries
import sys
# Add your custom path
gems_tco_path = "/Users/joonwonlee/Documents/GEMS_TCO-1/src"
sys.path.append(gems_tco_path)
import logging
import argparse # Argument parsing

# Data manipulation and analysis
import pandas as pd
import numpy as np
import pickle
import torch
import torch.optim as optim
import copy                    # clone tensor
import time

# Custom imports
import GEMS_TCO
from GEMS_TCO import kernels

from GEMS_TCO import kernels 
from GEMS_TCO import orderings as _orderings 
from GEMS_TCO import load_data
from GEMS_TCO import alg_optimization, alg_opt_Encoder
from GEMS_TCO import configuration as config

from typing import Optional, List, Tuple
from pathlib import Path
import typer
import json
from json import JSONEncoder

from GEMS_TCO import configuration as config
from GEMS_TCO import data_preprocess as dmbh

import os
from sklearn.neighbors import BallTree

In [54]:
import pickle
import os
# Assume your 'config' object is available
# import config

# --- 1. Configuration ---
# Specify the year and month you want to load
YEAR_TO_LOAD = 2024
MONTH_TO_LOAD = 7

# Use the same base path as your saving script
BASE_PATH = config.mac_data_load_path

# --- 2. Construct the File Path ---
# This must exactly match the naming convention from your saving script
month_str = f"{MONTH_TO_LOAD:02d}"
pickle_path = os.path.join(BASE_PATH, f'pickle_{YEAR_TO_LOAD}')
filename = f"coarse_cen_map_without_decrement_latitude{str(YEAR_TO_LOAD)[2:]}_{month_str}.pkl"
filepath_to_load = os.path.join(pickle_path, filename)

print(f"Attempting to load data from: {filepath_to_load}")

# --- 3. Load the Data ---
try:
    with open(filepath_to_load, 'rb') as pickle_file:
        # Use pickle.load() to read the data from the file
        loaded_coarse_map = pickle.load(pickle_file)
    
    print("\nData loaded successfully! ✅")
    
    # --- 4. Verify the Loaded Data ---
    # The loaded data is a dictionary. Let's inspect it.
    print(f"Type of loaded data: {type(loaded_coarse_map)}")
    if isinstance(loaded_coarse_map, dict):
        print(f"Number of entries (hours) in the map: {len(loaded_coarse_map)}")
        # Print the first 5 keys to see what they look like
        first_five_keys = list(loaded_coarse_map.keys())[:5]
        print(f"Example keys: {first_five_keys}")
        
        # You can now access the data for a specific hour, for example:
        # first_hour_data = loaded_coarse_map[first_five_keys[0]]
        # print(f"\nData for first hour is a tensor of shape: {first_hour_data.shape}")

except FileNotFoundError:
    print(f"\nError: File not found. Please check if the file exists at the specified path.")
except Exception as e:
    print(f"\nAn error occurred: {e}")


print(loaded_coarse_map['y24m07day01_hm00:53']['Longitude'].nunique())
print(loaded_coarse_map['y24m07day01_hm00:53']['Latitude'].nunique())

import GEMS_TCO
load_data_instance = GEMS_TCO.load_data('')

df_day_aggregated_list = []
df_day_map_list = []
for i in range(31):
    cur_map, cur_df =load_data_instance.load_working_data_byday_wo_mm(loaded_coarse_map,[i*8, (i+1)*8])
    df_day_aggregated_list.append( cur_df )
    df_day_map_list.append( cur_map )

Attempting to load data from: /Users/joonwonlee/Documents/GEMS_DATA/pickle_2024/coarse_cen_map_without_decrement_latitude24_07.pkl

Data loaded successfully! ✅
Type of loaded data: <class 'dict'>
Number of entries (hours) in the map: 248
Example keys: ['y24m07day01_hm00:53', 'y24m07day01_hm01:53', 'y24m07day01_hm02:53', 'y24m07day01_hm03:53', 'y24m07day01_hm04:49']
270
273


In [55]:
lon_s = 123
lon_e = 133
step_lat = 0.044
step_lon = 0.063

lat_coords = np.arange( 5 -0.044- 0.0002, 0 -0.044, -0.044)
lon_coords = np.arange( lon_e-step_lon- 0.0002, lon_s-step_lon, -step_lon)

# Apply the shift as in the original code
# These are the unique lat/lon values for the "center_points" grid
final_lat_values = lat_coords + step_lat 
final_lon_values = lon_coords + step_lon 

# Create 2D grid with broadcasting
#decrement = 0.00012
decrement = 0 
lat_grid = final_lat_values[:, None] + np.arange(len(final_lon_values)) * decrement  # shape: (228, 152)


mac_data_path = config.mac_data_load_path
years = [2024]  # years = [2023,2024]
months = list( range(7,8))
year = years[0]
month = months[0]
month_str = f"{month:02d}"  
filename = f"pickle_2024/orbit_map{str(year)[2:]}_{month_str}.pkl"
picklefile_path = Path(mac_data_path) / filename
print(picklefile_path)

with open(picklefile_path, 'rb') as pickle_file:
    data_map_hour = pickle.load(pickle_file)

# Base file path and settings
# base_path = "C:\\Users\\joonw\\TCO\\GEMS_data"    MSI notebook

mac_data_path = config.mac_data_load_path
lat_start, lat_end, lon_start, lon_end = 0, 5, 123, 133
step_lat, step_lon = 0.044, 0.063

# df = pd.read_csv("C:\\Users\\joonw\\TCO\\GEMS_data\\data_2024\\data_24_07_0131_N510_E110120.csv")  MSI notebook
df = pd.read_csv("/Users/joonwonlee/Documents/GEMS_DATA/data_2024/data_24_07_0131_N05_E123133.csv")  # MAC

/Users/joonwonlee/Documents/GEMS_DATA/pickle_2024/orbit_map24_07.pkl


# before preceding, change cur_map, _ = load_data_instance.load_working_data_byday_wo_mm(cbmap_ori, [i*8, i*8+1])

In [69]:
import torch
import numpy as np
import torch.nn.functional as F
import os
import pickle

# Assume GEMS_TCO is a custom class/module you have available
# from your_project import GEMS_TCO

# =========================================================================
# 1. Helper Functions
# =========================================================================

def subset_tensor(df_tensor: torch.Tensor) -> torch.Tensor:
    """Subsets a tensor to a specific lat/lon range."""
    #lat_mask = (df_tensor[:, 0] >= -5) & (df_tensor[:, 0] <= 6.3)
    #lon_mask = (df_tensor[:, 1] >= 118) & (df_tensor[:, 1] <= 134.2)
    lat_mask = (df_tensor[:, 0] >= 0) & (df_tensor[:, 0] <= 5)
    lon_mask = (df_tensor[:, 1] >= 123) & (df_tensor[:, 1] <= 133)

    df_sub = df_tensor[lat_mask & lon_mask].clone()
    return df_sub

def apply_first_difference_2d_tensor(df_tensor: torch.Tensor) -> torch.Tensor:
    """
    Applies a 2D first-order difference filter using convolution.
    This approximates Z(s) = [X(s+d_lat) - X(s)] + [X(s+d_lon) - X(s)].
    """
    if df_tensor.size(0) == 0:
        return torch.empty(0, 4)

    # 1. Get grid dimensions and validate
    unique_lats = torch.unique(df_tensor[:, 0])
    unique_lons = torch.unique(df_tensor[:, 1])
    lat_count, lon_count = unique_lats.size(0), unique_lons.size(0)

    if df_tensor.size(0) != lat_count * lon_count:
        raise ValueError("Tensor size does not match grid dimensions. Must be a complete grid.")
    if lat_count < 2 or lon_count < 2:
        return torch.empty(0, 4)

    # 2. Reshape data and define the correct kernel
    ozone_data = df_tensor[:, 2].reshape(1, 1, lat_count, lon_count)
    
    # ✅ CORRECT KERNEL: This kernel results in the standard first-order difference:
    # Z(i,j) = X(i+1,j) + X(i,j+1) - 2*X(i,j)
    # Note: F.conv2d in PyTorch actually performs cross-correlation. To get a true
    # convolution result, the kernel would need to be flipped. However, for a 
    # forward difference operator, defining the kernel for cross-correlation is more direct.
    # The kernel below is designed for cross-correlation to achieve the desired differencing.
    diff_kernel = torch.tensor([[[[-2., 1.],
                                  [ 1., 0.]]]], dtype=torch.float32)

    # 3. Apply convolution (which acts as cross-correlation)
    filtered_grid = F.conv2d(ozone_data, diff_kernel, padding='valid').squeeze()

    # 4. Determine coordinates for the new, smaller grid
    # The new grid corresponds to the anchor points of the kernel
    new_lats = unique_lats[:-1]
    new_lons = unique_lons[:-1]

    # 5. Reconstruct the output tensor
    new_lat_grid, new_lon_grid = torch.meshgrid(new_lats, new_lons, indexing='ij')
    filtered_values = filtered_grid.flatten()
    time_value = df_tensor[0, 3].repeat(filtered_values.size(0))

    new_tensor = torch.stack([
        new_lat_grid.flatten(),
        new_lon_grid.flatten(),
        filtered_values,
        time_value
    ], dim=1)
    
    return new_tensor

# =========================================================================
# 2. Data Loading (Unchanged)
# =========================================================================
# ⚠️ NOTE: You must define these variables
# mac_data_path = "..."
# year = 2022
# month_str = "01"
# class GEMS_TCO: # Placeholder
#     def load_data(self, path): return self
#     def load_working_data_byday_wo_mm(self, data, indices):
#         return {'key': torch.randn(100, 4)}, torch.randn(100, 4)

pickle_path = os.path.join(mac_data_path, f'pickle_{year}')
output_filename = f"coarse_cen_map_without_decrement_latitude{str(year)[2:]}_{month_str}.pkl"
output_filepath = os.path.join(pickle_path, output_filename)
print(f"Loading data from: {output_filepath}")
with open(output_filepath, 'rb') as pickle_file:
    cbmap_ori = pickle.load(pickle_file)

load_data_instance = GEMS_TCO.load_data('')
df_day_map_list = []
for i in range(31): # Adjust if necessary


    
    #cur_map, _ = load_data_instance.load_working_data_byday_wo_mm(cbmap_ori, [i*8, (i+1)*8])

    # hour of data
    cur_map, _ = load_data_instance.load_working_data_byday_wo_mm(cbmap_ori, [i*8, i*8+1])
    df_day_map_list.append(cur_map)
print(f"Loaded {len(df_day_map_list)} days of raw data.")

# =========================================================================
# 3. Main Processing Loop (Unchanged)
# =========================================================================
spatially_filtered_days = []
for day_idx, day_map in enumerate(df_day_map_list):
    tensors_to_aggregate = []
    for key, tensor in day_map.items():
        subsetted = subset_tensor(tensor)
        if subsetted.size(0) > 0:
            try:
                diff_applied = apply_first_difference_2d_tensor(subsetted)
                if diff_applied.size(0) > 0:
                    tensors_to_aggregate.append(diff_applied)
            except ValueError as e:
                print(f"Skipping data chunk on day {day_idx+1} due to error: {e}")

    if tensors_to_aggregate:
        aggregated_day_tensor = torch.cat(tensors_to_aggregate, dim=0)
        spatially_filtered_days.append(aggregated_day_tensor)

# =========================================================================
# 4. Verification (Unchanged)
# =========================================================================
print("\n--- Results ---")
print(f"Number of final spatially-differenced day tensors: {len(spatially_filtered_days)}")
if spatially_filtered_days:
    # Save the processed data for the next script
    processed_output_path = "spatial_first_difference_data.pkl"
    with open(processed_output_path, 'wb') as f:
        pickle.dump(spatially_filtered_days, f)
    print(f"Processed data saved to {processed_output_path}")

    print(f"\nShape of the first final tensor: {spatially_filtered_days[0].shape}")
    print("First final tensor head:")
    print(spatially_filtered_days[0][:5])
else:
    print("\nNo final differenced tensors were created.")

Loading data from: /Users/joonwonlee/Documents/GEMS_DATA/pickle_2024/coarse_cen_map_without_decrement_latitude24_07.pkl
Loaded 31 days of raw data.

--- Results ---
Number of final spatially-differenced day tensors: 31
Processed data saved to spatial_first_difference_data.pkl

Shape of the first final tensor: torch.Size([17854, 4])
First final tensor head:
tensor([[ 4.0000e-03,  1.2303e+02,  2.9422e+00,  2.1000e+01],
        [ 4.0000e-03,  1.2309e+02,  1.9636e+00,  2.1000e+01],
        [ 4.0000e-03,  1.2316e+02, -1.3187e+00,  2.1000e+01],
        [ 4.0000e-03,  1.2322e+02, -3.1683e+00,  2.1000e+01],
        [ 4.0000e-03,  1.2328e+02, -5.4922e-01,  2.1000e+01]])


## 10/20/25. No tapering, no approximate autocorrelation of taper by bartlett taper

In [42]:
import torch
import numpy as np
import matplotlib.pyplot as plt # Keep if plotting might be added later
import cmath
import pickle
import time # For timing
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import Parameter

# =========================================================================
# 1. Modeling Functions (Adapted for No Tapering, EXPONENTIAL Kernel)
# =========================================================================

# --- Bartlett Kernel (Used for c_gn when g_s=1) ---
def cgn_2dbartlett_kernel(u1, u2, n1, n2):
    """
    Computes the 2D Bartlett kernel: Product(1 - |ui|/ni).
    This represents c_gn(u) when g_s=1.
    """
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    kernel = (1.0 - torch.abs(u1_tensor) / n1_eff) * (1.0 - torch.abs(u2_tensor) / n2_eff)
    return torch.clamp(kernel, min=0.0)

# --- Covariance of the Original Field X (Corrected EXPONENTIAL Kernel) ---
def cov_x_exponential(u1, u2, t, params):
    """
    Computes the spatio-temporal autocovariance of the original, unfiltered process X
    using an EXPONENTIAL kernel (sqrt distance in exponent).
    Expects log-scale params [0,1,2,6].
    """
    device = params.device # Assuming params is a tensor now
    # Ensure inputs are tensors and on correct device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    log_params_indices = [0, 1, 2, 6]
    if torch.isnan(params[log_params_indices]).any() or torch.isinf(params[log_params_indices]).any():
         print("Warning: NaN/Inf in log-params before exp in cov_x_exponential")
         out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
         return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    sigmasq, r_lat, r_lon, nugget = torch.exp(params[log_params_indices])
    a_lat, a_lon, beta = params[3], params[4], params[5]

    if r_lat < 1e-6 or r_lon < 1e-6:
        print(f"Warning: Very small range detected (r_lat={r_lat:.2e}, r_lon={r_lon:.2e}). Clamping.")
        r_lat = torch.clamp(r_lat, min=1e-6)
        r_lon = torch.clamp(r_lon, min=1e-6)

    # Calculate Distance D for Exponential kernel
    x1 = u1_dev / r_lat - a_lat * t_dev
    x2 = u2_dev / r_lon - a_lon * t_dev
    x3 = beta * t_dev
    distance_sq = x1**2 + x2**2 + x3**2
    epsilon = 1e-12
    # Clamp before sqrt and add epsilon inside for stability
    distance_sq_clamped = torch.clamp(distance_sq, min=0.0)
    D = torch.sqrt(distance_sq_clamped + epsilon) # Use sqrt distance D

    # Calculate covariance C_X = sigmasq * exp(-D)
    cov_smooth = sigmasq * torch.exp(-D) # Use D here

    is_zero_lag = (torch.abs(u1_dev) < 1e-9) & (torch.abs(u2_dev) < 1e-9) & (torch.abs(t_dev) < 1e-9)
    final_cov = torch.where(is_zero_lag, cov_smooth + nugget, cov_smooth)

    if torch.isnan(final_cov).any():
        print("Warning: NaN detected in cov_x_exponential output.")
    return final_cov


# --- Covariance of the Spatially Differenced Field Y ---
def cov_spatial_difference(u1, u2, t, params, delta1, delta2):
    """
    Calculates covariance Cov(Y(s), Y(s+u)) for the field Y filtered by:
    Y(s) = X(s+d1) + X(s+d2) - 2X(s)
    Based on the underlying EXPONENTIAL covariance cov_x_exponential.
    """
    weights = {(0, 0): -2.0, (1, 0): 1.0, (0, 1): 1.0}
    device = params.device
    out_shape = torch.broadcast_shapes(u1.shape if isinstance(u1, torch.Tensor) else (),
                                     u2.shape if isinstance(u2, torch.Tensor) else (),
                                     t.shape if isinstance(t, torch.Tensor) else ())
    cov = torch.zeros(out_shape, device=device, dtype=torch.float32)

    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    for (a_idx, b_idx), w_ab in weights.items():
        offset_a1 = a_idx * delta1
        offset_a2 = b_idx * delta2
        for (c_idx, d_idx), w_cd in weights.items():
            offset_c1 = c_idx * delta1
            offset_c2 = d_idx * delta2
            lag_u1 = u1_dev + (offset_a1 - offset_c1)
            lag_u2 = u2_dev + (offset_a2 - offset_c2)
            # --- Call the EXPONENTIAL version ---
            term_cov = cov_x_exponential(lag_u1, lag_u2, t_dev, params)
            # ---
            if torch.isnan(term_cov).any():
                 print(f"Warning: NaN detected in term_cov within cov_spatial_difference.")
                 return torch.full_like(cov, float('nan'))
            cov += w_ab * w_cd * term_cov

    if torch.isnan(cov).any():
        print("Warning: NaN detected in final cov_spatial_difference output.")
    return cov


# --- Modified cn_bar for NO TAPERING (uses Bartlett kernel for c_gn) ---
def cn_bar_no_taper(u1, u2, t, params, n1, n2, delta1, delta2):
    """
    Computes c_X(u) * c_gn(u) where c_X is cov_spatial_difference (using Exponential base)
    and c_gn(u) is the Bartlett kernel (Eq. 16).
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    # Calculate theoretical covariance of the differenced field Y (based on Exponential X)
    cov_X_value = cov_spatial_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)

    # Calculate Bartlett Kernel c_gn(u) for the no-taper case
    c_gn_value = cgn_2dbartlett_kernel(u1_dev, u2_dev, n1, n2)

    if torch.isnan(cov_X_value).any() or torch.isnan(c_gn_value).any():
        print("Warning: NaN detected before multiplication in cn_bar_no_taper.")
        out_shape = torch.broadcast_shapes(cov_X_value.shape, c_gn_value.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    result = cov_X_value * c_gn_value
    if torch.isnan(result).any():
        print("Warning: NaN detected after multiplication in cn_bar_no_taper.")
    return result


# --- Expected Periodogram (uses cn_bar_no_taper) ---
def expected_periodogram_fft_no_taper(params, n1, n2, p, cov_func, delta1, delta2):
    """
    Calculates the expected periodogram I(omega_s) (a pxp matrix in time)
    for the spatially differenced process assuming NO data taper (g_s=1).
    Uses the provided cov_func (cov_spatial_difference based on Exponential X).
    Returns shape [n1_freq, n2_freq, p, p]
    """
    device = params.device if isinstance(params, torch.Tensor) else params[0].device
    if isinstance(params, list):
        params_tensor = torch.cat([p.to(device) for p in params])
    else:
        params_tensor = params.to(device)

    product_tensor = torch.zeros((n1, n2, p, p), dtype=torch.complex64, device=device)
    t_lags = torch.arange(p, dtype=torch.float32, device=device)
    u1_mesh_grid, u2_mesh_grid = torch.meshgrid(
        torch.arange(n1, dtype=torch.float32, device=device),
        torch.arange(n2, dtype=torch.float32, device=device),
        indexing='ij'
    )

    for q in range(p):
        for r in range(p):
            t_diff = t_lags[q] - t_lags[r]
            # Calls cn_bar_no_taper, which internally calls cov_spatial_difference
            # based on the EXPONENTIAL cov_x_exponential
            cov_times_bartlett = cn_bar_no_taper(
                u1_mesh_grid, u2_mesh_grid, t_diff,
                params_tensor, n1, n2, delta1, delta2
            )
            if torch.isnan(cov_times_bartlett).any():
                 print(f"Warning: NaN detected in cov_times_bartlett for t_lag {t_diff.item():.2f}.")
                 product_tensor[:, :, q, r] = float('nan')
            else:
                 product_tensor[:, :, q, r] = cov_times_bartlett.to(torch.complex64)

    if torch.isnan(product_tensor).any():
        print("Warning: NaN detected in product_tensor before FFT.")
        nan_shape = (n1, n2, p, p)
        return torch.full(nan_shape, float('nan'), dtype=torch.complex64, device=device)

    fft_result = torch.fft.fft2(product_tensor, dim=(0, 1))
    normalization_factor = 1.0 / (4.0 * cmath.pi**2)
    result = fft_result * normalization_factor

    if torch.isnan(result).any():
        print("Warning: NaN detected in expected_periodogram_fft_no_taper output after FFT.")
    return result


# =========================================================================
# 2. Data Processing (MODIFIED for NO TAPERING)
# =========================================================================
def generate_Jvector_no_taper(tensor_list, lat_col, lon_col, val_col, device):
    """
    Generates J-vector for a single component assuming g_s=1 (NO taper),
    placing result on device.
    """
    p = len(tensor_list)
    if p == 0: return torch.empty(0, 0, 0, device=device), 0, 0, 0

    valid_tensors = [t for t in tensor_list if t.numel() > 0 and t.shape[1] > max(lat_col, lon_col, val_col)]
    if not valid_tensors:
         print("Warning: No valid tensors found in tensor_list (empty or insufficient columns).")
         return torch.empty(0, 0, 0, device=device), 0, 0, 0

    try:
        all_lats_cpu = torch.cat([t[:, lat_col] for t in valid_tensors])
        all_lons_cpu = torch.cat([t[:, lon_col] for t in valid_tensors])
    except IndexError:
        print(f"Error: Invalid column index detected. lat_col={lat_col}, lon_col={lon_col}. Check tensor shapes.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    all_lats_cpu = all_lats_cpu[~torch.isnan(all_lats_cpu) & ~torch.isinf(all_lats_cpu)]
    all_lons_cpu = all_lons_cpu[~torch.isnan(all_lons_cpu) & ~torch.isinf(all_lons_cpu)]
    if all_lats_cpu.numel() == 0 or all_lons_cpu.numel() == 0:
        print("Warning: No valid coordinates found after NaN/Inf filtering.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    unique_lats_cpu, unique_lons_cpu = torch.unique(all_lats_cpu), torch.unique(all_lons_cpu)
    n1, n2 = len(unique_lats_cpu), len(unique_lons_cpu)
    if n1 == 0 or n2 == 0:
        print("Warning: Grid dimensions are zero after finding unique coordinates.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0

    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats_cpu)}
    lon_map = {lon.item(): i for i, lon in enumerate(unique_lons_cpu)}

    fft_results = []
    for tensor in tensor_list:
        data_grid = torch.zeros((n1, n2), dtype=torch.float32, device=device)
        if tensor.numel() > 0 and tensor.shape[1] > max(lat_col, lon_col, val_col):
            for row in tensor:
                lat_item, lon_item = row[lat_col].item(), row[lon_col].item()
                if not (np.isnan(lat_item) or np.isnan(lon_item)):
                    i = lat_map.get(lat_item)
                    j = lon_map.get(lon_item)
                    if i is not None and j is not None:
                        val = row[val_col]
                        val_num = val.item() if isinstance(val, torch.Tensor) else val
                        if not np.isnan(val_num) and not np.isinf(val_num):
                           data_grid[i, j] = val_num

        if torch.isnan(data_grid).any() or torch.isinf(data_grid).any():
             print("Warning: NaN/Inf detected in data_grid before FFT. Replacing with zeros.")
             data_grid = torch.nan_to_num(data_grid, nan=0.0, posinf=0.0, neginf=0.0)

        fft_results.append(torch.fft.fft2(data_grid))

    if not fft_results:
         print("Warning: No FFT results generated.")
         return torch.empty(0, 0, 0, device=device), n1, n2, 0

    J_vector_tensor = torch.stack(fft_results, dim=2).to(device)

    H = float(n1 * n2)
    if H < 1e-9:
        print("Warning: Normalization factor H is near zero.")
        norm_factor = torch.tensor(0.0, device=device)
    else:
        norm_factor = (torch.sqrt(torch.tensor(1.0 / H, device=device)) / (2.0 * cmath.pi))

    result = J_vector_tensor * norm_factor
    if torch.isnan(result).any():
        print("Warning: NaN detected in J_vector output.")
    return result, n1, n2, p


def calculate_sample_periodogram_vectorized(J_vector_tensor):
    """Calculates sample periodogram I_n = J J^H (pxp matrix for each spatial freq)."""
    if torch.isnan(J_vector_tensor).any() or torch.isinf(J_vector_tensor).any():
        print("Warning: NaN/Inf detected in J_vector_tensor input.")
        n1, n2, p = J_vector_tensor.shape
        return torch.full((n1, n2, p, p), float('nan'), dtype=torch.complex64, device=J_vector_tensor.device)

    J_col = J_vector_tensor.unsqueeze(-1)
    J_row_conj = J_vector_tensor.unsqueeze(-2).conj()
    result = J_col @ J_row_conj

    if torch.isnan(result).any():
        print("Warning: NaN detected in calculate_sample_periodogram_vectorized output.")
    return result


# =========================================================================
# 4. Likelihood Calculation (Adapted for No Tapering)
# =========================================================================

def whittle_likelihood_loss_no_taper(params, I_sample, n1, n2, p, delta1, delta2):
    """
    ✅ Whittle Likelihood Loss assuming NO data taper (g_s=1).
    Models a single field (the spatially differenced one).
    Uses Exponential kernel based cov_spatial_difference.
    """
    device = I_sample.device
    params_tensor = params.to(device)

    if torch.isnan(params_tensor).any() or torch.isinf(params_tensor).any():
        print("Warning: NaN/Inf detected in input parameters to likelihood.")
        return torch.tensor(float('nan'), device=device)

    # Use cov_spatial_difference (based on Exponential X) here
    I_expected = expected_periodogram_fft_no_taper(
        params_tensor, n1, n2, p, cov_spatial_difference, delta1, delta2
    )

    if torch.isnan(I_expected).any() or torch.isinf(I_expected).any():
        print("Warning: NaN/Inf returned from expected_periodogram calculation.")
        return torch.tensor(float('nan'), device=device)

    eye_matrix = torch.eye(p, dtype=torch.complex64, device=device)
    # Adaptive diagonal loading based on diagonal magnitude
    diag_vals = torch.abs(I_expected.diagonal(dim1=-2, dim2=-1))
    # Handle case where p=0 or all diags are zero/nan
    mean_diag_abs = diag_vals.mean().item() if diag_vals.numel() > 0 and not torch.isnan(diag_vals).all() else 1.0
    diag_load = max(mean_diag_abs * 1e-8, 1e-9) # Ensure at least 1e-9
    
    I_expected_stable = I_expected + eye_matrix * diag_load

    sign, logabsdet = torch.linalg.slogdet(I_expected_stable)
    # Check determinant sign and use penalty for bad params
    if torch.any(sign.real <= 1e-9):
        print("Warning: Non-positive determinant encountered. Applying penalty.")
        # Apply penalty: large positive value for log_det_term
        log_det_term = torch.where(sign.real > 1e-9, logabsdet, torch.tensor(1e10, device=device))
    else:
        log_det_term = logabsdet

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Warning: NaN/Inf detected in I_sample input to likelihood.")
        return torch.tensor(float('nan'), device=device)

    try:
        solved_term = torch.linalg.solve(I_expected_stable, I_sample)
        trace_term = torch.einsum('...ii->...', solved_term).real
    except torch.linalg.LinAlgError as e:
        print(f"Warning: LinAlgError during solve: {e}. Applying high loss penalty.")
        # Return a large loss value instead of NaN to guide optimizer away
        return torch.tensor(float('inf'), device=device)

    if torch.isnan(trace_term).any() or torch.isinf(trace_term).any():
        print("Warning: NaN/Inf detected in trace_term. Returning NaN loss.")
        return torch.tensor(float('nan'), device=device)

    likelihood_terms = log_det_term + trace_term

    if torch.isnan(likelihood_terms).any():
        print("Warning: NaN detected in likelihood_terms before summation. Returning NaN loss.")
        return torch.tensor(float('nan'), device=device)

    # Sum over non-zero spatial frequencies
    total_sum = torch.sum(likelihood_terms)
    dc_term = likelihood_terms[0, 0] if n1 > 0 and n2 > 0 else torch.tensor(0.0, device=device)
    if torch.isnan(dc_term).any() or torch.isinf(dc_term).any():
        print("Warning: NaN/Inf detected in DC term. Setting to 0.")
        dc_term = torch.tensor(0.0, device=device)

    loss = total_sum - dc_term if (n1 > 1 or n2 > 1) else total_sum

    if torch.isnan(loss) or torch.isinf(loss):
         print("Warning: NaN/Inf detected in final loss. Returning Inf penalty.")
         return torch.tensor(float('inf'), device=device) # Use Inf penalty

    return loss


# =========================================================================
# 5. Training Loop (CORRECTED version)
# =========================================================================
def run_full(params_list, optimizer, scheduler, I_sample, n1, n2, p, epochs=600, device='cpu'):
    """Corrected training loop using parameter list and no taper likelihood."""
    best_loss = float('inf')
    params_list = [p.to(device) for p in params_list]
    best_params_state = [p.detach().clone() for p in params_list]
    epochs_completed = 0
    DELTA_LAT, DELTA_LON = 0.044, 0.063

    def get_printable_params(p_list):
        valid_tensors = [p for p in p_list if isinstance(p, torch.Tensor)]
        if not valid_tensors: return "Invalid params_list"
        p_cat = torch.cat([p.detach().clone().cpu() for p in valid_tensors])
        log_indices = [0, 1, 2, 6]
        if all(idx < len(p_cat) for idx in log_indices):
            log_vals = p_cat[log_indices]
            if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
                 p_cat[log_indices] = torch.exp(log_vals)
            else:
                 print("Warning: NaN/Inf in log-params for printing.")
                 # Set corresponding natural scale values to NaN for clarity
                 p_cat[log_indices] = float('nan')
        return p_cat.numpy().round(4)

    I_sample_dev = I_sample.to(device)

    for epoch in range(epochs):
        epochs_completed = epoch + 1
        optimizer.zero_grad()
        params_tensor = torch.cat(params_list)

        loss = whittle_likelihood_loss_no_taper(
            params_tensor, I_sample_dev, n1, n2, p, DELTA_LAT, DELTA_LON
        )

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Loss became NaN or Inf at epoch {epoch+1}. Stopping.")
            if epoch == 0: best_params_state = None
            epochs_completed = epoch
            break # Stop this run

        loss.backward()

        nan_grad = False
        for param in params_list:
            if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                print(f"Warning: NaN/Inf gradient detected at epoch {epoch+1}.")
                nan_grad = True
                break
        if nan_grad:
             optimizer.zero_grad() # Zero bad gradients
             print("Skipping optimizer step due to invalid gradients.")
             # Don't step scheduler if optimizer step is skipped
             continue # Try next epoch

        # Clip gradients only if they are valid
        all_params_on_device = params_list
        if all_params_on_device:
            torch.nn.utils.clip_grad_norm_(all_params_on_device, max_norm=1.0)

        optimizer.step()
        scheduler.step() # Step scheduler after optimizer

        current_loss_item = loss.item()
        # Save best state *after* successful step only if params are valid
        if current_loss_item < best_loss:
             # Check parameters *after* step
            params_valid = not any(torch.isnan(p.data).any() or torch.isinf(p.data).any() for p in params_list)
            if params_valid:
                best_loss = current_loss_item
                best_params_state = [p.detach().clone() for p in params_list]
            else:
                print(f"Warning: NaN/Inf in params after step epoch {epoch+1}. Not saving state.")


        current_lr = optimizer.param_groups[0]['lr'] if optimizer.param_groups else 0.0

        if epoch % 50 == 0 or epoch == epochs - 1:
            print(f'--- Epoch {epoch+1}/{epochs} (LR: {current_lr:.6f}) ---')
            print(f' Loss: {current_loss_item:.4f}')
            print(f' Parameters (Natural Scale): {get_printable_params(params_list)}')


    print("\n--- Training Complete ---")
    if best_params_state is None:
        print("Training failed to find a valid model state.")
        return None, epochs_completed

    # Prepare final output (move best state to CPU)
    final_params_log_scale = torch.cat([p.cpu() for p in best_params_state])
    final_params_natural_scale = final_params_log_scale.detach().clone()
    log_indices = [0, 1, 2, 6]
    # Final check for NaN/Inf before exp
    if all(idx < len(final_params_natural_scale) for idx in log_indices):
        log_vals = final_params_natural_scale[log_indices]
        if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
            final_params_natural_scale[log_indices] = torch.exp(log_vals)
        else:
            print("Warning: Invalid values in best log-params before final exp. Setting natural scale to NaN.")
            final_params_natural_scale[log_indices] = float('nan')


    final_params_rounded = [round(p.item(), 4) if not np.isnan(p.item()) else float('nan') for p in final_params_natural_scale]
    final_loss_rounded = round(best_loss, 3) if best_loss != float('inf') else float('inf')

    print(f'\nFINAL BEST STATE ACHIEVED (during training):')
    print(f'Best Loss: {final_loss_rounded}')
    print(f'Parameters Corresponding to Best Loss (Natural Scale): {final_params_rounded}')

    return final_params_rounded + [final_loss_rounded], epochs_completed




In [70]:
processed_df = spatially_filtered_days

# =========================================================================
# 6. Main Execution Script (NO Tapering, Spatially Differenced Data, EXP Kernel)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    NUM_RUNS = 1 # Only 1 run with fixed start
    EPOCHS = 700 # Use 700 as in reference
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    # --- Grid Spacing (Must match data differencing step) ---
    DELTA_LAT, DELTA_LON = 0.044, 0.063 # From reference run_full

    # --- Column Indices for processed_df ---
    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2
    TIME_COL = 3

    # --- Load Spatially Differenced Data ---


    processed_df = spatially_filtered_days
    cur_df = processed_df[DAY_TO_RUN - 1]
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL):
        print(f"Error: Data for Day {DAY_TO_RUN} is empty or invalid.")
        exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute Sample Periodogram (NO Tapering) ---
    print("Pre-computing sample periodogram (NO data taper)...")
    J_vec, n1, n2, p = generate_Jvector_no_taper(
        time_slices_list,
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0:
       print(f"Error: J-vector generation failed for Day {DAY_TO_RUN}.")
       exit()

    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Error: NaN or Inf detected in the sample periodogram. Cannot proceed.")
        exit()

    print(f"Data grid: {n1}x{n2} spatial points, {p} time points. Sample Periodogram on {DEVICE}.")

    # --- 2. Optimization Loop ---
    all_final_results = []
    all_final_losses = []

    for i in range(NUM_RUNS):
        print(f"\n{'='*30} Initialization Run {i+1}/{NUM_RUNS} {'='*30}")

        initial_params_values = [np.log(25), np.log(1),np.log(1.5), -0.05, -0.08, 0.05, np.log(2)]
        initial_params_values = [np.log(22), np.log(1),np.log(1.5), 0, 0, 0, np.log(4.769)]
        print(f"Starting with FIXED params (log-scale for some): {[round(p, 4) for p in initial_params_values]}")

        params_list = [
            Parameter(torch.tensor([val], dtype=torch.float32))
            for val in initial_params_values
        ]

        lr_slow, lr_fast = 0.005, 0.02
        slow_indices = [0, 1, 2, 6]
        fast_indices = [3, 4, 5]

        valid_slow_indices = [idx for idx in slow_indices if idx < len(params_list)]
        valid_fast_indices = [idx for idx in fast_indices if idx < len(params_list)]

        param_groups = [
            {'params': [params_list[idx] for idx in valid_slow_indices], 'lr': lr_slow, 'name': 'slow_group'},
            {'params': [params_list[idx] for idx in valid_fast_indices], 'lr': lr_fast, 'name': 'fast_group'}
        ]

        optimizer = torch.optim.Adam(param_groups)

        T_MAX = 200
        ETA_MIN = 1e-6
        scheduler = CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=ETA_MIN)

        print(f"Starting optimization run {i+1} on device {DEVICE} (NO data taper, EXP kernel)...")
        final_results, epochs_run = run_full(
            params_list=params_list,
            optimizer=optimizer,
            scheduler=scheduler,
            I_sample=I_sample,
            n1=n1, n2=n2, p=p,
            epochs=EPOCHS,
            device=DEVICE
        )

        if final_results:
            all_final_results.append(final_results)
            all_final_losses.append(final_results[-1])
        else:
            all_final_results.append(None)
            all_final_losses.append(float('inf'))

    print(f"\n\n{'='*25} Overall Result from Run {'='*25}")
    valid_losses = [l for l in all_final_losses if l is not None and l != float('inf')]

    if not valid_losses:
        print(f"The run failed or resulted in an invalid loss for Day {DAY_TO_RUN}.")
    else:
        best_loss = valid_losses[0]
        best_run_index = 0
        best_results = all_final_results[best_run_index]
        print(f"Run Loss: {best_results[-1]}")
        print(f"Final Parameters (Natural Scale): {best_results[:-1]}")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Pre-computing sample periodogram (NO data taper)...
Data grid: 113x158 spatial points, 1 time points. Sample Periodogram on cpu.

Starting with FIXED params (log-scale for some): [3.091, 0.0, 0.4055, 0, 0, 0, 1.5621]
Starting optimization run 1 on device cpu (NO data taper, EXP kernel)...
--- Epoch 1/700 (LR: 0.005000) ---
 Loss: 7496.9814
 Parameters (Natural Scale): [21.8903  1.005   1.5075  0.      0.      0.      4.7452]
--- Epoch 51/700 (LR: 0.004240) ---
 Loss: 5933.7134
 Parameters (Natural Scale): [17.381   1.266   1.8984  0.      0.      0.      3.7425]
--- Epoch 101/700 (LR: 0.002461) ---
 Loss: 5405.8633
 Parameters (Natural Scale): [14.9064  1.4778  2.2111  0.      0.      0.      3.1535]
--- Epoch 151/700 (LR: 0.000706) ---
 Loss: 5352.5537
 Parameters (Natural Scale): [13.9972  1.5883  2.3171  0.      0.      0.      2.9309]
--- Epoch 201/700 (LR: 0.000001) ---
 Loss: 5352.3784
 Parameters (Natural Scale): [14.0408  1.6129  2.2636  0.      0.      0.    

## 10/20/25. Hamming tapering

In [60]:
import torch
import numpy as np
import matplotlib.pyplot as plt # Keep if plotting might be added later
import cmath
import pickle
import time # For timing
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import Parameter
import torch.fft # Explicit import for fft functions

# =========================================================================
# 1. Tapering, Autocorrelation, and Covariance Functions
# =========================================================================

# --- Tapering Functions (Hamming is used) ---
def cgn_hamming(u, n1, n2):
    """Computes a 2D Hamming window."""
    u1, u2 = u
    # Ensure inputs are tensors and on the correct device
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)

    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    # Add epsilon to denominator to avoid potential division by zero if n=0 or 1?
    # Though meshgrid usually starts from 0 to n-1.
    hamming1 = 0.54 + 0.46 * torch.cos(2.0 * torch.pi * u1_tensor / n1_eff)
    hamming2 = 0.54 + 0.46 * torch.cos(2.0 * torch.pi * u2_tensor / n2_eff)
    return hamming1 * hamming2

def cgn_2dbartlett(u, n1, n2): # Kept for potential future use
    """Computes a 2D Bartlett window function."""
    u1, u2 = u
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    kernel = (1.0 - torch.abs(u1_tensor) / n1_eff) * (1.0 - torch.abs(u2_tensor) / n2_eff)
    return torch.clamp(kernel, min=0.0)

# --- NEW: Function to Calculate Taper Autocorrelation ---
def calculate_taper_autocorrelation_fft(taper_grid, n1, n2, device):
    """
    Computes the normalized taper autocorrelation function c_gn(u) using FFT.

    Args:
        taper_grid (torch.Tensor): The 2D taper function g_s on the grid [n1, n2].
        n1, n2 (int): Dimensions of the original grid.
        device: The torch device.

    Returns:
        torch.Tensor: Normalized autocorrelation c_gn(u), shifted so lag u=(0,0)
                      is at index [n1-1, n2-1]. Shape [2*n1-1, 2*n2-1].
    """
    taper_grid = taper_grid.to(device) # Ensure input is on device
    H = torch.sum(taper_grid**2)
    if H < 1e-12:
        print("Warning: Sum of squared taper weights (H) is near zero.")
        return torch.zeros((2*n1-1, 2*n2-1), device=device, dtype=taper_grid.dtype)

    # Pad for linear autocorrelation via FFT
    N1, N2 = 2 * n1 - 1, 2 * n2 - 1
    taper_fft = torch.fft.fft2(taper_grid, s=(N1, N2))
    power_spectrum = torch.abs(taper_fft)**2
    autocorr_unnormalized = torch.fft.ifft2(power_spectrum).real
    # Shift zero lag (originally at [0,0]) to the center [n1-1, n2-1]
    autocorr_shifted = torch.fft.fftshift(autocorr_unnormalized)

    # Normalize by H (value at zero lag)
    # Add small epsilon to prevent potential division by zero if H is calculated slightly differently
    c_gn_grid = autocorr_shifted / (H + 1e-12)

    return c_gn_grid # Already on device

# --- Covariance of Original Field X (EXPONENTIAL Kernel) ---
def cov_x_exponential(u1, u2, t, params):
    """
    Computes spatio-temporal autocovariance of X using EXPONENTIAL kernel.
    Expects log-scale params [0,1,2,6]. Handles device internally.
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    log_params_indices = [0, 1, 2, 6]
    if torch.isnan(params[log_params_indices]).any() or torch.isinf(params[log_params_indices]).any():
         print("Warning: NaN/Inf in log-params before exp in cov_x_exponential")
         out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
         return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    sigmasq, r_lat, r_lon, nugget = torch.exp(params[log_params_indices])
    a_lat, a_lon, beta = params[3], params[4], params[5]

    if r_lat < 1e-6 or r_lon < 1e-6:
        r_lat = torch.clamp(r_lat, min=1e-6)
        r_lon = torch.clamp(r_lon, min=1e-6)

    x1 = u1_dev / r_lat - a_lat * t_dev
    x2 = u2_dev / r_lon - a_lon * t_dev
    x3 = beta * t_dev
    distance_sq = x1**2 + x2**2 + x3**2
    epsilon = 1e-12
    distance_sq_clamped = torch.clamp(distance_sq, min=0.0)
    D = torch.sqrt(distance_sq_clamped + epsilon)

    cov_smooth = sigmasq * torch.exp(-D)
    is_zero_lag = (torch.abs(u1_dev) < 1e-9) & (torch.abs(u2_dev) < 1e-9) & (torch.abs(t_dev) < 1e-9)
    final_cov = torch.where(is_zero_lag, cov_smooth + nugget, cov_smooth)

    if torch.isnan(final_cov).any(): print("Warning: NaN detected in cov_x_exponential output.")
    return final_cov

# --- Covariance of Spatially Differenced Field Y (Unchanged structure) ---
def cov_spatial_difference(u1, u2, t, params, delta1, delta2):
    """
    Calculates covariance Cov(Y(s), Y(s+u)) for Y(s) = X(s+d1) + X(s+d2) - 2X(s)
    Based on the underlying EXPONENTIAL covariance cov_x_exponential.
    """
    weights = {(0, 0): -2.0, (1, 0): 1.0, (0, 1): 1.0}
    device = params.device
    out_shape = torch.broadcast_shapes(u1.shape if isinstance(u1, torch.Tensor) else (),
                                     u2.shape if isinstance(u2, torch.Tensor) else (),
                                     t.shape if isinstance(t, torch.Tensor) else ())
    cov = torch.zeros(out_shape, device=device, dtype=torch.float32)
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    for (a_idx, b_idx), w_ab in weights.items():
        offset_a1 = a_idx * delta1
        offset_a2 = b_idx * delta2
        for (c_idx, d_idx), w_cd in weights.items():
            offset_c1 = c_idx * delta1
            offset_c2 = d_idx * delta2
            lag_u1 = u1_dev + (offset_a1 - offset_c1)
            lag_u2 = u2_dev + (offset_a2 - offset_c2)
            term_cov = cov_x_exponential(lag_u1, lag_u2, t_dev, params) # Calls EXPONENTIAL version
            if torch.isnan(term_cov).any():
                 print(f"Warning: NaN in term_cov within cov_spatial_difference.")
                 return torch.full_like(cov, float('nan'))
            cov += w_ab * w_cd * term_cov

    if torch.isnan(cov).any(): print("Warning: NaN in final cov_spatial_difference output.")
    return cov

# --- NEW cn_bar using Taper Autocorrelation ---
def cn_bar_tapered(u1, u2, t, params, n1, n2, taper_autocorr_grid, delta1, delta2):
    """
    Computes c_X(u) * c_gn(u) where c_X is cov_spatial_difference
    and c_gn(u) is looked up from the pre-computed taper_autocorr_grid.
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)

    # Calculate theoretical covariance c_X(u)
    cov_X_value = cov_spatial_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)

    # --- Get Taper Autocorrelation Value c_gn(u) from grid ---
    # Lags u1, u2 are grid indices (0..n1-1, 0..n2-1) from meshgrid
    # Center of autocorr grid (n1-1, n2-1) corresponds to lag (0,0)
    # Index for lag (u1, u2) IS (n1-1 + u1, n2-1 + u2) if u1,u2 are relative lags centered at 0
    # BUT u1_mesh_grid, u2_mesh_grid are 0..n1-1, 0..n2-1.
    # We need autocorrelation for lags u = s1 - s2, where s1, s2 are grid indices.
    # The FFT method gives autocorr C(k) = Sum_n g[n]g[n-k]^*
    # The shifted grid taper_autocorr_grid[n1-1+u1_lag, n2-1+u2_lag] gives C(u1_lag, u2_lag)
    # where u1_lag = -(n1-1)...0...(n1-1).
    # Since u1_mesh_grid represents the ABSOLUTE separation s1-s2 (0 to n1-1),
    # we need to map this to the indices of the centered autocorrelation grid.
    # A lag of 0 corresponds to index n1-1. A lag of 1 corresponds to index n1.
    # A lag of u1 corresponds to index n1-1 + u1.
    # However, FFT assumes periodicity. Need care with interpretation of lags vs indices.
    # Let's test: If u1_mesh_grid=0, we need lag 0 => index n1-1. If u1_mesh_grid=1, need lag 1 => index n1.
    # If u1_mesh_grid=u1, need lag u1 => index n1-1 + u1. This seems correct.
    # But wait, meshgrid is 0..n1-1. What lags does product_tensor represent?
    # product_tensor[i,j,q,r] = cn_bar(u1=i, u2=j, t=tq-tr). u1, u2 are interpreted as spatial lags.
    # So we need c_gn at spatial lags u1=i, u2=j.

    # Ensure u1, u2 are integer indices for lookup
    u1_idx = u1_dev.long()
    u2_idx = u2_dev.long()

    # Calculate indices into the centered autocorrelation grid
    idx1 = (n1 - 1 + u1_idx)
    idx2 = (n2 - 1 + u2_idx)

    # Clamp indices (although they should be within bounds if u1, u2 are from meshgrid 0..n-1?)
    # Max lag is n-1. Max index needed is n-1 + n-1 = 2n-2. Grid size is 2n-1. Correct.
    idx1 = torch.clamp(idx1, 0, 2 * n1 - 2)
    idx2 = torch.clamp(idx2, 0, 2 * n2 - 2)

    taper_autocorr_value = taper_autocorr_grid[idx1, idx2] # Indexing with meshes

    if torch.isnan(cov_X_value).any() or torch.isnan(taper_autocorr_value).any():
        print("Warning: NaN detected before multiplication in cn_bar_tapered.")
        out_shape = torch.broadcast_shapes(cov_X_value.shape, taper_autocorr_value.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)

    result = cov_X_value * taper_autocorr_value
    if torch.isnan(result).any(): print("Warning: NaN in cn_bar_tapered output.")
    return result

# --- Expected Periodogram (uses cn_bar_tapered) ---
def expected_periodogram_fft_tapered(params, n1, n2, p, taper_autocorr_grid, cov_func, delta1, delta2):
    """
    Calculates the expected periodogram I(omega_s) (a pxp matrix in time)
    using the exact taper autocorrelation c_gn(u).
    Returns shape [n1_freq, n2_freq, p, p]
    """
    device = params.device if isinstance(params, torch.Tensor) else params[0].device
    if isinstance(params, list):
        params_tensor = torch.cat([p.to(device) for p in params])
    else:
        params_tensor = params.to(device)

    product_tensor = torch.zeros((n1, n2, p, p), dtype=torch.complex64, device=device)
    t_lags = torch.arange(p, dtype=torch.float32, device=device)
    u1_mesh_grid, u2_mesh_grid = torch.meshgrid(
        torch.arange(n1, dtype=torch.float32, device=device),
        torch.arange(n2, dtype=torch.float32, device=device),
        indexing='ij'
    )

    for q in range(p):
        for r in range(p):
            t_diff = t_lags[q] - t_lags[r]
            # Calculate c_X(u) * c_gn(u) using the autocorrelation grid
            cov_times_autocorr = cn_bar_tapered(
                u1_mesh_grid, u2_mesh_grid, t_diff,
                params_tensor, n1, n2, taper_autocorr_grid, delta1, delta2
            )
            if torch.isnan(cov_times_autocorr).any():
                 print(f"Warning: NaN detected in cov_times_autocorr for t_lag {t_diff.item():.2f}.")
                 product_tensor[:, :, q, r] = float('nan')
            else:
                 product_tensor[:, :, q, r] = cov_times_autocorr.to(torch.complex64)

    if torch.isnan(product_tensor).any():
        print("Warning: NaN detected in product_tensor before FFT.")
        nan_shape = (n1, n2, p, p)
        return torch.full(nan_shape, float('nan'), dtype=torch.complex64, device=device)

    fft_result = torch.fft.fft2(product_tensor, dim=(0, 1))
    normalization_factor = 1.0 / (4.0 * cmath.pi**2)
    result = fft_result * normalization_factor

    if torch.isnan(result).any(): print("Warning: NaN in expected_periodogram_fft_tapered output.")
    return result

# =========================================================================
# 2. Data Processing (MODIFIED for HAMMING TAPERING)
# =========================================================================
def generate_Jvector_tapered(tensor_list, tapering_func, lat_col, lon_col, val_col, device):
    """
    Generates J-vector for a single component using the specified taper,
    placing result on device.
    """
    p = len(tensor_list)
    if p == 0: return torch.empty(0, 0, 0, device=device), 0, 0, 0, None # Return None for taper grid

    valid_tensors = [t for t in tensor_list if t.numel() > 0 and t.shape[1] > max(lat_col, lon_col, val_col)]
    if not valid_tensors:
         print("Warning: No valid tensors found in tensor_list.")
         return torch.empty(0, 0, 0, device=device), 0, 0, 0, None

    try:
        all_lats_cpu = torch.cat([t[:, lat_col] for t in valid_tensors])
        all_lons_cpu = torch.cat([t[:, lon_col] for t in valid_tensors])
    except IndexError:
        print(f"Error: Invalid column index. Check tensor shapes.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0, None

    all_lats_cpu = all_lats_cpu[~torch.isnan(all_lats_cpu) & ~torch.isinf(all_lats_cpu)]
    all_lons_cpu = all_lons_cpu[~torch.isnan(all_lons_cpu) & ~torch.isinf(all_lons_cpu)]
    if all_lats_cpu.numel() == 0 or all_lons_cpu.numel() == 0:
        print("Warning: No valid coordinates after NaN/Inf filtering.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0, None

    unique_lats_cpu, unique_lons_cpu = torch.unique(all_lats_cpu), torch.unique(all_lons_cpu)
    n1, n2 = len(unique_lats_cpu), len(unique_lons_cpu)
    if n1 == 0 or n2 == 0:
        print("Warning: Grid dimensions are zero.")
        return torch.empty(0, 0, 0, device=device), 0, 0, 0, None

    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats_cpu)}
    lon_map = {lon.item(): i for i, lon in enumerate(unique_lons_cpu)}

    # --- Create Taper Grid ---
    u1_mesh_cpu, u2_mesh_cpu = torch.meshgrid(
        torch.arange(n1, dtype=torch.float32),
        torch.arange(n2, dtype=torch.float32),
        indexing='ij'
    )
    taper_grid = tapering_func((u1_mesh_cpu, u2_mesh_cpu), n1, n2).to(device) # Taper on device

    fft_results = []
    for tensor in tensor_list:
        data_grid = torch.zeros((n1, n2), dtype=torch.float32, device=device)
        if tensor.numel() > 0 and tensor.shape[1] > max(lat_col, lon_col, val_col):
            for row in tensor:
                lat_item, lon_item = row[lat_col].item(), row[lon_col].item()
                if not (np.isnan(lat_item) or np.isnan(lon_item)):
                    i = lat_map.get(lat_item)
                    j = lon_map.get(lon_item)
                    if i is not None and j is not None:
                        val = row[val_col]
                        val_num = val.item() if isinstance(val, torch.Tensor) else val
                        if not np.isnan(val_num) and not np.isinf(val_num):
                           data_grid[i, j] = val_num

        # --- Apply Tapering ---
        data_grid_tapered = data_grid * taper_grid # Both on device

        if torch.isnan(data_grid_tapered).any() or torch.isinf(data_grid_tapered).any():
             print("Warning: NaN/Inf detected in data_grid_tapered before FFT. Replacing with zeros.")
             data_grid_tapered = torch.nan_to_num(data_grid_tapered, nan=0.0, posinf=0.0, neginf=0.0)

        fft_results.append(torch.fft.fft2(data_grid_tapered))

    if not fft_results:
         print("Warning: No FFT results generated.")
         return torch.empty(0, 0, 0, device=device), n1, n2, 0, taper_grid # Return taper grid

    J_vector_tensor = torch.stack(fft_results, dim=2).to(device)

    # --- Normalization using Taper ---
    H = torch.sum(taper_grid**2)
    if H < 1e-12:
        print("Warning: Normalization factor H is near zero.")
        norm_factor = torch.tensor(0.0, device=device)
    else:
        norm_factor = (torch.sqrt(1.0 / H) / (2.0 * cmath.pi)).to(device)

    result = J_vector_tensor * norm_factor
    if torch.isnan(result).any(): print("Warning: NaN in J_vector output.")
    # Also return the calculated taper_grid
    return result, n1, n2, p, taper_grid


def calculate_sample_periodogram_vectorized(J_vector_tensor):
    """Calculates sample periodogram I_n = J J^H (pxp matrix for each spatial freq)."""
    if torch.isnan(J_vector_tensor).any() or torch.isinf(J_vector_tensor).any():
        print("Warning: NaN/Inf detected in J_vector_tensor input.")
        n1, n2, p = J_vector_tensor.shape
        return torch.full((n1, n2, p, p), float('nan'), dtype=torch.complex64, device=J_vector_tensor.device)

    J_col = J_vector_tensor.unsqueeze(-1)
    J_row_conj = J_vector_tensor.unsqueeze(-2).conj()
    result = J_col @ J_row_conj

    if torch.isnan(result).any(): print("Warning: NaN in periodogram matrix output.")
    return result


# =========================================================================
# 4. Likelihood Calculation (Adapted for Tapering with Autocorrelation)
# =========================================================================

def whittle_likelihood_loss_tapered(params, I_sample, n1, n2, p, taper_autocorr_grid, delta1, delta2):
    """
    ✅ Whittle Likelihood Loss using data tapering and exact taper autocorrelation c_gn.
    Models a single field (the spatially differenced one). Uses Exponential kernel.
    """
    device = I_sample.device
    params_tensor = params.to(device)

    if torch.isnan(params_tensor).any() or torch.isinf(params_tensor).any():
        print("Warning: NaN/Inf detected in input parameters to likelihood.")
        return torch.tensor(float('nan'), device=device)

    # Use expected_periodogram_fft_tapered which uses cn_bar_tapered
    I_expected = expected_periodogram_fft_tapered(
        params_tensor, n1, n2, p, taper_autocorr_grid, cov_spatial_difference, delta1, delta2
    )

    if torch.isnan(I_expected).any() or torch.isinf(I_expected).any():
        print("Warning: NaN/Inf returned from expected_periodogram calculation.")
        return torch.tensor(float('nan'), device=device)

    eye_matrix = torch.eye(p, dtype=torch.complex64, device=device)
    diag_vals = torch.abs(I_expected.diagonal(dim1=-2, dim2=-1))
    mean_diag_abs = diag_vals.mean().item() if diag_vals.numel() > 0 and not torch.isnan(diag_vals).all() else 1.0
    diag_load = max(mean_diag_abs * 1e-8, 1e-9)
    I_expected_stable = I_expected + eye_matrix * diag_load

    sign, logabsdet = torch.linalg.slogdet(I_expected_stable)
    if torch.any(sign.real <= 1e-9):
        print("Warning: Non-positive determinant encountered. Applying penalty.")
        log_det_term = torch.where(sign.real > 1e-9, logabsdet, torch.tensor(1e10, device=device))
    else:
        log_det_term = logabsdet

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Warning: NaN/Inf detected in I_sample input to likelihood.")
        return torch.tensor(float('nan'), device=device)

    try:
        solved_term = torch.linalg.solve(I_expected_stable, I_sample)
        trace_term = torch.einsum('...ii->...', solved_term).real
    except torch.linalg.LinAlgError as e:
        print(f"Warning: LinAlgError during solve: {e}. Applying high loss penalty.")
        return torch.tensor(float('inf'), device=device)

    if torch.isnan(trace_term).any() or torch.isinf(trace_term).any():
        print("Warning: NaN/Inf detected in trace_term. Returning NaN loss.")
        return torch.tensor(float('nan'), device=device)

    likelihood_terms = log_det_term + trace_term

    if torch.isnan(likelihood_terms).any():
        print("Warning: NaN detected in likelihood_terms before summation. Returning NaN loss.")
        return torch.tensor(float('nan'), device=device)

    total_sum = torch.sum(likelihood_terms)
    dc_term = likelihood_terms[0, 0] if n1 > 0 and n2 > 0 else torch.tensor(0.0, device=device)
    if torch.isnan(dc_term).any() or torch.isinf(dc_term).any():
        print("Warning: NaN/Inf detected in DC term. Setting to 0.")
        dc_term = torch.tensor(0.0, device=device)

    loss = total_sum - dc_term if (n1 > 1 or n2 > 1) else total_sum

    if torch.isnan(loss) or torch.isinf(loss):
         print("Warning: NaN/Inf detected in final loss. Returning Inf penalty.")
         return torch.tensor(float('inf'), device=device)

    return loss


# =========================================================================
# 5. Training Loop (CORRECTED version, adapted for tapering)
# =========================================================================
def run_full_tapered(params_list, optimizer, scheduler, I_sample, n1, n2, p, taper_autocorr_grid, epochs=600, device='cpu'):
    """Corrected training loop using parameter list and tapered likelihood."""
    best_loss = float('inf')
    params_list = [p.to(device) for p in params_list]
    best_params_state = [p.detach().clone() for p in params_list]
    epochs_completed = 0
    DELTA_LAT, DELTA_LON = 0.044, 0.063 # Grid spacing needed for cov func

    def get_printable_params(p_list):
        valid_tensors = [p for p in p_list if isinstance(p, torch.Tensor)]
        if not valid_tensors: return "Invalid params_list"
        p_cat = torch.cat([p.detach().clone().cpu() for p in valid_tensors])
        log_indices = [0, 1, 2, 6]
        if all(idx < len(p_cat) for idx in log_indices):
            log_vals = p_cat[log_indices]
            if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
                 p_cat[log_indices] = torch.exp(log_vals)
            else:
                 p_cat[log_indices] = float('nan')
        return p_cat.numpy().round(4)

    I_sample_dev = I_sample.to(device)
    taper_autocorr_grid_dev = taper_autocorr_grid.to(device) # Ensure autocorr grid is on device

    for epoch in range(epochs):
        epochs_completed = epoch + 1
        optimizer.zero_grad()
        params_tensor = torch.cat(params_list) # Create tensor on device

        # Use the tapered likelihood function
        loss = whittle_likelihood_loss_tapered(
            params_tensor, I_sample_dev, n1, n2, p, taper_autocorr_grid_dev, DELTA_LAT, DELTA_LON
        )

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Loss became NaN or Inf at epoch {epoch+1}. Stopping.")
            if epoch == 0: best_params_state = None
            epochs_completed = epoch
            break

        loss.backward()

        nan_grad = False
        for param in params_list:
            if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                nan_grad = True; break
        if nan_grad:
             print(f"Warning: NaN/Inf gradient at epoch {epoch+1}. Skipping step.")
             optimizer.zero_grad()
             continue

        all_params_on_device = params_list
        if all_params_on_device:
            torch.nn.utils.clip_grad_norm_(all_params_on_device, max_norm=1.0)

        optimizer.step()
        scheduler.step()

        current_loss_item = loss.item()
        if current_loss_item < best_loss:
            params_valid = not any(torch.isnan(p.data).any() or torch.isinf(p.data).any() for p in params_list)
            if params_valid:
                best_loss = current_loss_item
                best_params_state = [p.detach().clone() for p in params_list]
            else:
                print(f"Warning: NaN/Inf in params after step epoch {epoch+1}. Not saving.")

        current_lr = optimizer.param_groups[0]['lr'] if optimizer.param_groups else 0.0

        if epoch % 50 == 0 or epoch == epochs - 1:
            print(f'--- Epoch {epoch+1}/{epochs} (LR: {current_lr:.6f}) ---')
            print(f' Loss: {current_loss_item:.4f}')
            print(f' Parameters (Natural Scale): {get_printable_params(params_list)}')


    print("\n--- Training Complete ---")
    if best_params_state is None:
        print("Training failed to find a valid model state.")
        return None, epochs_completed

    final_params_log_scale = torch.cat([p.cpu() for p in best_params_state])
    final_params_natural_scale = final_params_log_scale.detach().clone()
    log_indices = [0, 1, 2, 6]
    if all(idx < len(final_params_natural_scale) for idx in log_indices):
        log_vals = final_params_natural_scale[log_indices]
        if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()):
            final_params_natural_scale[log_indices] = torch.exp(log_vals)
        else:
            final_params_natural_scale[log_indices] = float('nan')

    final_params_rounded = [round(p.item(), 4) if not np.isnan(p.item()) else float('nan') for p in final_params_natural_scale]
    final_loss_rounded = round(best_loss, 3) if best_loss != float('inf') else float('inf')

    print(f'\nFINAL BEST STATE ACHIEVED (during training):')
    print(f'Best Loss: {final_loss_rounded}')
    print(f'Parameters Corresponding to Best Loss (Natural Scale): {final_params_rounded}')

    return final_params_rounded + [final_loss_rounded], epochs_completed


if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    TAPERING_FUNC = cgn_hamming # Use Hamming taper
    NUM_RUNS = 1
    EPOCHS = 700
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    DELTA_LAT, DELTA_LON = 0.044, 0.063

    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2 # Spatially differenced value
    TIME_COL = 3

    processed_df = spatially_filtered_days

    cur_df = processed_df[DAY_TO_RUN - 1]
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL):
        print(f"Error: Data for Day {DAY_TO_RUN} is empty or invalid.")
        exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute J-vector, Taper Grid, and Taper Autocorrelation ---
    print("Pre-computing J-vector (Hamming taper)...")
    J_vec, n1, n2, p, taper_grid = generate_Jvector_tapered( # Use tapered version
        time_slices_list,
        tapering_func=TAPERING_FUNC, # Pass Hamming
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec is None or J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0:
       print(f"Error: J-vector generation failed for Day {DAY_TO_RUN}.")
       exit()

    print("Pre-computing sample periodogram...")
    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    print("Pre-computing Hamming taper autocorrelation...")
    taper_autocorr_grid = calculate_taper_autocorrelation_fft(taper_grid, n1, n2, DEVICE)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Error: NaN/Inf in sample periodogram.")
        exit()
    if torch.isnan(taper_autocorr_grid).any() or torch.isinf(taper_autocorr_grid).any():
        print("Error: NaN/Inf in taper autocorrelation.")
        exit()

    print(f"Data grid: {n1}x{n2}, {p} time points. J-vector, Periodogram, Taper Autocorr on {DEVICE}.")

Using device: cpu
Pre-computing J-vector (Hamming taper)...
Pre-computing sample periodogram...
Pre-computing Hamming taper autocorrelation...
Data grid: 113x158, 1 time points. J-vector, Periodogram, Taper Autocorr on cpu.


# best output below 1

In [65]:
# =========================================================================
# 6. Main Execution Script (HAMMING Tapering, Spatially Differenced Data, EXP Kernel)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    TAPERING_FUNC = cgn_hamming # Use Hamming taper
    NUM_RUNS = 1
    EPOCHS = 700
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    DELTA_LAT, DELTA_LON = 0.044, 0.063

    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2 # Spatially differenced value
    TIME_COL = 3

    # --- Load Spatially Differenced Data ---
    try:
        with open("spatial_first_difference_data.pkl", 'rb') as f:
            processed_df = pickle.load(f)
        print(f"Loaded {len(processed_df)} days from spatial_first_difference_data.pkl.")
        processed_df = [
            torch.tensor(arr, dtype=torch.float32).cpu() if not isinstance(arr, torch.Tensor)
            else arr.cpu().to(torch.float32)
            for arr in processed_df
        ]
        if not processed_df: raise ValueError("'processed_df' is empty.")
    except FileNotFoundError:
        print("Error: `spatial_first_difference_data.pkl` not found.")
        exit()
    except Exception as e:
        print(f"Error loading or processing 'processed_df': {e}")
        exit()

    if DAY_TO_RUN > len(processed_df) or DAY_TO_RUN <= 0:
        print(f"Error: DAY_TO_RUN ({DAY_TO_RUN}) out of bounds.")
        exit()

    cur_df = processed_df[DAY_TO_RUN - 1]
    cur_df = cur_df[cur_df[:,3]==cur_df[:,3].min()]
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL):
        print(f"Error: Data for Day {DAY_TO_RUN} is empty or invalid.")
        exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute J-vector, Taper Grid, and Taper Autocorrelation ---
    print("Pre-computing J-vector (Hamming taper)...")
    J_vec, n1, n2, p, taper_grid = generate_Jvector_tapered( # Use tapered version
        time_slices_list,
        tapering_func=TAPERING_FUNC, # Pass Hamming
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec is None or J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0:
       print(f"Error: J-vector generation failed for Day {DAY_TO_RUN}.")
       exit()

    print("Pre-computing sample periodogram...")
    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    print("Pre-computing Hamming taper autocorrelation...")
    taper_autocorr_grid = calculate_taper_autocorrelation_fft(taper_grid, n1, n2, DEVICE)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Error: NaN/Inf in sample periodogram.")
        exit()
    if torch.isnan(taper_autocorr_grid).any() or torch.isinf(taper_autocorr_grid).any():
        print("Error: NaN/Inf in taper autocorrelation.")
        exit()

    print(f"Data grid: {n1}x{n2}, {p} time points. J-vector, Periodogram, Taper Autocorr on {DEVICE}.")

    # --- 2. Optimization Loop ---
    all_final_results = []
    all_final_losses = []

    for i in range(NUM_RUNS):
        print(f"\n{'='*30} Initialization Run {i+1}/{NUM_RUNS} {'='*30}")

        initial_params_values = [np.log(22), np.log(1),np.log(1.5), 0, 0, 0, np.log(4.769)]
        print(f"Starting with FIXED params (log-scale): {[round(p, 4) for p in initial_params_values]}")

        params_list = [
            Parameter(torch.tensor([val], dtype=torch.float32))
            for val in initial_params_values
        ]

        lr_slow, lr_fast = 0.005, 0.02
        slow_indices = [0, 1, 2, 6]
        fast_indices = [3, 4, 5]

        valid_slow_indices = [idx for idx in slow_indices if idx < len(params_list)]
        valid_fast_indices = [idx for idx in fast_indices if idx < len(params_list)]

        param_groups = [
            {'params': [params_list[idx] for idx in valid_slow_indices], 'lr': lr_slow, 'name': 'slow_group'},
            {'params': [params_list[idx] for idx in valid_fast_indices], 'lr': lr_fast, 'name': 'fast_group'}
        ]

        optimizer = torch.optim.Adam(param_groups)

        T_MAX = 200
        ETA_MIN = 1e-6
        scheduler = CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=ETA_MIN)

        print(f"Starting optimization run {i+1} on device {DEVICE} (Hamming taper, EXP kernel)...")
        # --- Use the TAPERED training loop ---
        final_results, epochs_run = run_full_tapered(
            params_list=params_list,
            optimizer=optimizer,
            scheduler=scheduler,
            I_sample=I_sample,
            n1=n1, n2=n2, p=p,
            taper_autocorr_grid=taper_autocorr_grid, # Pass the autocorrelation grid
            epochs=EPOCHS,
            device=DEVICE
        )

        if final_results:
            all_final_results.append(final_results)
            all_final_losses.append(final_results[-1])
        else:
            all_final_results.append(None)
            all_final_losses.append(float('inf'))

    print(f"\n\n{'='*25} Overall Result from Run {'='*25}")
    valid_losses = [l for l in all_final_losses if l is not None and l != float('inf')]

    if not valid_losses:
        print(f"The run failed or resulted in an invalid loss for Day {DAY_TO_RUN}.")
    else:
        best_loss = valid_losses[0]
        best_run_index = 0
        best_results = all_final_results[best_run_index]
        print(f"Run Loss: {best_results[-1]}")
        print(f"Final Parameters (Natural Scale): {best_results[:-1]}")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Loaded 31 days from spatial_first_difference_data.pkl.
Pre-computing J-vector (Hamming taper)...
Pre-computing sample periodogram...
Pre-computing Hamming taper autocorrelation...
Data grid: 113x158, 1 time points. J-vector, Periodogram, Taper Autocorr on cpu.

Starting with FIXED params (log-scale): [3.091, 0.0, 0.4055, 0, 0, 0, 1.5621]
Starting optimization run 1 on device cpu (Hamming taper, EXP kernel)...
--- Epoch 1/700 (LR: 0.005000) ---
 Loss: 7640.4780
 Parameters (Natural Scale): [21.8903  1.005   1.5075  0.      0.      0.      4.7452]
--- Epoch 51/700 (LR: 0.004240) ---
 Loss: 6121.1533
 Parameters (Natural Scale): [17.3807  1.266   1.8986  0.      0.      0.      3.7425]
--- Epoch 101/700 (LR: 0.002461) ---
 Loss: 5632.1089
 Parameters (Natural Scale): [14.9045  1.477   2.2124  0.      0.      0.      3.1535]
--- Epoch 151/700 (LR: 0.000706) ---
 Loss: 5595.5435
 Parameters (Natural Scale): [14.1552  1.5608  2.2843  0.      0.      0.      2.9785]
--- Epoc

In [None]:

# =========================================================================
# 6. Main Execution Script (HAMMING Tapering, Spatially Differenced Data, EXP Kernel)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    TAPERING_FUNC = cgn_hamming # Use Hamming taper
    NUM_RUNS = 1
    EPOCHS = 700
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    DELTA_LAT, DELTA_LON = 0.044, 0.063

    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2 # Spatially differenced value
    TIME_COL = 3

    # --- Load Spatially Differenced Data ---
    processed_df = spatially_filtered_days

    cur_df = processed_df[DAY_TO_RUN - 1]
    cur_df = cur_df[cur_df[:,3]==cur_df[:,3].min()]
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL):
        print(f"Error: Data for Day {DAY_TO_RUN} is empty or invalid.")
        exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute J-vector, Taper Grid, and Taper Autocorrelation ---
    print("Pre-computing J-vector (Hamming taper)...")
    J_vec, n1, n2, p, taper_grid = generate_Jvector_tapered( # Use tapered version
        time_slices_list,
        tapering_func=TAPERING_FUNC, # Pass Hamming
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec is None or J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0:
       print(f"Error: J-vector generation failed for Day {DAY_TO_RUN}.")
       exit()

    print("Pre-computing sample periodogram...")
    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    print("Pre-computing Hamming taper autocorrelation...")
    taper_autocorr_grid = calculate_taper_autocorrelation_fft(taper_grid, n1, n2, DEVICE)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any():
        print("Error: NaN/Inf in sample periodogram.")
        exit()
    if torch.isnan(taper_autocorr_grid).any() or torch.isinf(taper_autocorr_grid).any():
        print("Error: NaN/Inf in taper autocorrelation.")
        exit()

    print(f"Data grid: {n1}x{n2}, {p} time points. J-vector, Periodogram, Taper Autocorr on {DEVICE}.")

    # --- 2. Optimization Loop ---
    all_final_results = []
    all_final_losses = []

    for i in range(NUM_RUNS):
        print(f"\n{'='*30} Initialization Run {i+1}/{NUM_RUNS} {'='*30}")

        #initial_params_values = [np.log(21.303), np.log(1.007),np.log(1.563), 0, 0, 0, np.log(3.890)]
        initial_params_values = [np.log(21), np.log(1),np.log(1.5), 0, 0, 0, np.log(4.769)]   
  
        print(f"Starting with FIXED params (log-scale): {[round(p, 4) for p in initial_params_values]}")

        params_list = [
            Parameter(torch.tensor([val], dtype=torch.float32))
            for val in initial_params_values
        ]

        lr_slow, lr_fast = 0.005, 0.02
        slow_indices = [0, 1, 2,6]
        fast_indices = [3, 4, 5]

        valid_slow_indices = [idx for idx in slow_indices if idx < len(params_list)]
        valid_fast_indices = [idx for idx in fast_indices if idx < len(params_list)]

        param_groups = [
            {'params': [params_list[idx] for idx in valid_slow_indices], 'lr': lr_slow, 'name': 'slow_group'},
            {'params': [params_list[idx] for idx in valid_fast_indices], 'lr': lr_fast, 'name': 'fast_group'}
        ]

        optimizer = torch.optim.Adam(param_groups)

        T_MAX = 200
        ETA_MIN = 1e-6
        scheduler = CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=ETA_MIN)

        print(f"Starting optimization run {i+1} on device {DEVICE} (Hamming taper, EXP kernel)...")
        # --- Use the TAPERED training loop ---
        final_results, epochs_run = run_full_tapered(
            params_list=params_list,
            optimizer=optimizer,
            scheduler=scheduler,
            I_sample=I_sample,
            n1=n1, n2=n2, p=p,
            taper_autocorr_grid=taper_autocorr_grid, # Pass the autocorrelation grid
            epochs=EPOCHS,
            device=DEVICE
        )

        if final_results:
            all_final_results.append(final_results)
            all_final_losses.append(final_results[-1])
        else:
            all_final_results.append(None)
            all_final_losses.append(float('inf'))

    print(f"\n\n{'='*25} Overall Result from Run {'='*25}")
    valid_losses = [l for l in all_final_losses if l is not None and l != float('inf')]

    if not valid_losses:
        print(f"The run failed or resulted in an invalid loss for Day {DAY_TO_RUN}.")
    else:
        best_loss = valid_losses[0]
        best_run_index = 0
        best_results = all_final_results[best_run_index]
        print(f"Run Loss: {best_results[-1]}")
        print(f"Final Parameters (Natural Scale): {best_results[:-1]}")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Pre-computing J-vector (Hamming taper)...
Pre-computing sample periodogram...
Pre-computing Hamming taper autocorrelation...
Data grid: 113x158, 1 time points. J-vector, Periodogram, Taper Autocorr on cpu.

Starting with FIXED params (log-scale): [3.091, 0.0, 0.4055, 0, 0, 0, 1.5621]
Starting optimization run 1 on device cpu (Hamming taper, EXP kernel)...
--- Epoch 1/700 (LR: 0.005000) ---
 Loss: 7640.4780
 Parameters (Natural Scale): [21.8903  1.005   1.5075  0.      0.      0.      4.7452]
--- Epoch 51/700 (LR: 0.004240) ---
 Loss: 6121.1533
 Parameters (Natural Scale): [17.3807  1.266   1.8986  0.      0.      0.      3.7425]
--- Epoch 101/700 (LR: 0.002461) ---
 Loss: 5632.1089
 Parameters (Natural Scale): [14.9045  1.477   2.2124  0.      0.      0.      3.1535]
--- Epoch 151/700 (LR: 0.000706) ---
 Loss: 5595.5435
 Parameters (Natural Scale): [14.1552  1.5608  2.2843  0.      0.      0.      2.9785]
--- Epoch 201/700 (LR: 0.000001) ---
 Loss: 5595.4917
 Paramete

### whittle likelihood function for this hamming + spatial differenced data

In [None]:
# =========================================================================
# 2. Likelihood Evaluation Function (Accepts Natural Scale Parameters)
# =========================================================================

def calculate_whittle_likelihood(
    param_estimates_natural_scale: list or np.ndarray, # <<< Renamed input
    I_sample: torch.Tensor,
    taper_autocorr_grid: torch.Tensor,
    n1: int,
    n2: int,
    p: int,
    delta_lat: float,
    delta_lon: float,
    device: torch.device = torch.device('cpu')
    ) -> float:
    """
    Calculates the Whittle likelihood (negative log-likelihood) for given estimates
    provided in NATURAL SCALE.

    Assumes Hamming tapering applied to spatially differenced data,
    using exact taper autocorrelation and an underlying Exponential kernel for X.

    Args:
        param_estimates_natural_scale: List/array of 7 params in NATURAL SCALE
                                       [sigmasq, r_lat, r_lon, a_lat, a_lon, beta, nugget].
        I_sample: Pre-computed sample periodogram tensor [n1, n2, p, p].
        taper_autocorr_grid: Pre-computed taper autocorrelation grid [2n1-1, 2n2-1].
        n1, n2, p: Grid and time dimensions.
        delta_lat, delta_lon: Grid spacings used for differencing.
        device: Torch device for calculation.

    Returns:
        Scalar negative log-likelihood value (float), or float('inf')/float('nan').
    """
    start_calc_time = time.time()

    # --- Input Validation and Tensor Conversion ---
    try:
        # Convert natural scale input to tensor
        if not isinstance(param_estimates_natural_scale, torch.Tensor):
            params_natural_tensor = torch.tensor(param_estimates_natural_scale, dtype=torch.float32)
        else:
            params_natural_tensor = param_estimates_natural_scale.float()
        params_natural_tensor = params_natural_tensor.to(device)

        if params_natural_tensor.shape != (7,):
             raise ValueError(f"Expected 7 params, got {params_natural_tensor.shape}")
        if torch.isnan(params_natural_tensor).any() or torch.isinf(params_natural_tensor).any():
             raise ValueError("NaN/Inf detected in input natural scale params.")

        # <<< NEW: Convert relevant parameters BACK to log-scale >>>
        params_log_scale_tensor = params_natural_tensor.clone()
        log_indices = [0, 1, 2, 6] # sigmasq, r_lat, r_lon, nugget
        # Check for non-positive values before taking log
        if torch.any(params_natural_tensor[log_indices] <= 0):
            raise ValueError("Parameters sigmasq, r_lat, r_lon, nugget must be positive in natural scale.")
        params_log_scale_tensor[log_indices] = torch.log(params_natural_tensor[log_indices])
        # <<< END NEW SECTION >>>

        # Ensure other inputs are tensors on the correct device
        I_sample = I_sample.to(device)
        taper_autocorr_grid = taper_autocorr_grid.to(device)
        if torch.isnan(I_sample).any() or torch.isinf(I_sample).any(): raise ValueError("NaN/Inf in I_sample.")
        if torch.isnan(taper_autocorr_grid).any() or torch.isinf(taper_autocorr_grid).any(): raise ValueError("NaN/Inf in taper_autocorr_grid.")

    except Exception as e:
        print(f"Error during input validation/conversion: {e}")
        return float('nan')

    # --- Likelihood Calculation ---
    with torch.no_grad():
        # Pass the LOG-SCALE tensor to the loss function
        likelihood_value = whittle_likelihood_loss_tapered(
            params=params_log_scale_tensor, # <<< Use log-scale tensor here
            I_sample=I_sample,
            n1=n1, n2=n2, p=p,
            taper_autocorr_grid=taper_autocorr_grid,
            delta1=delta_lat,
            delta2=delta_lon
        )

    end_calc_time = time.time()

    # --- Return Scalar Value ---
    if isinstance(likelihood_value, torch.Tensor):
        if torch.isinf(likelihood_value) and likelihood_value > 0: return float('inf')
        if torch.isnan(likelihood_value): return float('nan')
        return likelihood_value.item()
    else: return float(likelihood_value)



# =========================================================================
# 6. Main Execution Script (HAMMING Tapering, Spatially Differenced Data, EXP Kernel)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    TAPERING_FUNC = cgn_hamming # Use Hamming taper
    # NUM_RUNS = 1 # Not needed if not optimizing
    # EPOCHS = 700 # Not needed if not optimizing
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    DELTA_LAT, DELTA_LON = 0.044, 0.063

    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2 # Spatially differenced value
    TIME_COL = 3

    # --- Load Spatially Differenced Data ---
    try:
        with open("spatial_first_difference_data.pkl", 'rb') as f:
            processed_df = pickle.load(f)
        print(f"Loaded {len(processed_df)} days from spatial_first_difference_data.pkl.")
        processed_df = [
            torch.tensor(arr, dtype=torch.float32).cpu() if not isinstance(arr, torch.Tensor)
            else arr.cpu().to(torch.float32)
            for arr in processed_df
        ]
        if not processed_df: raise ValueError("'processed_df' is empty.")
    except FileNotFoundError: print("Error: `spatial_first_difference_data.pkl` not found."); exit()
    except Exception as e: print(f"Error loading 'processed_df': {e}"); exit()

    if DAY_TO_RUN > len(processed_df) or DAY_TO_RUN <= 0: print(f"Error: DAY_TO_RUN out of bounds."); exit()

    cur_df = processed_df[DAY_TO_RUN - 1]
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL): print(f"Error: Data for Day {DAY_TO_RUN} empty/invalid."); exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute J-vector, Taper Grid, and Taper Autocorrelation ---
    print(f"Pre-computing J-vector ({TAPERING_FUNC.__name__} taper)...")
    J_vec, n1, n2, p, taper_grid = generate_Jvector_tapered(
        time_slices_list,
        tapering_func=TAPERING_FUNC,
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec is None or J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0: print(f"Error: J-vector generation failed."); exit()
    if taper_grid is None: print("Error: Taper grid not generated."); exit()

    print("Pre-computing sample periodogram...")
    I_sample = calculate_sample_periodogram_vectorized(J_vec) # Now defined here

    print(f"Pre-computing {TAPERING_FUNC.__name__} taper autocorrelation...")
    taper_autocorr_grid = calculate_taper_autocorrelation_fft(taper_grid, n1, n2, DEVICE) # Now defined here

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any(): print("Error: NaN/Inf in sample periodogram."); exit()
    if torch.isnan(taper_autocorr_grid).any() or torch.isinf(taper_autocorr_grid).any(): print("Error: NaN/Inf in taper autocorrelation."); exit()

    print(f"Data grid: {n1}x{n2}, {p} time points. J-vector, Periodogram, Taper Autocorr on {DEVICE}.")

    # --- 2. Define Parameter Estimates to Evaluate (NATURAL SCALE) ---
    # Example 1: Your fixed initial parameters (converted to natural scale)
    param_estimates_initial_natural = [21.303, 1.307, 1.563, 0.022, -0.144, 0.198, 4.769] 

    # Example 2: Hypothetical optimized parameters (NATURAL SCALE)
    param_estimates_optimized_natural =  [31.2594, 0.665, 1.8981, 0.0, 0.1317, -0.0, 1.9785]

    # --- 3. Calculate Likelihood Directly ---
    print("\nCalculating likelihood for initial parameters (natural scale):")
    # Convert natural to log scale for the likelihood function
    params_initial_log = torch.tensor(param_estimates_initial_natural, dtype=torch.float32)
    log_indices = [0, 1, 2, 6]
    params_initial_log[log_indices] = torch.log(params_initial_log[log_indices])

    with torch.no_grad(): # Disable gradients for evaluation
        likelihood_initial = whittle_likelihood_loss_tapered(
            params=params_initial_log.to(DEVICE), # Pass log-scale tensor
            I_sample=I_sample.to(DEVICE),
            n1=n1, n2=n2, p=p,
            taper_autocorr_grid=taper_autocorr_grid.to(DEVICE),
            delta1=DELTA_LAT, delta2=DELTA_LON
        )
    print(f"-> Whittle Likelihood ('Vecchia Optimized' Params): {likelihood_initial.item():.3f}")


    print("\nCalculating likelihood for 'optimized' parameters (natural scale):")
    # Convert natural to log scale for the likelihood function
    params_optimized_log = torch.tensor(param_estimates_optimized_natural, dtype=torch.float32)
    # Check for positivity before log
    if torch.any(params_optimized_log[log_indices] <= 0):
         print("Error: Optimized params have non-positive values for log conversion.")
         likelihood_optimized = torch.tensor(float('nan')) # Or handle appropriately
    else:
         params_optimized_log[log_indices] = torch.log(params_optimized_log[log_indices])

         with torch.no_grad():
             likelihood_optimized = whittle_likelihood_loss_tapered(
                 params=params_optimized_log.to(DEVICE), # Pass log-scale tensor
                 I_sample=I_sample.to(DEVICE),
                 n1=n1, n2=n2, p=p,
                 taper_autocorr_grid=taper_autocorr_grid.to(DEVICE),
                 delta1=DELTA_LAT, delta2=DELTA_LON
             )
    print(f"-> Whittle Likelihood ('Whittle Optimized' Params): {likelihood_optimized.item():.3f}")


    # --- 4. Compare Results ---
    likelihood_initial_val = likelihood_initial.item()
    likelihood_optimized_val = likelihood_optimized.item()
    if not (np.isnan(likelihood_initial_val) or np.isnan(likelihood_optimized_val) or np.isinf(likelihood_initial_val) or np.isinf(likelihood_optimized_val)):
        if likelihood_optimized_val < likelihood_initial_val: print(f"\n'Optimized' parameters better by {likelihood_initial_val - likelihood_optimized_val:.3f}.")
        elif likelihood_initial_val < likelihood_optimized_val: print(f"\nInitial parameters better by {likelihood_optimized_val - likelihood_initial_val:.3f}.")
        else: print("\nBoth parameter sets yield the same likelihood.")
    else: print("\nCould not compare likelihoods due to NaN or Inf results.")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Loaded 31 days from spatial_first_difference_data.pkl.
Pre-computing J-vector (cgn_hamming taper)...
Pre-computing sample periodogram...
Pre-computing cgn_hamming taper autocorrelation...
Data grid: 113x158, 8 time points. J-vector, Periodogram, Taper Autocorr on cpu.

Calculating likelihood for initial parameters (natural scale):
-> Whittle Likelihood ('Vecchia Optimized' Params): 57998.410

Calculating likelihood for 'optimized' parameters (natural scale):
-> Whittle Likelihood ('Whittle Optimized' Params): 41623.910

'Optimized' parameters better by 16374.500.

Total execution time: 1.68 seconds


## bartlett tapering

In [22]:
import torch
import numpy as np
import matplotlib.pyplot as plt # Keep if plotting might be added later
import cmath
import pickle
import time # For timing
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import Parameter
import torch.fft # Explicit import for fft functions

# =========================================================================
# 1. Tapering, Autocorrelation, and Covariance Functions
# =========================================================================

# --- Tapering Functions (Bartlett is used) ---
def cgn_hamming(u, n1, n2): # Kept for potential future use
    """Computes a 2D Hamming window."""
    u1, u2 = u
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    hamming1 = 0.54 + 0.46 * torch.cos(2.0 * torch.pi * u1_tensor / n1_eff)
    hamming2 = 0.54 + 0.46 * torch.cos(2.0 * torch.pi * u2_tensor / n2_eff)
    return hamming1 * hamming2

def cgn_2dbartlett(u, n1, n2):
    """Computes a 2D Bartlett window function."""
    u1, u2 = u
    device = u1.device if isinstance(u1, torch.Tensor) else (u2.device if isinstance(u2, torch.Tensor) else torch.device('cpu'))
    u1_tensor = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_tensor = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    n1_eff = float(n1) if n1 > 0 else 1.0
    n2_eff = float(n2) if n2 > 0 else 1.0
    kernel = (1.0 - torch.abs(u1_tensor) / n1_eff) * (1.0 - torch.abs(u2_tensor) / n2_eff)
    return torch.clamp(kernel, min=0.0)

# --- Function to Calculate Taper Autocorrelation (Unchanged) ---
def calculate_taper_autocorrelation_fft(taper_grid, n1, n2, device):
    """
    Computes the normalized taper autocorrelation function c_gn(u) using FFT.
    """
    taper_grid = taper_grid.to(device)
    H = torch.sum(taper_grid**2)
    if H < 1e-12:
        print("Warning: Sum of squared taper weights (H) is near zero.")
        return torch.zeros((2*n1-1, 2*n2-1), device=device, dtype=taper_grid.dtype)
    N1, N2 = 2 * n1 - 1, 2 * n2 - 1
    taper_fft = torch.fft.fft2(taper_grid, s=(N1, N2))
    power_spectrum = torch.abs(taper_fft)**2
    autocorr_unnormalized = torch.fft.ifft2(power_spectrum).real
    autocorr_shifted = torch.fft.fftshift(autocorr_unnormalized)
    c_gn_grid = autocorr_shifted / (H + 1e-12)
    return c_gn_grid

# --- Covariance of Original Field X (EXPONENTIAL Kernel) (Unchanged) ---
def cov_x_exponential(u1, u2, t, params):
    """
    Computes spatio-temporal autocovariance of X using EXPONENTIAL kernel.
    Expects log-scale params [0,1,2,6]. Handles device internally.
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)
    log_params_indices = [0, 1, 2, 6]
    if torch.isnan(params[log_params_indices]).any() or torch.isinf(params[log_params_indices]).any():
         print("Warning: NaN/Inf in log-params before exp in cov_x_exponential")
         out_shape = torch.broadcast_shapes(u1_dev.shape, u2_dev.shape, t_dev.shape)
         return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)
    sigmasq, r_lat, r_lon, nugget = torch.exp(params[log_params_indices])
    a_lat, a_lon, beta = params[3], params[4], params[5]
    if r_lat < 1e-6 or r_lon < 1e-6:
        r_lat = torch.clamp(r_lat, min=1e-6); r_lon = torch.clamp(r_lon, min=1e-6)
    x1 = u1_dev / r_lat - a_lat * t_dev
    x2 = u2_dev / r_lon - a_lon * t_dev
    x3 = beta * t_dev
    distance_sq = x1**2 + x2**2 + x3**2
    epsilon = 1e-12
    distance_sq_clamped = torch.clamp(distance_sq, min=0.0)
    D = torch.sqrt(distance_sq_clamped + epsilon)
    cov_smooth = sigmasq * torch.exp(-D)
    is_zero_lag = (torch.abs(u1_dev) < 1e-9) & (torch.abs(u2_dev) < 1e-9) & (torch.abs(t_dev) < 1e-9)
    final_cov = torch.where(is_zero_lag, cov_smooth + nugget, cov_smooth)
    if torch.isnan(final_cov).any(): print("Warning: NaN detected in cov_x_exponential output.")
    return final_cov

# --- Covariance of Spatially Differenced Field Y (Unchanged structure) ---
def cov_spatial_difference(u1, u2, t, params, delta1, delta2):
    """
    Calculates covariance Cov(Y(s), Y(s+u)) for Y(s) = X(s+d1) + X(s+d2) - 2X(s)
    Based on the underlying EXPONENTIAL covariance cov_x_exponential.
    """
    weights = {(0, 0): -2.0, (1, 0): 1.0, (0, 1): 1.0}
    device = params.device
    out_shape = torch.broadcast_shapes(u1.shape if isinstance(u1, torch.Tensor) else (),
                                     u2.shape if isinstance(u2, torch.Tensor) else (),
                                     t.shape if isinstance(t, torch.Tensor) else ())
    cov = torch.zeros(out_shape, device=device, dtype=torch.float32)
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)
    for (a_idx, b_idx), w_ab in weights.items():
        offset_a1 = a_idx * delta1; offset_a2 = b_idx * delta2
        for (c_idx, d_idx), w_cd in weights.items():
            offset_c1 = c_idx * delta1; offset_c2 = d_idx * delta2
            lag_u1 = u1_dev + (offset_a1 - offset_c1)
            lag_u2 = u2_dev + (offset_a2 - offset_c2)
            term_cov = cov_x_exponential(lag_u1, lag_u2, t_dev, params) # Calls EXPONENTIAL version
            if torch.isnan(term_cov).any():
                 print(f"Warning: NaN in term_cov within cov_spatial_difference.")
                 return torch.full_like(cov, float('nan'))
            cov += w_ab * w_cd * term_cov
    if torch.isnan(cov).any(): print("Warning: NaN in final cov_spatial_difference output.")
    return cov

# --- cn_bar using Taper Autocorrelation (Unchanged structure) ---
def cn_bar_tapered(u1, u2, t, params, n1, n2, taper_autocorr_grid, delta1, delta2):
    """
    Computes c_X(u) * c_gn(u) using the pre-computed taper_autocorr_grid.
    """
    device = params.device
    u1_dev = u1.to(device) if isinstance(u1, torch.Tensor) else torch.tensor(u1, device=device, dtype=torch.float32)
    u2_dev = u2.to(device) if isinstance(u2, torch.Tensor) else torch.tensor(u2, device=device, dtype=torch.float32)
    t_dev = t.to(device) if isinstance(t, torch.Tensor) else torch.tensor(t, device=device, dtype=torch.float32)
    cov_X_value = cov_spatial_difference(u1_dev, u2_dev, t_dev, params, delta1, delta2)
    u1_idx = u1_dev.long(); u2_idx = u2_dev.long()
    # Indices into the centered autocorrelation grid [2n1-1, 2n2-1]
    idx1 = (n1 - 1 + u1_idx); idx2 = (n2 - 1 + u2_idx)
    # Clamp indices to ensure they are within the valid range
    idx1 = torch.clamp(idx1, 0, 2 * n1 - 2); idx2 = torch.clamp(idx2, 0, 2 * n2 - 2)
    # Ensure taper_autocorr_grid is on the correct device
    taper_autocorr_value = taper_autocorr_grid.to(device)[idx1, idx2]
    if torch.isnan(cov_X_value).any() or torch.isnan(taper_autocorr_value).any():
        print("Warning: NaN detected before multiplication in cn_bar_tapered.")
        out_shape = torch.broadcast_shapes(cov_X_value.shape, taper_autocorr_value.shape)
        return torch.full(out_shape, float('nan'), device=device, dtype=torch.float32)
    result = cov_X_value * taper_autocorr_value
    if torch.isnan(result).any(): print("Warning: NaN in cn_bar_tapered output.")
    return result

# --- Expected Periodogram (uses cn_bar_tapered) (Unchanged structure) ---
def expected_periodogram_fft_tapered(params, n1, n2, p, taper_autocorr_grid, cov_func, delta1, delta2):
    """
    Calculates the expected periodogram I(omega_s) using c_gn(u).
    """
    device = params.device if isinstance(params, torch.Tensor) else params[0].device
    if isinstance(params, list): params_tensor = torch.cat([p.to(device) for p in params])
    else: params_tensor = params.to(device)
    product_tensor = torch.zeros((n1, n2, p, p), dtype=torch.complex64, device=device)
    t_lags = torch.arange(p, dtype=torch.float32, device=device)
    u1_mesh_grid, u2_mesh_grid = torch.meshgrid(
        torch.arange(n1, dtype=torch.float32, device=device),
        torch.arange(n2, dtype=torch.float32, device=device), indexing='ij'
    )
    for q in range(p):
        for r in range(p):
            t_diff = t_lags[q] - t_lags[r]
            cov_times_autocorr = cn_bar_tapered(
                u1_mesh_grid, u2_mesh_grid, t_diff,
                params_tensor, n1, n2, taper_autocorr_grid, delta1, delta2
            )
            if torch.isnan(cov_times_autocorr).any():
                 print(f"Warning: NaN in cov_times_autocorr t_lag {t_diff.item():.2f}.")
                 product_tensor[:, :, q, r] = float('nan')
            else:
                 product_tensor[:, :, q, r] = cov_times_autocorr.to(torch.complex64)
    if torch.isnan(product_tensor).any():
        print("Warning: NaN in product_tensor before FFT.")
        nan_shape = (n1, n2, p, p); return torch.full(nan_shape, float('nan'), dtype=torch.complex64, device=device)
    fft_result = torch.fft.fft2(product_tensor, dim=(0, 1))
    normalization_factor = 1.0 / (4.0 * cmath.pi**2)
    result = fft_result * normalization_factor
    if torch.isnan(result).any(): print("Warning: NaN in expected_periodogram_fft_tapered output.")
    return result

# =========================================================================
# 2. Data Processing (Uses Tapering)
# =========================================================================
def generate_Jvector_tapered(tensor_list, tapering_func, lat_col, lon_col, val_col, device):
    """
    Generates J-vector using the specified taper, returns J-vector and taper grid.
    """
    p = len(tensor_list)
    taper_grid_out = None # Initialize
    if p == 0: return torch.empty(0, 0, 0, device=device), 0, 0, 0, taper_grid_out

    valid_tensors = [t for t in tensor_list if t.numel() > 0 and t.shape[1] > max(lat_col, lon_col, val_col)]
    if not valid_tensors:
         print("Warning: No valid tensors found in tensor_list."); return torch.empty(0, 0, 0, device=device), 0, 0, 0, taper_grid_out

    try:
        all_lats_cpu = torch.cat([t[:, lat_col] for t in valid_tensors])
        all_lons_cpu = torch.cat([t[:, lon_col] for t in valid_tensors])
    except IndexError: print(f"Error: Invalid column index."); return torch.empty(0, 0, 0, device=device), 0, 0, 0, taper_grid_out

    all_lats_cpu = all_lats_cpu[~torch.isnan(all_lats_cpu) & ~torch.isinf(all_lats_cpu)]
    all_lons_cpu = all_lons_cpu[~torch.isnan(all_lons_cpu) & ~torch.isinf(all_lons_cpu)]
    if all_lats_cpu.numel() == 0 or all_lons_cpu.numel() == 0:
        print("Warning: No valid coords."); return torch.empty(0, 0, 0, device=device), 0, 0, 0, taper_grid_out

    unique_lats_cpu, unique_lons_cpu = torch.unique(all_lats_cpu), torch.unique(all_lons_cpu)
    n1, n2 = len(unique_lats_cpu), len(unique_lons_cpu)
    if n1 == 0 or n2 == 0: print("Warning: Grid dims zero."); return torch.empty(0, 0, 0, device=device), 0, 0, 0, taper_grid_out

    lat_map = {lat.item(): i for i, lat in enumerate(unique_lats_cpu)}
    lon_map = {lon.item(): i for i, lon in enumerate(unique_lons_cpu)}

    u1_mesh_cpu, u2_mesh_cpu = torch.meshgrid(torch.arange(n1, dtype=torch.float32), torch.arange(n2, dtype=torch.float32), indexing='ij')
    taper_grid_out = tapering_func((u1_mesh_cpu, u2_mesh_cpu), n1, n2).to(device) # Taper on device

    fft_results = []
    for tensor in tensor_list:
        data_grid = torch.zeros((n1, n2), dtype=torch.float32, device=device)
        if tensor.numel() > 0 and tensor.shape[1] > max(lat_col, lon_col, val_col):
            for row in tensor:
                lat_item, lon_item = row[lat_col].item(), row[lon_col].item()
                if not (np.isnan(lat_item) or np.isnan(lon_item)):
                    i = lat_map.get(lat_item); j = lon_map.get(lon_item)
                    if i is not None and j is not None:
                        val = row[val_col]; val_num = val.item() if isinstance(val, torch.Tensor) else val
                        if not np.isnan(val_num) and not np.isinf(val_num): data_grid[i, j] = val_num

        data_grid_tapered = data_grid * taper_grid_out # Apply taper
        if torch.isnan(data_grid_tapered).any() or torch.isinf(data_grid_tapered).any():
             print("Warning: NaN/Inf in tapered data grid. Replacing zeros."); data_grid_tapered = torch.nan_to_num(data_grid_tapered)
        fft_results.append(torch.fft.fft2(data_grid_tapered))

    if not fft_results: print("Warning: No FFT results."); return torch.empty(0, 0, 0, device=device), n1, n2, 0, taper_grid_out

    J_vector_tensor = torch.stack(fft_results, dim=2).to(device)
    H = torch.sum(taper_grid_out**2)
    if H < 1e-12: print("Warning: H near zero."); norm_factor = torch.tensor(0.0, device=device)
    else: norm_factor = (torch.sqrt(1.0 / H) / (2.0 * cmath.pi)).to(device)

    result = J_vector_tensor * norm_factor
    if torch.isnan(result).any(): print("Warning: NaN in J_vector output.")
    return result, n1, n2, p, taper_grid_out

def calculate_sample_periodogram_vectorized(J_vector_tensor):
    """Calculates sample periodogram I_n = J J^H."""
    if torch.isnan(J_vector_tensor).any() or torch.isinf(J_vector_tensor).any():
        print("Warning: NaN/Inf in J_vector input."); n1,n2,p=J_vector_tensor.shape; return torch.full((n1,n2,p,p), float('nan'), dtype=torch.complex64, device=J_vector_tensor.device)
    J_col = J_vector_tensor.unsqueeze(-1); J_row_conj = J_vector_tensor.unsqueeze(-2).conj()
    result = J_col @ J_row_conj
    if torch.isnan(result).any(): print("Warning: NaN in periodogram matrix output.")
    return result

# =========================================================================
# 4. Likelihood Calculation (Adapted for Tapering with Autocorrelation)
# =========================================================================

def whittle_likelihood_loss_tapered(params, I_sample, n1, n2, p, taper_autocorr_grid, delta1, delta2):
    """ Whittle Likelihood using data tapering and exact taper autocorrelation c_gn."""
    device = I_sample.device
    params_tensor = params.to(device)
    if torch.isnan(params_tensor).any() or torch.isinf(params_tensor).any(): print("Warning: NaN/Inf in likelihood params."); return torch.tensor(float('nan'), device=device)
    I_expected = expected_periodogram_fft_tapered(
        params_tensor, n1, n2, p, taper_autocorr_grid, cov_spatial_difference, delta1, delta2
    )
    if torch.isnan(I_expected).any() or torch.isinf(I_expected).any(): print("Warning: NaN/Inf from expected periodogram."); return torch.tensor(float('nan'), device=device)
    eye_matrix = torch.eye(p, dtype=torch.complex64, device=device)
    diag_vals = torch.abs(I_expected.diagonal(dim1=-2, dim2=-1))
    mean_diag_abs = diag_vals.mean().item() if diag_vals.numel() > 0 and not torch.isnan(diag_vals).all() else 1.0
    diag_load = max(mean_diag_abs * 1e-8, 1e-9)
    I_expected_stable = I_expected + eye_matrix * diag_load
    sign, logabsdet = torch.linalg.slogdet(I_expected_stable)
    if torch.any(sign.real <= 1e-9): print("Warning: Non-positive determinant."); log_det_term = torch.where(sign.real > 1e-9, logabsdet, torch.tensor(1e10, device=device))
    else: log_det_term = logabsdet
    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any(): print("Warning: NaN/Inf in I_sample."); return torch.tensor(float('nan'), device=device)
    try:
        solved_term = torch.linalg.solve(I_expected_stable, I_sample)
        trace_term = torch.einsum('...ii->...', solved_term).real
    except torch.linalg.LinAlgError as e: print(f"Warning: LinAlgError: {e}."); return torch.tensor(float('inf'), device=device)
    if torch.isnan(trace_term).any() or torch.isinf(trace_term).any(): print("Warning: NaN/Inf in trace_term."); return torch.tensor(float('nan'), device=device)
    likelihood_terms = log_det_term + trace_term
    if torch.isnan(likelihood_terms).any(): print("Warning: NaN in likelihood_terms."); return torch.tensor(float('nan'), device=device)
    total_sum = torch.sum(likelihood_terms)
    dc_term = likelihood_terms[0, 0] if n1 > 0 and n2 > 0 else torch.tensor(0.0, device=device)
    if torch.isnan(dc_term).any() or torch.isinf(dc_term).any(): print("Warning: NaN/Inf DC term."); dc_term = torch.tensor(0.0, device=device)
    loss = total_sum - dc_term if (n1 > 1 or n2 > 1) else total_sum
    if torch.isnan(loss) or torch.isinf(loss): print("Warning: NaN/Inf final loss."); return torch.tensor(float('inf'), device=device)
    return loss

# =========================================================================
# 5. Training Loop (CORRECTED version, adapted for tapering)
# =========================================================================
def run_full_tapered(params_list, optimizer, scheduler, I_sample, n1, n2, p, taper_autocorr_grid, epochs=600, device='cpu'):
    """Corrected training loop using parameter list and tapered likelihood."""
    best_loss = float('inf')
    params_list = [p.to(device) for p in params_list]
    best_params_state = [p.detach().clone() for p in params_list]
    epochs_completed = 0
    DELTA_LAT, DELTA_LON = 0.044, 0.063
    def get_printable_params(p_list):
        valid_tensors = [p for p in p_list if isinstance(p, torch.Tensor)];
        if not valid_tensors: return "Invalid params_list"
        p_cat = torch.cat([p.detach().clone().cpu() for p in valid_tensors])
        log_indices = [0, 1, 2, 6]
        if all(idx < len(p_cat) for idx in log_indices):
            log_vals = p_cat[log_indices]
            if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()): p_cat[log_indices] = torch.exp(log_vals)
            else: p_cat[log_indices] = float('nan')
        return p_cat.numpy().round(4)
    I_sample_dev = I_sample.to(device)
    taper_autocorr_grid_dev = taper_autocorr_grid.to(device)
    for epoch in range(epochs):
        epochs_completed = epoch + 1
        optimizer.zero_grad()
        params_tensor = torch.cat(params_list)
        loss = whittle_likelihood_loss_tapered(
            params_tensor, I_sample_dev, n1, n2, p, taper_autocorr_grid_dev, DELTA_LAT, DELTA_LON
        )
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Loss NaN/Inf epoch {epoch+1}. Stop.");
            if epoch == 0: best_params_state = None;
            epochs_completed = epoch; break
        loss.backward()
        nan_grad = False
        for param in params_list:
            if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()): nan_grad = True; break
        if nan_grad: print(f"Warning: NaN/Inf grad epoch {epoch+1}. Skip step."); optimizer.zero_grad(); continue
        all_params_on_device = params_list
        if all_params_on_device: torch.nn.utils.clip_grad_norm_(all_params_on_device, max_norm=1.0)
        optimizer.step()
        scheduler.step()
        current_loss_item = loss.item()
        if current_loss_item < best_loss:
            params_valid = not any(torch.isnan(p.data).any() or torch.isinf(p.data).any() for p in params_list)
            if params_valid: best_loss = current_loss_item; best_params_state = [p.detach().clone() for p in params_list]
            else: print(f"Warning: NaN/Inf params epoch {epoch+1}. Not saving.")
        current_lr = optimizer.param_groups[0]['lr'] if optimizer.param_groups else 0.0
        if epoch % 50 == 0 or epoch == epochs - 1:
            print(f'--- Epoch {epoch+1}/{epochs} (LR: {current_lr:.6f}) ---'); print(f' Loss: {current_loss_item:.4f}')
            print(f' Parameters (Natural Scale): {get_printable_params(params_list)}')
    print("\n--- Training Complete ---")
    if best_params_state is None: print("Training failed."); return None, epochs_completed
    final_params_log_scale = torch.cat([p.cpu() for p in best_params_state])
    final_params_natural_scale = final_params_log_scale.detach().clone()
    log_indices = [0, 1, 2, 6]
    if all(idx < len(final_params_natural_scale) for idx in log_indices):
        log_vals = final_params_natural_scale[log_indices]
        if not (torch.isnan(log_vals).any() or torch.isinf(log_vals).any()): final_params_natural_scale[log_indices] = torch.exp(log_vals)
        else: final_params_natural_scale[log_indices] = float('nan')
    final_params_rounded = [round(p.item(), 4) if not np.isnan(p.item()) else float('nan') for p in final_params_natural_scale]
    final_loss_rounded = round(best_loss, 3) if best_loss != float('inf') else float('inf')
    print(f'\nFINAL BEST STATE ACHIEVED:'); print(f'Best Loss: {final_loss_rounded}')
    print(f'Parameters (Natural Scale): {final_params_rounded}')
    return final_params_rounded + [final_loss_rounded], epochs_completed

# =========================================================================
# 6. Main Execution Script (BARTLETT Tapering, Spatially Differenced Data, EXP Kernel)
# =========================================================================
if __name__ == '__main__':
    start_time = time.time()

    # --- Configuration ---
    DAY_TO_RUN = 1
    TAPERING_FUNC = cgn_2dbartlett # <<< Use Bartlett taper
    NUM_RUNS = 1
    EPOCHS = 700
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    DELTA_LAT, DELTA_LON = 0.044, 0.063

    LAT_COL, LON_COL = 0, 1
    VAL_COL = 2 # Spatially differenced value
    TIME_COL = 3

    # --- Load Spatially Differenced Data ---
    try:
        with open("spatial_first_difference_data.pkl", 'rb') as f:
            processed_df = pickle.load(f)
        print(f"Loaded {len(processed_df)} days from spatial_first_difference_data.pkl.")
        processed_df = [
            torch.tensor(arr, dtype=torch.float32).cpu() if not isinstance(arr, torch.Tensor)
            else arr.cpu().to(torch.float32)
            for arr in processed_df
        ]
        if not processed_df: raise ValueError("'processed_df' is empty.")
    except FileNotFoundError: print("Error: `spatial_first_difference_data.pkl` not found."); exit()
    except Exception as e: print(f"Error loading 'processed_df': {e}"); exit()

    if DAY_TO_RUN > len(processed_df) or DAY_TO_RUN <= 0: print(f"Error: DAY_TO_RUN out of bounds."); exit()

    cur_df = processed_df[DAY_TO_RUN - 1]
    if cur_df.numel() == 0 or cur_df.shape[1] <= max(LAT_COL, LON_COL, VAL_COL, TIME_COL): print(f"Error: Data for Day {DAY_TO_RUN} empty/invalid."); exit()

    unique_times = torch.unique(cur_df[:, TIME_COL])
    time_slices_list = [cur_df[cur_df[:, TIME_COL] == t_val] for t_val in unique_times]

    # --- 1. Pre-compute J-vector, Taper Grid, and Taper Autocorrelation ---
    print("Pre-computing J-vector (Bartlett taper)...")
    J_vec, n1, n2, p, taper_grid = generate_Jvector_tapered( # Use tapered version
        time_slices_list,
        tapering_func=TAPERING_FUNC, # Pass Bartlett
        lat_col=LAT_COL, lon_col=LON_COL, val_col=VAL_COL,
        device=DEVICE
    )

    if J_vec is None or J_vec.numel() == 0 or n1 == 0 or n2 == 0 or p == 0: print(f"Error: J-vector generation failed."); exit()

    print("Pre-computing sample periodogram...")
    I_sample = calculate_sample_periodogram_vectorized(J_vec)

    print("Pre-computing Bartlett taper autocorrelation...")
    taper_autocorr_grid = calculate_taper_autocorrelation_fft(taper_grid, n1, n2, DEVICE)

    if torch.isnan(I_sample).any() or torch.isinf(I_sample).any(): print("Error: NaN/Inf in sample periodogram."); exit()
    if torch.isnan(taper_autocorr_grid).any() or torch.isinf(taper_autocorr_grid).any(): print("Error: NaN/Inf in taper autocorrelation."); exit()

    print(f"Data grid: {n1}x{n2}, {p} time points. J-vector, Periodogram, Taper Autocorr on {DEVICE}.")

    # --- 2. Optimization Loop ---
    all_final_results = []
    all_final_losses = []

    for i in range(NUM_RUNS):
        print(f"\n{'='*30} Initialization Run {i+1}/{NUM_RUNS} {'='*30}")

        initial_params_values = [np.log(25), np.log(1),np.log(1.5), -0.05, -0.08, 0.05, np.log(2)]
        print(f"Starting with FIXED params (log-scale): {[round(p, 4) for p in initial_params_values]}")
        params_list = [Parameter(torch.tensor([val], dtype=torch.float32)) for val in initial_params_values]

        lr_slow, lr_fast = 0.005, 0.02
        slow_indices = [0, 1, 2, 6]; fast_indices = [3, 4, 5]
        valid_slow_indices = [idx for idx in slow_indices if idx < len(params_list)]
        valid_fast_indices = [idx for idx in fast_indices if idx < len(params_list)]
        param_groups = [
            {'params': [params_list[idx] for idx in valid_slow_indices], 'lr': lr_slow, 'name': 'slow_group'},
            {'params': [params_list[idx] for idx in valid_fast_indices], 'lr': lr_fast, 'name': 'fast_group'}
        ]
        optimizer = torch.optim.Adam(param_groups)
        T_MAX = 200; ETA_MIN = 1e-6
        scheduler = CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=ETA_MIN)

        print(f"Starting optimization run {i+1} on device {DEVICE} (Bartlett taper, EXP kernel)...")
        # --- Use the TAPERED training loop ---
        final_results, epochs_run = run_full_tapered(
            params_list=params_list,
            optimizer=optimizer,
            scheduler=scheduler,
            I_sample=I_sample,
            n1=n1, n2=n2, p=p,
            taper_autocorr_grid=taper_autocorr_grid, # Pass the autocorrelation grid
            epochs=EPOCHS,
            device=DEVICE
        )

        if final_results: all_final_results.append(final_results); all_final_losses.append(final_results[-1])
        else: all_final_results.append(None); all_final_losses.append(float('inf'))

    print(f"\n\n{'='*25} Overall Result from Run {'='*25}")
    valid_losses = [l for l in all_final_losses if l is not None and l != float('inf')]
    if not valid_losses: print(f"Run failed for Day {DAY_TO_RUN}.")
    else:
        best_loss = valid_losses[0]; best_run_index = 0
        best_results = all_final_results[best_run_index]
        print(f"Run Loss: {best_results[-1]}")
        print(f"Final Parameters (Natural Scale): {best_results[:-1]}")

    end_time = time.time()
    print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")

Using device: cpu
Loaded 31 days from spatial_first_difference_data.pkl.
Pre-computing J-vector (Bartlett taper)...
Pre-computing sample periodogram...
Pre-computing Bartlett taper autocorrelation...
Data grid: 113x158, 8 time points. J-vector, Periodogram, Taper Autocorr on cpu.

Starting with FIXED params (log-scale): [3.2189, 0.0, 0.4055, -0.05, -0.08, 0.05, 0.6931]
Starting optimization run 1 on device cpu (Bartlett taper, EXP kernel)...
--- Epoch 1/700 (LR: 0.005000) ---
 Loss: 8180.4858
 Parameters (Natural Scale): [24.8753  1.005   1.5075 -0.03   -0.06    0.07    1.99  ]
--- Epoch 51/700 (LR: 0.004240) ---
 Loss: 6522.7622
 Parameters (Natural Scale): [2.30309e+01 1.06020e+00 1.71420e+00 2.33300e-01 1.49600e-01 4.70000e-03
 1.81810e+00]
--- Epoch 101/700 (LR: 0.002461) ---
 Loss: 6332.5908
 Parameters (Natural Scale): [2.30089e+01 1.05530e+00 1.72840e+00 2.37400e-01 1.43200e-01 1.60000e-03
 1.80290e+00]
--- Epoch 151/700 (LR: 0.000706) ---
 Loss: 6360.0117
 Parameters (Natural S