In [1]:
import xarray as xr
print('Xarray version', xr.__version__)

Xarray version 2025.7.1


In [2]:
import numpy as np
print('Numpy version', np.__version__)

Numpy version 2.2.6


In [3]:
from perlmutterpath import *  # Contains the data_dir and mesh_dir variables
NUM_FEATURES = 2              # C: Number of features per cell (ex., Freeboard, Ice Area)

In [4]:
from NC_FILE_PROCESSING.patchify_utils import *

# Available patchify functions
PATCHIFY_FUNCTIONS = {
    "agglomerative": compute_agglomerative_patches,
    "breadth_first_bfs_basic": build_patches_from_seeds_bfs_basic,
    "breadth_first_improved_padded": build_patches_from_seeds_improved_padded,
    "dbscan": get_clusters_dbscan,
    "kmeans": cluster_patches_kmeans,
    "knn_basic": compute_knn_patches,
    "knn_disjoint": compute_disjoint_knn_patches,
    "latlon_spillover": patchify_by_latlon_spillover,
    "latitude_neighbors": patchify_by_latitude,
    "latitude_simple": patchify_by_latitude_simple,
    "latitude_spillover_redo": patchify_with_spillover,
    "lon_spilldown": patchify_by_lon_spilldown,
    "rows": get_rows_of_patches,
    "staggered_polar_descent": patchify_staggered_polar_descent,
}

PATCHIFY_ABBREVIATIONS = {
    "agglomerative": "AGG",
    "breadth_first_bfs_basic": "BFSB",
    "breadth_first_improved_padded": "BPIP",
    "dbscan": "DBSCAN",
    "kmeans": "KM",
    "knn_basic": "KNN",
    "knn_disjoint": "DKNN",
    "latlon_spillover": "LLSO",
    "latitude_neighbors": "LAT",
    "latitude_simple": "LSIM",
    "latitude_spillover_redo": "PSO", # Uses PSO (Patchify SpillOver)
    "lon_spilldown": "LSD",
    "rows": "ROW",
    "staggered_polar_descent": "SPD",
}


# Variables for the Model

Check over these CAREFULLY!

Note that if you use the login node for training (even for the trial dataset that is much smaller), you run the risk of getting the error: # OutOfMemoryError: CUDA out of memory.

In [5]:
# --- Time-Related Variables:
CONTEXT_LENGTH = 7            # T: Number of historical time steps used for input
FORECAST_HORIZON = 3          # Number of future time steps to predict (ex. 1 day for next time step)

# --- Model Hyperparameters
D_MODEL = 128                 # d_model: Dimension of the transformer's internal representations (embedding dimension)
N_HEAD = 8                    # nhead: Number of attention heads
NUM_TRANSFORMER_LAYERS = 4    # num_layers: Number of TransformerEncoderLayers
BATCH_SIZE = 16
NUM_EPOCHS = 10

# --- Performance-Related Variables:
NUM_WORKERS = 64

# --- Feature-Related Variables:
MAX_FREEBOARD_FOR_NORMALIZATION = 1    # Only works when you set MAX_FREEBOARD_ON too. #TODO - INCORPORATE THIS LATER

# --- Space-Related Variables:
LATITUDE_THRESHOLD = 40          # Determines number of cells and patches (could use -90 to use the entire dataset).
CELLS_PER_PATCH = 256            # L: Number of cells within each patch

PATCHIFY_TO_USE = "rows"     # Change this to use other patching techniques

# --- Run Settings:
TRIAL_RUN =              True   # SET THIS TO USE THE PRACTICE SET (MUCH FASTER AND SMALLER, for debugging)
PLOT_DATA_DISTRIBUTION = True   # SET THIS TO PLOT THE OUTLIERS (Results are independent of patchify used)
NORMALIZE_ON =           True    # SET THIS TO USE NORMALIZATION ON FREEBOARD (Results are independent of patchify used)
TRAINING =               True    # SET THIS TO RUN THE TRAINING LOOP (Use on full dataset for results)
EVALUATING_ON =          True    # SET THIS TO RUN THE METRICS AT THE BOTTOM (Use on full dataset for results)
MAX_FREEBOARD_ON =       False   # Use this if you want to normalize with a pre-defined maximum for outlier handling # TODO - INCORPORATE THIS LATER
MAP_WITH_CARTOPY_ON =    True   # Make sure the Cartopy library is included in the kernel

## Other Variables Dependent on Those Above ^

In [6]:
mesh = xr.open_dataset(mesh_dir)
latCell = np.degrees(mesh["latCell"].values)
lonCell = np.degrees(mesh["lonCell"].values)
mesh.close()
print("Total nCells:       ", len(latCell))

mask = latCell >= LATITUDE_THRESHOLD
masked_ncells_size = np.count_nonzero(mask)
print("Mask size:          ", masked_ncells_size)

NUM_PATCHES = masked_ncells_size // CELLS_PER_PATCH    # P: Approximate number of spatial patches to expect

print("cells_per_patch:    ", CELLS_PER_PATCH)
print("n_patches:          ", NUM_PATCHES)

# The input dimension for the patch embedding linear layer.
# Each patch at a given time step has NUM_FEATURES * CELLS_PER_PATCH features.
# This is the 'D' dimension used in the Transformer's input tensor (B, T, P, D).
PATCH_EMBEDDING_INPUT_DIM = NUM_FEATURES * CELLS_PER_PATCH # 2 * 256 = 512

DEFAULT_PATCHIFY_METHOD_FUNC = PATCHIFY_FUNCTIONS[PATCHIFY_TO_USE]

# --- Common Parameters for all functions ---
COMMON_PARAMS = {
    "latCell": latCell,
    "lonCell": lonCell,
    "cells_per_patch": CELLS_PER_PATCH, 
    "num_patches": NUM_PATCHES,
    "latitude_threshold": LATITUDE_THRESHOLD,
    "seed": 42
}

cellsOnCell = np.load(f'cellsOnCell.npy')

# --- Function-specific Parameters (if any) ---
SPECIFIC_PARAMS = {
    "latitude_spillover_redo": {"step_deg": 5, "max_lat": 90},
    "latitude_simple": {"step_deg": 5, "max_lat": 90},
    "latitude_neighbors": {"step_deg": 5, "max_lat": 90},
    "breadth_first_improved_padded": {"cellsOnCell": cellsOnCell, "pad_to_exact_size": True},
    "breadth_first_bfs_basic": {"cellsOnCell": cellsOnCell},
    "agglomerative": {"n_neighbors": 5},
    "kmeans": {},
    "dbscan": {},
    "rows": {},
    "knn_basic": {},
    "knn_disjoint": {},
    "latlon_spillover": {},
    "lon_spilldown": {},
    "staggered_polar_descent": {},
}

Total nCells:        465044
Mask size:           53973
cells_per_patch:     256
n_patches:           210


In [7]:
if TRIAL_RUN:
    model_mode = "tr" # Training Dataset
else:
    model_mode = "fd" # Full Dataset

if NORMALIZE_ON:
    norm = "nT"
else:
    norm = "nF"

# Get the abbreviation, with a fallback for functions not yet mapped
patching_strategy_abbr = PATCHIFY_ABBREVIATIONS.get(PATCHIFY_TO_USE, "UNKNWN")

if patching_strategy_abbr == "UNKNWN":
    raise ValueError("Check the name of the patchify function")

# Model nome convention - fd:full data, etc.
model_version = (
    f"{model_mode}_{norm}_D{D_MODEL}_B{BATCH_SIZE}_lt{LATITUDE_THRESHOLD}_P{NUM_PATCHES}_L{CELLS_PER_PATCH}"
    f"_T{CONTEXT_LENGTH}_Fh{FORECAST_HORIZON}_e{NUM_EPOCHS}_{patching_strategy_abbr}"
)

print(model_version)

tr_nT_D128_B16_lt40_P210_L256_T7_Fh3_e10_ROW


### Notes:

* TRY: NUM_WORKERS as 16 to 32 - profile to see if the GPU is still waiting on the CPU.
* TRY: NUM_WORKERS as 64 - the number of CPU cores available.
* TRY: NUM_WORKERS experiment with os.cpu_count() - 2
* TRY: NUM_WORKERS experiment with (logical_cores_per_gpu * num_gpus)

num_workers considerations:
* Too few workers: GPUs might become idle waiting for data.
* Too many workers: Can lead to increased CPU memory usage and context switching overhead.

# More Imports

In [8]:
import sys
print('System Version:', sys.version)

System Version: 3.13.5 | packaged by conda-forge | (main, Jun 16 2025, 08:27:50) [GCC 13.3.0]


In [9]:
#print(sys.executable) # for troubleshooting kernel issues
#print(sys.path)

In [10]:
import os
#print(os.getcwd())

In [11]:
import cupy as cp

In [12]:
import cudf
print('Cudf version', cudf.__version__)

Cudf version 25.06.00


In [13]:
import pandas as pd
print('Pandas version', pd.__version__)

Pandas version 2.2.3


In [14]:
import matplotlib
import matplotlib.pyplot as plt
print('Matplotlib version', matplotlib.__version__)

Matplotlib version 3.10.5


In [15]:
import torch
from torch.utils.data import Dataset, DataLoader

print('PyTorch version', torch.__version__)

ImportError: /pscratch/sd/b/brelypo/conda_env/rapids_pytorch_xarray_cartopy_papermill/lib/python3.13/site-packages/torch/../../../././libcublas.so.12: undefined symbol: cublasLtGetEnvironmentMode, version libcublasLt.so.12

# Hardware Details

In [None]:
!nvidia-smi

In [None]:
if TRAINING and not torch.cuda.is_available():
    raise ValueError("There is a problem with Torch not recognizing the GPUs")
else:
    print(torch.cuda.device_count()) # check the number of available CUDA devices
    # will print 1 on login node; 4 on GPU exclusive node; 1 on shared GPU node

In [None]:
print(torch.cuda.get_device_properties(0)) #provides information about a specific GPU
#total_memory=40326MB, multi_processor_count=108, L2_cache_size=40MB

In [None]:
import psutil
import platform

# Get general CPU information
processor_name = platform.processor()
print(f"Processor Name: {processor_name}")

# Get core counts
physical_cores = psutil.cpu_count(logical=False)
logical_cores = psutil.cpu_count(logical=True)
print(f"Physical Cores: {physical_cores}")
print(f"Logical Cores: {logical_cores}")

# Get CPU frequency
cpu_frequency = psutil.cpu_freq()
if cpu_frequency:
    print(f"Current CPU Frequency: {cpu_frequency.current:.2f} MHz")
    print(f"Min CPU Frequency: {cpu_frequency.min:.2f} MHz")
    print(f"Max CPU Frequency: {cpu_frequency.max:.2f} MHz")

# Get CPU utilization (percentage)
# The interval argument specifies the time period over which to measure CPU usage.
# Setting percpu=True gives individual core utilization.
cpu_percent_total = psutil.cpu_percent(interval=1)
print(f"Total CPU Usage: {cpu_percent_total}%")

# cpu_percent_per_core = psutil.cpu_percent(interval=1, percpu=True)
# print("CPU Usage per Core:")
# for i, percent in enumerate(cpu_percent_per_core):
#     print(f"  Core {i+1}: {percent}%")



# Example of one netCDF file with xarray

In [None]:
# ds = xr.open_dataset("train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-01-01.nc")

from perlmutterpath import * # has the path to the data on Perlmutter
ds = xr.open_dataset(full_data_dir_sample, decode_times=True)

In [None]:
ds.data_vars

In [None]:
day_counter = ds["timeDaily_counter"].shape[0]
print(day_counter)

In [None]:
print(ds["xtime_startDaily"])

In [None]:
print(ds["xtime_startDaily"].values)

In [None]:
ice_area = ds["timeDaily_avg_iceAreaCell"]
print(ds['timeDaily_avg_iceAreaCell'].attrs['long_name'])
print(f"Shape of ice area variable: {ice_area.shape}")

In [None]:
ice_volume = ds["timeDaily_avg_iceVolumeCell"]
print(ds['timeDaily_avg_iceVolumeCell'].attrs['long_name'])
print(f"Shape of ice area variable: {ice_volume.shape}")

In [None]:
print(ds.coords)
print(ds.dims)

In [None]:
print(ds)
ds.close()

# Example of Mesh File

In [None]:
mesh = xr.open_dataset("NC_FILE_PROCESSING/mpassi.IcoswISC30E3r5.20231120.nc")

In [None]:
mesh.data_vars

In [None]:
print(mesh["latCell"].attrs['long_name'])
print(mesh["lonCell"].attrs['long_name'])

In [None]:
cellsOnCell = mesh["cellsOnCell"].values
print(mesh["cellsOnCell"].attrs['long_name'])
print(mesh["cellsOnCell"].values)

In [None]:
print(cellsOnCell.shape[1])

In [None]:
print(mesh["cellsOnCell"].max().values)
print(mesh["cellsOnCell"].min().values)

In [None]:
#cp.save('cellsOnCell.npy', cellsOnCell) 

In [None]:
#landIceMask = mesh["landIceMask"].values
#cp.save('landIceMask.npy', landIceMask)

In [None]:
print(mesh.coords)
print(mesh.dims)

In [None]:
print(mesh)

In [None]:
mesh.close()

# Pre-processing + Freeboard calculation functions

In [None]:
# Constants (adjust if you use different units)
D_WATER = 1023  # Density of seawater (kg/m^3)
D_ICE = 917     # Density of sea ice (kg/m^3)
D_SNOW = 330    # Density of snow (kg/m^3)

MIN_AREA = 1e-6

def compute_freeboard(area: cp.ndarray, 
                      ice_volume: cp.ndarray, 
                      snow_volume: cp.ndarray) -> cp.ndarray:
    """
    Compute sea ice freeboard from ice and snow volume and area.
    
    Parameters
    ----------
    area : cp.ndarray
        Sea ice concentration / area (same shape as ice_volume and snow_volume).
    ice_volume : cp.ndarray
        Sea ice volume per grid cell.
    snow_volume : cp.ndarray
        Snow volume per grid cell.
    
    Returns
    -------
    freeboard : cp.ndarray
        Freeboard height for each cell, same shape as inputs.
    """
    # Initialize arrays
    height_ice = cp.zeros_like(ice_volume)
    height_snow = cp.zeros_like(snow_volume)

    # Valid mask: avoid dividing by very small or zero area
    valid = area > MIN_AREA

    # Safely compute heights where valid
    height_ice[valid] = ice_volume[valid] / area[valid]
    height_snow[valid] = snow_volume[valid] / area[valid]

    # Compute freeboard using the physical formula
    freeboard = (
        height_ice * (D_WATER - D_ICE) / D_WATER +
        height_snow * (D_WATER - D_SNOW) / D_WATER
    )

    return freeboard


In [None]:
def check_freeboard_outliers(freeboard_data: cp.ndarray, times_array: cp.ndarray = None, post_norm=False):
    """
    Checks for bad outliers in the freeboard data using the IQR method.
    Logs the findings and identifies the absolute extreme values and their dates.
    
    Parameters
    ----------
    freeboard_data : cp.ndarray
        The NumPy array of freeboard values (can be multi-dimensional).
    times_array : cp.ndarray, optional
        The NumPy array of datetime objects corresponding to the time dimension
        of freeboard_data. Required to pinpoint dates of extremes.
    """

    logging.info("--- Checking for Freeboard Outliers ---")

    if post_norm:
        norm_state = "post-norm"
    else:
        norm_state = "pre-norm"
    
    flat_freeboard = freeboard_data.flatten()
    data_shape = freeboard_data.shape # (Time, nCells)
    logging.info(f"Data Shape is {data_shape}")
    total_elements = data_shape[1]
    logging.info(f"Freeboard samples total {total_elements}")

    # --- Absolute Extremes ---
    abs_min_val = flat_freeboard.min()
    abs_max_val = flat_freeboard.max()

    logging.info(f"Freeboard Absolute Minimum Value ({norm_state}): {abs_min_val:.4f}")
    logging.info(f"Freeboard Absolute Maximum Value ({norm_state}): {abs_max_val:.4f}")


    # Make sure times_array is a NumPy array for pd.to_datetime and standard indexing
    # We will need to assume it's been transferred back to CPU before this function call.
    if times_array is not None and data_shape[0] == len(times_array):

        # Find index of absolute min/max in the flattened array
        min_flat_idx = cp.argmin(flat_freeboard)
        max_flat_idx = cp.argmax(flat_freeboard)

        # Unravel the flat index back to (time_idx, cell_idx)
        min_time_idx_cupy, min_cell_idx_cupy = cp.unravel_index(min_flat_idx, data_shape)
        max_time_idx_cupy, max_cell_idx_cupy = cp.unravel_index(max_flat_idx, data_shape)

        # Extract the single integer values from the CuPy arrays using .item()
        # and get the cell index as well for logging.
        min_time_idx = min_time_idx_cupy.item()
        min_cell_idx = min_cell_idx_cupy.item()
        max_time_idx = max_time_idx_cupy.item()
        max_cell_idx = max_cell_idx_cupy.item()
        
        # Now, times_array can be indexed with a single integer
        logging.info(f"Date of Absolute Minimum ({abs_min_val:.4f}): {pd.to_datetime(times_array[min_time_idx]).strftime('%Y-%m-%d %H:%M:%S')} (Cell index in masked data: {min_cell_idx})")
        logging.info(f"Date of Absolute Maximum ({abs_max_val:.4f}): {pd.to_datetime(times_array[max_time_idx]).strftime('%Y-%m-%d %H:%M:%S')} (Cell index in masked data: {max_cell_idx})")

    else:
        logging.warning("Cannot pinpoint dates of extremes: times_array not provided or shape mismatch.")

    # --- Zeros ----
    count_zero = cp.sum(flat_freeboard == 0)
    percent_zero = (count_zero / total_elements) * 100
    logging.info(f"Percentage of Freeboard values exactly 0: {percent_zero:.2f}% ({count_zero} points)")

    # --- IQR Outlier Detection ---
    Q1 = cp.percentile(flat_freeboard, 25)
    Q3 = cp.percentile(flat_freeboard, 75)
    IQR = Q3 - Q1
    
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    
    outliers_low = flat_freeboard[flat_freeboard < lower_bound]
    outliers_high = flat_freeboard[flat_freeboard > upper_bound]

    num_outliers = len(outliers_low) + len(outliers_high)

    logging.info(f"{norm_state}")
    logging.info(f"Freeboard Q1: {Q1:.4f}")
    logging.info(f"Freeboard Q3: {Q3:.4f}")
    logging.info(f"Freeboard IQR: {IQR:.4f}")
    logging.info(f"Freeboard Lower Bound (Q1 - 1.5*IQR): {lower_bound:.4f}")
    logging.info(f"Freeboard Upper Bound (Q3 + 1.5*IQR): {upper_bound:.4f}")
    logging.info(f"Number of low outliers: {len(outliers_low)}")
    logging.info(f"Number of high outliers: {len(outliers_high)}")
    logging.info(f"Total outliers: {num_outliers} ({num_outliers / total_elements * 100:.2f}% of total elements)")

    if num_outliers > 0:
        logging.warning("Potential outliers detected in Freeboard data!")
        logging.info(f"Sample low outliers (first 10): {outliers_low[:10]}")
        logging.info(f"Sample high outliers (first 10): {outliers_high[:10]}")
        logging.info(f"Sample high outliers (last 10): {outliers_high[:-10]}")
    else:
        logging.info("No significant outliers detected in Freeboard data based on IQR method.")

def plot_freeboard_distribution(freeboard_data: cp.ndarray, prefix: str = ""):
    """
    Plots the distribution of the freeboard variable using a histogram and a boxplot,
    and saves the plot as a PNG file.
    
    Parameters
    ----------
    freeboard_data : cp.ndarray
        The flattened NumPy array of freeboard values to plot.
    save_path : str
        The directory where the plot PNG file will be saved.
    """
    logging.info(f"--- Plotting Freeboard Distribution ({prefix}) ---")
    
    flat_freeboard = freeboard_data.flatten()
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    np_flat_freeboard = flat_freeboard.get() # Convert the CuPy array to NumPy for the plotting

    axes[0].hist(np_flat_freeboard, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
    axes[0].set_title('Distribution of Freeboard (Histogram)')
    axes[0].set_xlabel('Freeboard Value')
    axes[0].set_ylabel('Frequency')
    axes[0].grid(True, linestyle='--', alpha=0.6)
    axes[0].set_xlim(0, 1.8)

    # --- Boxplot ---
    axes[1].boxplot(np_flat_freeboard, vert=True, patch_artist=True, boxprops=dict(facecolor='lightcoral'),
                    medianprops=dict(color='black'), whiskerprops=dict(color='gray'),
                    capprops=dict(color='gray'), flierprops=dict(marker='o', markersize=5, markerfacecolor='red', alpha=0.5))
    axes[1].set_title('Distribution of Freeboard (Boxplot)')
    axes[1].set_ylabel('Freeboard Value')
    axes[1].set_ylim(0, 0.3)
    axes[1].set_xticks([])

    plt.suptitle(f'Freeboard Data Distribution and Outlier Visualization {prefix}', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust rect to leave space for suptitle
    
    # Save to the current working directory
    current_directory = os.getcwd()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = os.path.join(current_directory, f"{prefix}_{timestamp}.png")
    
    plt.savefig(filename, dpi=300) # dpi=300 for high-quality image
    plt.close(fig) 
    
    logging.info(f"Freeboard distribution plot saved to: {filename}")

In [None]:
def analyze_ice_area_imbalance(ice_area_data: cp.ndarray):
    """
    Measures and logs the percentage of ice_area data points that are 0, 1, or between 0 and 1.
    
    Parameters
    ----------
    ice_area_data : cp.ndarray
        The CuPy array of ice_area values (can be multi-dimensional).
    """
    logging.info("--- Analyzing Ice Area Imbalance ---")

    flat_ice_area = ice_area_data.flatten()
    total_elements = flat_ice_area.size # Use .size for total count

    if total_elements == 0:
        logging.warning("Ice Area data is empty, cannot analyze imbalance.")
        return

    # Use .item() to extract the scalar counts
    count_zero = cp.sum(flat_ice_area == 0).item()
    count_one = cp.sum(flat_ice_area == 1).item()
    count_between = cp.sum((flat_ice_area > 0) & (flat_ice_area < 1)).item()

    percent_zero = (count_zero / total_elements) * 100
    percent_one = (count_one / total_elements) * 100
    percent_between = (count_between / total_elements) * 100

    logging.info(f"Total Ice Area data points: {total_elements}")
    logging.info(f"Percentage of values == 0: {percent_zero:.2f}% ({count_zero} points)")
    logging.info(f"Percentage of values == 1: {percent_one:.2f}% ({count_one} points)")
    logging.info(f"Percentage of values between 0 and 1 (exclusive): {percent_between:.2f}% ({count_between} points)")
    
    # Optional check for values outside [0, 1] range, if any
    # Use .item() to extract the scalar count
    count_invalid = cp.sum((flat_ice_area < 0) | (flat_ice_area > 1)).item()
    if count_invalid > 0:
        logging.warning(f"Found {count_invalid} ice_area values outside the [0, 1] range!")
        print(f"Found {count_invalid} ice_area values outside the [0, 1] range!")

# The original function with minimal changes.
def plot_ice_area_imbalance(ice_area_data: cp.ndarray, prefix: str = ""):
    """
    Creates a bar chart to visualize the imbalance of ice_area values (0, 1, or between 0-1).
    Saves the chart as a PNG file.
    
    Parameters
    ----------
    ice_area_data : cp.ndarray
        The CuPy array of ice_area values to plot (can be multi-dimensional).
    """
    import matplotlib.pyplot as plt # Moved import to a function level
    
    logging.info("--- Plotting Ice Area Imbalance Chart ---")

    flat_ice_area = ice_area_data.flatten()
    total_elements = flat_ice_area.size # Use .size for the total count

    if total_elements == 0:
        logging.warning("Ice Area data is empty, cannot plot imbalance.")
        return

    # Use .item() to get the scalar counts for calculation
    count_zero = cp.sum(flat_ice_area == 0).item()
    count_00_to_25_percent = cp.sum((flat_ice_area > 0) & (flat_ice_area < 0.25)).item()
    count_25_to_50_percent = cp.sum((flat_ice_area > 0.25) & (flat_ice_area < 0.5)).item()
    count_59_to_75_percent = cp.sum((flat_ice_area > 0.5) & (flat_ice_area < 0.75)).item()
    count_75_to_99_percent = cp.sum((flat_ice_area > 0.75) & (flat_ice_area < 1)).item()
    count_one = cp.sum(flat_ice_area == 1).item()
    
    categories = ['Exactly 0', '>0 - 0.25','0.25 - 0.5','0.5 - 0.75','0.75 - <1', 'Exactly 1']
    percentages = [
        (count_zero / total_elements) * 100,
        (count_00_to_25_percent / total_elements) * 100,
        (count_25_to_50_percent / total_elements) * 100,
        (count_59_to_75_percent / total_elements) * 100,
        (count_75_to_99_percent / total_elements) * 100,
        (count_one / total_elements) * 100,
    ]

    fig, ax = plt.subplots(figsize=(10, 7))

    # Plot the bars as percentages
    bars = ax.bar(categories, percentages, color=['black','gray','silver','lightgrey','whitesmoke','white','red'], edgecolor='black')
      
    ax.set_title('Distribution of Ice Area Values', fontsize=16)
    ax.set_xlabel('Value Category', fontsize=12)
    ax.set_ylabel('Percentage of Data (%)', fontsize=12)
    ax.set_ylim(0, 80)
    ax.grid(axis='y', linestyle='--', alpha=0.7)

    # Add percentage labels on top of the bars
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.2f}%',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=10)

    plt.tight_layout()

    # Save to the current working directory
    current_directory = os.getcwd()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = os.path.join(current_directory, f"{prefix}_SIC_imbalance_{timestamp}.png")
    
    plt.savefig(filename, dpi=300) # dpi=300 for high-quality image
    plt.close(fig) 
    logging.info(f"Ice Area imbalance chart saved to: {filename}")

In [None]:
def normalize_freeboard(freeboard, min_val=-0.2, max_val=1.2):
    return cp.clip((freeboard - min_val) / (max_val - min_val), 0, 1)

# Custom Pytorch Dataset
Example from NERSC of using ERA5 Dataset:

https://github.com/NERSC/dl-at-scale-training/blob/main/utils/data_loader.py

# __ init __ - masks and loads the data into tensors

In [None]:
date_string = "2023-10-26 14:30:00"
new_time = cudf.to_datetime(date_string)

print(type(new_time))

In [None]:
import os
import time
from datetime import datetime
from datetime import timedelta

from torch.utils.data import Dataset
from typing import List, Union, Callable, Tuple, Dict, Any
from NC_FILE_PROCESSING.patchify_utils import *
from perlmutterpath import * # Contains the data_dir and mesh_dir variables

import logging

# Set level to logging.INFO to see the statements
logging.basicConfig(filename=f'Dataset_{model_version}.log', filemode='w', level=logging.INFO)

class DailyNetCDFDataset(Dataset):
    """
    PyTorch Dataset that concatenates a directory of month-wise NetCDF files
    along their 'Time' dimension and yields daily data *plus* its timestamp.

    Parameters
    ----------
    data_dir : str
        Directory containing NetCDF files
    transform : Callable | None
        Optional - transform applied to the data tensor *only*.
    latitude_threshold
        The minimum latitude to use for Arctic data
    context_length
        The number of days to fetch for input in the prediction step
    forecast_horizon
        The number of days to predict in the future
    plot_outliers_and_imbalance
        Optional - check outliers and imbalance on the variables Ice Area and Freeboard
    trial_run
        Optional - use the data in the trial directory instead of the full dataset
        Specify the name of the trial director in perlmutterpath.py
    num_patches
        How many patches to use for the patchify function
    cells_per_patch
        How many cells to have in each patch for patchify
    patchify_func : Callable
        The patchify function to use (ex., patchify_by_latlon_spillover).
    patchify_func_key : str
        The string key identifying the patchify function (e.g., "latlon_spillover")
        used to look up its specific parameters.

    """
    def __init__(
        self,
        data_dir: str = data_dir,
        mesh_dir: str = mesh_dir,
        transform: Callable = None,
        latitude_threshold: int = LATITUDE_THRESHOLD,
        context_length: int = CONTEXT_LENGTH,
        forecast_horizon: int = FORECAST_HORIZON,
        normalize_on: bool = NORMALIZE_ON,
        plot_outliers_and_imbalance: bool = PLOT_DATA_DISTRIBUTION, # set FALSE FOR FINAL
        trial_run: bool = TRIAL_RUN, # Use the trial data directory
        num_patches: int = NUM_PATCHES,
        cells_per_patch: int = CELLS_PER_PATCH,
        patchify_func: Callable = DEFAULT_PATCHIFY_METHOD_FUNC, # Default patchify function
        patchify_func_key: str = PATCHIFY_TO_USE, # Key to look up specific params
    ):

        """ __init__ needs to 

        Handle the raw data:
        1) Gather the sorted daily data from each netCDF file (1 file = 1 month of daily data)
            The netCDF files contain nCells worth of data per day for each feature (ice area, ice volume, etc.)
            nCells = 465044 with the IcoswISC30E3r5 mesh
        2) Load the mesh and initialize the cell mask
        3) Store a list of datetimes from each file 
        4) Extract raw data
        
        Perform pre-processing:
        5) Apply a mask to nCells to look just at regions in certain latitudes
            nCells >= 40 degrees is 53973 cells
            nCells >= 50 degrees is 35623 cells
        6) Derive Freeboard from ice area, snow volume, and ice volume
        7) Custom patchify and store patch_ids so the data loader can use them
        8) Optional: Plot the outliers and data imbalance for Ice Area and Freeboard
        9) Optional: Normalize the data (Ice area is already between 0 and 1; Freeboard is not) """

        start_time = time.time()
        self.data_dir = data_dir
        self.mesh_dir = mesh_dir
        self.transform = transform
        self.latitude_threshold = latitude_threshold
        self.context_length = context_length
        self.forecast_horizon = forecast_horizon
        self.normalize_on = normalize_on
        self.plot_outliers_and_imbalance = plot_outliers_and_imbalance
        self.trial_run = trial_run
        self.num_patches = num_patches
        self.cells_per_patch = cells_per_patch
        self.patchify_func = patchify_func # Store the specified patchify function
        self.patchify_func_key = patchify_func_key # Store the key for looking up specific params

        # --- 1. Gather files (sorted for deterministic order) ---------
        if self.trial_run:
            
            # USE THIS FOR PRACTICE (SMALLER CHUNK OF DATA)-
            self.file_paths = sorted(
                [
                    os.path.join(self.data_dir, f)
                    for f in os.listdir(self.data_dir)

                    # GET 4 YEAR SUBSET 2020 - 2024
                    if f.startswith("v3.LR.historical_0051.mpassi.hist.am.timeSeriesStatsDaily.202") and f.endswith(".nc")
                ]
            )

        else:
            # USE THE FULL DATASET 
            self.data_dir = data_dir
            self.file_paths = sorted(
                [
                    os.path.join(data_dir, f)
                    for f in os.listdir(data_dir)

                    # GET ALL - 1850 TO 2024
                    if f.startswith("v3.LR.historical_0051.mpassi.hist.am.timeSeriesStatsDaily.") and f.endswith(".nc")
                ]
            )
        
        logging.info(f"Found {len(self.file_paths)} NetCDF files:")
        if not self.file_paths:
            raise FileNotFoundError(f"No *.nc files found in {data_dir!r}")

        # --- 2. Load the mesh file. Latitudes and Longitudes are in radians. ---
        latCell, lonCell = load_mesh_radians(self.mesh_dir)
        latCell = cp.array(latCell)
        lonCell = cp.array(lonCell)
        latCell = cp.degrees(latCell)
        lonCell = cp.degrees(lonCell)
        
        # Initialize the cell mask
        self.cell_mask = latCell >= latitude_threshold        # CuPy boolean array
        masked_ncells_size = cp.count_nonzero(self.cell_mask)
        logging.info(f"Mask size: {masked_ncells_size}")

        self.full_to_masked = {
            int(full_idx): new_idx
            for new_idx, full_idx in enumerate(cp.where(self.cell_mask)[0])
        }

        # Also store reverse mapping: masked -> full for recovery of data later
        self.masked_to_full = {
            v: k for k, v in self.full_to_masked.items()
        }

        logging.info(f"=== Extracting raw data and times in a single loop === ")

        all_times_list = []
        ice_area_all_list = []
        ice_volume_all_list = []
        snow_volume_all_list = []
        
        for i, path in enumerate(self.file_paths):
            ds = xr.open_dataset(path)

            # --- 3. Store a list of datetimes from each file -> helps with retrieving 1 day's data later        
            # Decode byte strings and fix the format
            xtime_strs = ds["xtime_startDaily"].str.decode("utf-8").values
            xtime_strs = [s.replace("_", " ") for s in xtime_strs]  # "0010-01-01_00:00:00" → "0010-01-01 00:00:00"
        
            # Convert to datetime.datetime objects
            times = [datetime.strptime(s, "%Y-%m-%d %H:%M:%S") for s in xtime_strs]
            all_times_list.extend(times)
            
            # --- 4. Extract raw data
            ice_area = cp.array(ds["timeDaily_avg_iceAreaCell"].values)
            ice_volume = cp.array(ds["timeDaily_avg_iceVolumeCell"].values)
            snow_volume = cp.array(ds["timeDaily_avg_snowVolumeCell"].values)

            # --- 5. Apply a mask to the nCells
            ice_area = ice_area[:, self.cell_mask]
            ice_volume = ice_volume[:, self.cell_mask]
            snow_volume = snow_volume[:, self.cell_mask]

            # Append masked data to lists
            ice_area_all_list.append(ice_area)
            ice_volume_all_list.append(ice_volume)
            snow_volume_all_list.append(snow_volume)

            ds.close() # Close dataset after processing

        # --- Concatenate all collected data into single NumPy arrays after the loop
        self.ice_area_cupy = cp.concatenate(ice_area_all_list, axis=0)
        ice_volume_combined = cp.concatenate(ice_volume_all_list, axis=0)
        snow_volume_combined = cp.concatenate(snow_volume_all_list, axis=0)

        # TODO - CHECK IF CONFLICTS
        self.times = cudf.to_datetime(all_times_list)
        
        # Checking the dates
        logging.info(f"Parsed {len(self.times)} total dates")
        logging.info(f"First few: {str(self.times[:5])}")

        # Stats on how many dates there are
        logging.info(f"Total days collected: {len(self.times)}")
        logging.info(f"Unique days: {len(self.times.unique())}") # Use .unique() for cudf Series
        logging.info(f"First 35 days: {self.times[:35]}")
        logging.info(f"Last 35 days: {self.times[-35:]}")

        logging.info(f"Shape of combined ice_area array: {self.ice_area_cupy.shape}")
        logging.info(f"Elapsed time for combined data/time loading: {time.time() - start_time} seconds")
        
        # --- 6. Derive Freeboard from ice area, snow volume and ice volume (store it on GPU)
        logging.info(f"=== Calculating Freeboard === ")
        self.freeboard_cupy = compute_freeboard(self.ice_area_cupy, ice_volume_combined, snow_volume_combined)
        logging.info(f"Elapsed time for freeboard calculation: {time.time() - start_time} seconds")
        
        logging.info(f"=== Patchifying === ")

        # Get the parameters for the patchification function
        patchify_call_params = COMMON_PARAMS.copy()
        
        # Retrieve only the specific parameters for the chosen patchify function
        patchify_call_params.update(SPECIFIC_PARAMS.get(self.patchify_func_key, {}))
        
        # --- 7. Use the dynamic patchify function
        #     Returns 
        # full_nCells_patch_ids : cp.ndarray
        #     Array of shape (nCells,) giving patch ID or -1 if unassigned.
        # indices_per_patch_id : List[cp.ndarray]
        #     List of patches, each a list of cell indices (cp.ndarray of ints) that correspond with nCells array.
        # patch_latlons : cp.ndarray
        #     Array of shape (n_patches, 2) containing (latitude, longitude) for one
        #     representative cell per patch (the first cell added to the patch)
        self.full_nCells_patch_ids, self.indices_per_patch_id, self.patch_latlons, self.algorithm = self.patchify_func(**patchify_call_params)

        # Convert full-domain patch indices to masked-domain indices
        # This ensures there's no out of bounds problem,
        # like index 296237 is out of bounds for axis 1 with size 53973
        self.indices_per_patch_id = [
            [self.full_to_masked[i] for i in patch if i in self.full_to_masked]
            for patch in self.indices_per_patch_id
        ]

        print(type(self.indices_per_patch_id[0]))
        print(type(self.indices_per_patch_id[0][0]))

        
        logging.info(f"Elapsed time for patchifying with the {self.algorithm} algorithm: {time.time() - start_time} seconds")

        # --- 8. Optional --- OUTLIER DETECTION AND DATA IMBALANCE CHECK ---
        if self.trial_run:
            prefix = "trial"
        else:
            prefix = "prod"
            
        if self.plot_outliers_and_imbalance:
            logging.info(f"=== Plotting Outliers and Imbalance === ")
            check_freeboard_outliers(self.freeboard_cupy, self.times, post_norm=False)  
            plot_freeboard_distribution(self.freeboard_cupy, f"{prefix}_fb_pre_norm")
            analyze_ice_area_imbalance(self.ice_area_cupy)
            plot_ice_area_imbalance(self.ice_area_cupy, prefix)

        # --- 9. Optional --- Normalize the data (Area is already between 0 and 1; Freeboard is not)
        if self.normalize_on:
            logging.info(f"=== Normalizing Freeboard === ")
    
            self.freeboard_min = self.freeboard_cupy.min()
            self.freeboard_max = self.freeboard_cupy.max()
    
            logging.info(f"Freeboard min (pre-norm): {self.freeboard_min} meters" )
            logging.info(f"Freeboard max (pre-norm): {self.freeboard_max} meters")
    
            self.freeboard_cupy = normalize_freeboard(
                self.freeboard_cupy, min_val=self.freeboard_min, max_val=self.freeboard_max)
    
            logging.info(f"Freeboard Shape: {self.freeboard_cupy.shape}")
            logging.info(f"Ice Area Shape:  {self.ice_area_cupy.shape}")
    
            logging.info("=== Normalized Freeboard ===")
            freeboard_min_after_norm = self.freeboard_cupy.min()
            freeboard_max_after_norm  = self.freeboard_cupy.max()
    
            logging.info(f"Freeboard min (post-norm): {freeboard_min_after_norm}" )
            logging.info(f"Freeboard max (post-norm): {freeboard_max_after_norm}")

            if self.plot_outliers_and_imbalance:
                check_freeboard_outliers(self.freeboard_cupy, self.times, post_norm=True)
                plot_freeboard_distribution(self.freeboard_cupy, f"{prefix}_fb_post_norm")

        
        # Transfer the preprocessed CuPy arrays to NumPy arrays on the CPU
        # so that the __getitem__ method can be simple and fast.
        self.freeboard = self.freeboard_cupy.get()
        self.ice_area = self.ice_area_cupy.get()

        del self.freeboard_cupy
        del self.ice_area_cupy
        
        logging.info("End of __init__")
        end_time = time.time()
        logging.info(f"Elapsed time for DailyNetCDFDataset __init__: {end_time - start_time} seconds")
        print(f"Elapsed time for DailyNetCDFDataset __init__: {end_time - start_time} seconds")
        print(f"In minutes:                {(end_time - start_time)//60} minutes")

    def __len__(self):
        """
        Returns the total number of possible starting indices (idx) for a valid sequence.
        A valid sequence needs `self.context_length` days for input and `self.forecast_horizon` days for target.
        
        ex) If the total number of days is 365, the context_length is 7 and the forecast_horizon is 3, then
        
        365 - (7 + 3) + 1 = 365 - 10 + 1 = 356 valid starting indices
        """
        required_length = self.context_length + self.forecast_horizon
        if len(self.freeboard) < required_length:
            return 0 # Not enough raw data to form even one sample

        # The number of valid starting indices
        return len(self.freeboard) - required_length + 1

   
    def get_patch_tensor(self, day_idx: int) -> torch.Tensor:
        
        """
        Retrieves the feature data for a specific day, organized into patches.

        This method extracts 'freeboard' and 'ice_area' data for a given day
        and then reshapes it according to the pre-defined patches. Each patch
        will contain its own set of feature values.

        Parameters
        ----------
        day_idx : int
            The integer index of the day to retrieve data for, relative to the
            concatenated dataset's time dimension.

        Returns
        -------
        torch.Tensor
            A tensor containing the feature data organized by patches for the
            specified day.
            Shape: (num_patches, num_features, patch_size)
            Where:
            - num_patches: Total number of patches (ex., 140).
            - num_features: The number of features per cell (currently 2: freeboard, ice_area).
            - patch_size: The number of cells within each patch.
            
        """
        
        freeboard_day = self.freeboard[day_idx]  # (nCells,)
        ice_area_day = self.ice_area[day_idx]    # (nCells,)
        features = np.stack([freeboard_day, ice_area_day], axis=0)  # (2, nCells)
        patch_tensors = []

        for patch_indices in self.indices_per_patch_id:
            patch = features[:, patch_indices]  # (2, patch_size)
            patch_tensors.append(torch.from_numpy(patch).float())
            
        return torch.stack(patch_tensors)  # (context_length, num_patches, num_features, patch_size)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, np.datetime64]:

        """__ getitem __ needs to 
        
        1. Given an input of a certain date id, get the input and the target tensors
        2. Return all the patches for the input and the target
           Features are: [freeboard, ice_area] over masked cells. 
           
        """
        # Start with the id of the day in question
        start_idx = idx

        # end_idx is the exclusive end of the input sequence,
        # and the inclusive start of the target sequence.
        end_idx = idx + self.context_length
        target_start = end_idx

        # the target sequence ends after forecast horizon
        target_end = end_idx + self.forecast_horizon

        if target_end > len(self.freeboard):
            raise IndexError(
                f"Requested time window exceeds dataset. "
                f"Problematic idx: {idx}, "
                f"Context Length: {self.context_length}, "
                f"Forecast Horizon: {self.forecast_horizon}, "
                f"Calculated target_end: {target_end}, "
                f"Actual dataset length (len(self.freeboard)): {len(self.freeboard)}"
            )

        # Build input tensor
        input_seq = [self.get_patch_tensor(i) for i in range(start_idx, end_idx)]
        input_tensor = torch.stack(input_seq)
    
        # Build target tensor: shape (forecast_horizon, num_patches)
        target_seq = self.ice_area[end_idx:target_end]
        target_patches = []

        for day in target_seq:
            patch_day = [ # Use torch.from_numpy() and then cast to float
                torch.from_numpy(day[patch_indices]).float() for patch_indices in self.indices_per_patch_id
            ]
            
            # After stacking, patch_day_tensor will be (num_patches, CELLS_PER_PATCH)
            patch_day_tensor = torch.stack(patch_day)  # (num_patches,)
            target_patches.append(patch_day_tensor)

        # Final target tensor shape: (forecast_horizon, num_patches, CELLS_PER_PATCH)
        target_tensor = torch.stack(target_patches)  # (forecast_horizon, num_patches)
        
        return input_tensor, target_tensor, start_idx, end_idx, target_start, target_end


    def __repr__(self):
        """ Format the string representation of the data """
        return (
            f"<DailyNetCDFDataset: {len(self)} days, "
            f"{len(self.freeboard[0])} cells/day, "
            f"{len(self.file_paths)} files loaded, "
            f"Patchify Algorithm: {self.algorithm}>" # Added algorithm to repr
        )

    def time_to_dataframe(self) -> pd.DataFrame:
            """Return a DataFrame of time features you can merge with predictions."""
            t = pd.to_datetime(self.times)            # pandas Timestamp index
            return pd.DataFrame(
                {
                    "time": t,
                    "year": t.year,
                    "month": t.month,
                    "day": t.day,
                    "doy": t.dayofyear,
                }
            )

In [None]:
!sqs

# DataLoader

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset

print(f"===== Making the Dataset Class: TRIAL_RUN MODE IS {TRIAL_RUN} ===== ")

# load all the data from one folder
dataset = DailyNetCDFDataset(data_dir)

# Patch locations for positional embedding
PATCH_LATLONS_TENSOR = torch.tensor(dataset.patch_latlons, dtype=torch.float32)

print("========== SPLITTING THE DATASET ===================")
# DIFFERENT SUBSET OPTIONS FOR TRAINING / VALIDATION / TESTING for the trial data vs. full dataset
if TRIAL_RUN:
    total_days = len(dataset)
    train_end = int(total_days * 0.7)
    val_end = int(total_days * 0.85)
    
    train_set = Subset(dataset, range(0, train_end))
    val_set   = Subset(dataset, range(train_end, val_end))
    test_set  = Subset(dataset, range(val_end, total_days))
    
else:
    # --- Custom Splitting by Year ---
    
    # Convert dataset.times to pandas DatetimeIndex for easier year-based filtering
    all_times_pd = pd.to_datetime(dataset.times)

    # Define the start and end years for each set - keep this for the full dataset
    train_start_year = 1850
    train_end_year = 2011
    val_start_year = 2012
    val_end_year = 2017
    test_start_year = 2018
    test_end_year = 2024
    
    # Get the boolean masks for each set
    train_mask = (all_times_pd.year >= train_start_year) & (all_times_pd.year <= train_end_year)
    val_mask = (all_times_pd.year >= val_start_year) & (all_times_pd.year <= val_end_year)
    test_mask = (all_times_pd.year >= test_start_year) & (all_times_pd.year <= test_end_year)

    # Get the integer indices where the masks are True
    train_indices = cp.where(train_mask)[0].tolist()
    val_indices = cp.where(val_mask)[0].tolist()
    test_indices = cp.where(test_mask)[0].tolist()
    
    # Create Subsets using the obtained indices
    train_set = Subset(dataset, train_indices)
    val_set   = Subset(dataset, val_indices)
    test_set  = Subset(dataset, test_indices)

    train_end = train_indices[-1]
    val_end = val_indices[-1]

print("Training data length:   ", len(train_set))
print("Validation data length: ", len(val_set))
print("Testing data length:    ", len(test_set))

total_days = len(train_set) + len(val_set) + len(test_set)
print("Total days = ", total_days)

print("Number of training batches", len(train_set)//BATCH_SIZE)
print("Number of training batches", len(val_set)//BATCH_SIZE)

print("Number of test batches after drop_last incomplete batch", len(test_set)//BATCH_SIZE)
print("Number of test days to drop after drop_last incomplete batch", len(test_set)//BATCH_SIZE)

print("===== Printing Dataset ===== ")
print(dataset)                 # calls __repr__ → see how many files & days loaded

print("===== Sample at dataset[0] ===== ")
input_tensor, target_tensor, start_idx, end_idx, target_start, target_end = dataset[0]

print(f"Fetched start index {start_idx}: Time={dataset.times[start_idx]}")
print(f"Fetched end   index {end_idx}: Time={dataset.times[end_idx]}")

print(f"Fetched target start index {target_start}: Time={dataset.times[target_start]}")
print(f"Fetched target end   index {target_end}: Time={dataset.times[target_end]}")

def print_set_dates(dataset_subset, set_name):
    """ Print start and end dates for each set (Training, Validation, Testing)"""
    if len(dataset_subset) == 0:
        print(f"{set_name} set: No data available.")
        return

    # Get the global indices of the first and last elements in the subset
    first_global_idx = dataset_subset.indices[0]
    last_global_idx = dataset_subset.indices[-1]

    # Note: For the training, validation, and testing sets, each item (idx) represents the *start*
    # of a `context_length + forecast_horizon` window.
    # So, the start date of a set is the `dataset.times` value at the global index of its first item.
    start_date = dataset.times[first_global_idx] 

    # The last "day" considered in the last sample of the subset
    # is the `dataset.times` value at the global index of its last item
    # PLUS the `context_length + forecast_horizon - 1` days to get to the end of that last window.
    #end_date_idx_for_last_sample = last_global_idx + dataset.context_length + dataset.forecast_horizon - 1
    #end_date = dataset.times[end_date_idx_for_last_sample]
    end_date = dataset.times[last_global_idx]

    print(f"{set_name} set start date: {start_date}")
    print(f"{set_name} set end date: {end_date}")
    logging.info(f"{set_name} set start date: {start_date}")
    logging.info(f"{set_name} set end date: {end_date}")

print("===== Start and End Dates for Each Set =====")
print_set_dates(train_set, "Training")
print_set_dates(val_set, "Validation")
print_set_dates(test_set, "Testing")

print("===== Starting DataLoader ====")
# wrap in a DataLoader
# 1. Use pinned memory for faster asynch transfer to GPUs)
# 2. Use a prefetch factor so that the GPU is fed w/o a ton of CPU memory use
# 3. Use shuffle=False to preserve time order (especially for forecasting)
# 4. Use drop_last=True to prevent it from testing on incomplete batches
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
val_loader   = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
test_loader  = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2, drop_last=True)

print("input_tensor should be of shape (context_length, num_patches, num_features, patch_size)")
print(f"actual input_tensor.shape = {input_tensor.shape}")
print("target_tensor should be of shape (forecast_horizon, num_patches, patch_size)")
print(f"actual target_tensor.shape = {target_tensor.shape}")

# Transformer Class

In [None]:
import torch
import torch.nn as nn
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class IceForecastTransformer(nn.Module):
    
    """
    A Transformer-based model for forecasting ice conditions based on sequences of
    historical patch data.

    Parameters
    ----------
    input_patch_features_dim : int
        The dimensionality of the feature vector for each individual patch (ex. 2 features).
        This is the input dimension for the patch embedding layer.
    num_patches : int
        The total number of geographical patches that the `nCells` data was divided into.
        (ex., 256 patches).
    context_length : int, optional
        The number of historical days (time steps) to use as input for the transformer.
        Defaults to 7.
    forecast_horizon : int, optional
        The number of future days to predict for each patch.
        Defaults to 1.
    d_model : int, optional
        The dimension of the model's hidden states (embedding dimension).
        This is the size of the vectors that flow through the Transformer encoder.
        Defaults to 128.
    nhead : int, optional
        The number of attention heads in the multi-head attention mechanism within
        each Transformer encoder layer. Defaults to 8.
    num_layers : int, optional
        The number of Transformer encoder layers in the model. Defaults to 4.

    Attributes
    ----------
    patch_embed : nn.Linear
        Linear layer to project input patch features into the `d_model` hidden space.
    encoder : nn.TransformerEncoder
        The Transformer encoder module composed of `num_layers` encoder layers.
    mlp_head : nn.Sequential
        A multi-layer perceptron head for outputting predictions for each patch.
    """
    
    def __init__(self,
                 input_patch_features_dim: int = PATCH_EMBEDDING_INPUT_DIM, # D: The flat feature dimension of a single patch (ex., 512)
                 num_patches: int = NUM_PATCHES,  # P: Number of spatial patches
                 context_length: int = CONTEXT_LENGTH, # T: Number of historical time steps
                 forecast_horizon: int = FORECAST_HORIZON, # Number of future time steps to predict (usually 1)
                 d_model: int = D_MODEL,        # d_model: Transformer's embedding dimension
                 nhead: int = N_HEAD,           # nhead: Number of attention heads
                 num_layers: int = NUM_TRANSFORMER_LAYERS # num_layers: Number of TransformerEncoderLayers
                ):
        
        super().__init__()

        """
        The transformer should
        1. Accept a sequence of days (ex. 7 days of patches). 
           The context_length parameter says how many days to use for input.
        2. Encode each patch with the transformer.
        3. Output the patches for regression (ex. predict the 8th day).
           The forecast_horizon parameter says how many days to use for the output prediction.
        
        """

        self.context_length = context_length
        self.forecast_horizon = forecast_horizon
        self.num_patches = num_patches
        self.d_model = d_model
        self.input_patch_features_dim = input_patch_features_dim
   
        print("Calling IceForecastTransformer __init__")
        start_time = time.time()

        # Patch embedding layer: projects the raw patch features (512)
        # into d_model (128) hidden space dimension
        self.patch_embed = nn.Linear(input_patch_features_dim, d_model)

        # Transformer Encoder
        # batch_first=True means input/output tensors are (batch, sequence, features)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output MLP head: (B, P, CELLS_PER_PATCH * forecast_horizon)
        # Make a prediction for every cell per patch
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, CELLS_PER_PATCH * forecast_horizon)
        )

        end_time = time.time()
        print(f"Elapsed time: {end_time - start_time:.2f} seconds")
        print("End of __init__")

    def forward(self, x):
        """
        B = Batch size
        T = Time (context_length)
        P = Patch count
        D = Patch Dimension (cells per patch * feature count)
        x: Tensor of shape (B, T, P, D)
        Output: Tensor of shape (batch_size, forecast_horizon, num_patches)
        Output: (B, forecast_horizon, P)
        """
        
        # Initial input x shape from DataLoader / pre-processing:
        # (B, T, P, D) i.e., (Batch_Size, Context_Length, Num_Patches, Input_Patch_Features_Dim)
        # Example: (16, 7, 140, 512)
        
        B, T, P, D = x.shape

        # Flatten time and patches for the Transformer Encoder:
        # Each (Time, Patch) combination becomes a single token in the sequence.
        # Output shape: (B, T * P, D)
        # Example: (16, 7 * 140 = 980, 512)
        
        # Flatten time and patches for the Transformer Encoder: (B, T * P, D)
        # This treats each patch at each time step as a distinct token
        x = x.view(B, T * P, D)

        # Project patch features to the transformer's d_model dimension
        x = self.patch_embed(x)  # Output: (B, T * P, d_model) ex., (16, 980, 128)
        
        # Apply transformer encoder layers
        x = self.encoder(x)      # Output: (B, T * P, d_model) ex., (16, 980, 128)

        # Reshape back to separate time and patches: (B, T, P, d_model) ex., (16, 7, 140, 128)
        x = x.view(B, T, P, self.d_model) 

        # Mean pooling over the time (context_length) dimension for each patch.
        # This aggregates information from all historical time steps for each patch's final prediction.        
        x = x.mean(dim=1)  # Output: (B, P, d_model) ex., (16, 140, 128)

        # Apply MLP head to predict values for each cell in each patch
        # The MLP head outputs (B, P, CELLS_PER_PATCH * forecast_horizon)
        x = self.mlp_head(x) # ex. (16, 140, 256 * 3) = (16, 140, 768)

        # Reshape the output to (B, forecast_horizon, P, CELLS_PER_PATCH)
        # Explicitly reshape the last dimension to seperate the forecast horizon out
        x = x.view(B, P, self.forecast_horizon, CELLS_PER_PATCH) # Reshape into forecast_horizon and CELLS_PER_PATCH
        x = x.permute(0, 2, 1, 3) # Permute to (B, forecast_horizon, P, CELLS_PER_PATCH)

        return x


In [None]:
!sqs

# Training Loop

In [None]:
if TRAINING:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torch import Tensor
    import torch.nn.functional as F
    
    import logging
    
    # Set level to logging.INFO to see the statements
    logging.basicConfig(filename='IceForecastTransformerInstance.log', filemode='w', level=logging.INFO)
    
    model = IceForecastTransformer().to(device)
    
    print("\n--- Model Architecture ---")
    print(model)
    print("--------------------------\n")
    
    logging.info("\n--- Model Architecture ---")
    logging.info(str(model)) # Log the full model structure
    logging.info(f"Total model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logging.info("--------------------------\n")
    
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()

    logging.info("== TIMER IS STARTING FOR TRAINING ==")
    start_time = time.time()
    logging.info("===============================")
    logging.info("       STARTING EPOCHS       ")
    logging.info("===============================")
    logging.info(f"Number of epochs: {NUM_EPOCHS}")
    logging.info(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0
    
        for batch_idx, (input_tensor, target_tensor, start_idx, end_idx, target_start, target_end) in enumerate(train_loader):  
    
            # Move input and target to the device
            # x: (B, context_length, num_patches, input_patch_features_dim), y: (B, forecast_horizon, num_patches)
            x = input_tensor.to(device)  # Shape: (B, T, P, C, L)
            y = target_tensor.to(device)  # Shape: (B, forecast_horizon, P, L)
    
            # Reshape x for transformer input
            B, T, P, C, L = x.shape
            x_reshaped_for_transformer_D = x.view(B, T, P, C * L)
    
            # Run through transformer
            y_pred = model(x_reshaped_for_transformer_D) # y_pred is (B, forecast_horizon, num_patches) ex., (16, 1, 140)
            
            # Compute loss
            loss = criterion(y_pred, y) # DIRECTLY compare y_pred and y
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
    
        avg_train_loss = total_loss / len(train_loader)
        logging.info(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {avg_train_loss:.4f}")
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {avg_train_loss:.4f}") # Keep print for immediate console feedback
    
        # --- Validation loop ---
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                # Unpack the full tuple
                x_val, y_val, start_idx, end_idx, target_start, target_end = batch
        
                # Move to GPU if available
                x_val = x_val.to(device)
                y_val = y_val.to(device)
    
                # Extract dimensions from x_val for reshaping
                # x_val before reshaping: (B_val, T_val, P_val, C_val, L_val)
                B_val, T_val, P_val, C_val, L_val = x_val.shape
                
                # Reshape x_val for transformer input
                x_val_reshaped_for_transformer_input = x_val.view(B_val, T_val, P_val, C_val * L_val)
    
                # Model output is (B, forecast_horizon, P, L)
                y_val_pred = model(x_val_reshaped_for_transformer_input) 
    
                # Compute validation loss (y_val_pred and y_val should have identical shapes)
                val_loss += criterion(y_val_pred, y_val).item() # y_val is (B, forecast_horizon, P, L)
        
        avg_val_loss = val_loss / len(val_loader)
        logging.info(f"Epoch {epoch+1}/{NUM_EPOCHS} - Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Loss: {avg_val_loss:.4f}") # Keep print for immediate console feedback
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    logging.info("===============================================")
    logging.info(f"Elapsed time for TRAINING: {elapsed_time:.2f} seconds")
    logging.info("===============================================")
    print("===============================================")
    print(f"Elapsed time for TRAINING: {elapsed_time:.2f} seconds")
    print("===============================================")

In [None]:
!sqs

TODO OPTION: Try temporal attention only (ex., Informer, Time Series Transformer).

# Save the Model

In [None]:
# Define the path where to save or load the model
PATH = f"SIC_model_{model_version}.pth"

if TRAINING:
    
    # Save the model's state_dict
    torch.save(model.state_dict(), PATH)
    print(f"Saved model at {PATH}")

# === BELOW - CAN BE USED ANY TIME FROM A .PTH FILE

Make sure and run the cells that contain constants or run all, but comment out the "save" and the training loop cell.

# Re-Load the Model

In [None]:
if EVALUATING_ON:
    
    import torch
    import torch.nn as nn
    
    if not torch.cuda.is_available():
        raise ValueError("There is a problem with Torch not recognizing the GPUs")
    
    # Instantiate the model (must have the same architecture as when it was saved)
    # Create an identical instance of the original __init__ parameters
    loaded_model = IceForecastTransformer()
    
    # Load the saved state_dict (weights_only=True helps ensure safety of pickle files)
    loaded_model.load_state_dict(torch.load(PATH, weights_only=True))
    
    # Set the model to evaluation mode
    loaded_model.eval()
    
    # Move the model to the appropriate device (CPU or GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loaded_model.to(device)
    
    print("Model loaded successfully!")

# Metrics

In [None]:
if EVALUATING_ON:

    # TODO: MAKE THIS MORE EFFICIENT AND ALSO ADD IN SIE PREDICTION AND CONFUSION MATRIX
    import io
    
    # Create a string buffer to capture output
    captured_output = io.StringIO()
    
    # Redirect stdout to the buffer
    sys.stdout = captured_output
    
    from scipy.stats import entropy
    
    # Accumulators for errors
    all_abs_errors = [] # To store absolute errors for each cell in each patch
    all_mse_errors = [] # To store MSE for each cell in each patch
    
    # Accumulators for histogram data
    all_predicted_values_flat = []
    all_actual_values_flat = []
    
    print("\nStarting evaluation and metric calculation...")
    print("==================")
    print(f"DEBUG: Batch Size: {BATCH_SIZE} Days")
    print(f"DEBUG: Context Length: {CONTEXT_LENGTH} Days")
    print(f"DEBUG: Forecast Horizon: {FORECAST_HORIZON} Days")
    print(f"DEBUG: Number of batches in test_loader (with drop_last=True): {len(test_loader)} Batches")
    print("==================")
    print(f"DEBUG: len(test_set): {len(test_set)} Days")
    print(f"DEBUG: len(dataset) for splitting: {len(dataset)} Days")
    print(f"DEBUG: train_end: {train_end}")
    print(f"DEBUG: val_end: {val_end}")
    print(f"DEBUG: range for test_set: {range(val_end, total_days)}")
    print("==================")
    
    # Iterate over the test_loader
    # (B, forecast_horizon, P, CELLS_PER_PATCH) to match the model's output.
    for i, (sample_x, sample_y, start_idx, end_idx, target_start, target_end) in enumerate(test_loader):
        print(f"Processing batch {i+1}/{len(test_loader)}")
        
        # Move to device and apply initial reshape as done in training
        sample_x = sample_x.to(device)
        sample_y = sample_y.to(device) # Actual target values
    
        # Initial reshape of x for the Transformer model
        B_sample, T_sample, P_sample, C_sample, L_sample = sample_x.shape
        sample_x_reshaped = sample_x.view(B_sample, T_sample, P_sample, C_sample * L_sample)
    
        # Perform inference
        with torch.no_grad(): # Essential for inference to disable gradient calculations
            predicted_y_patches = loaded_model(sample_x_reshaped)
    
        # Ensure predicted_y_patches and sample_y have the same shape for comparison
        # Expected shape: (B, forecast_horizon, NUM_PATCHES, CELLS_PER_PATCH)
        if predicted_y_patches.shape != sample_y.shape:
            print(f"Shape mismatch: Predicted {predicted_y_patches.shape}, Actual {sample_y.shape}")
            continue # Skip this batch if shapes are incompatible
    
        # Calculate errors for each cell in each patch, across the forecast horizon and batch
        # The errors will implicitly be averaged over the batch when we take the mean later
        diff = predicted_y_patches - sample_y
        abs_error_batch = torch.abs(diff)
        mse_error_batch = diff ** 2
    
        # Accumulate errors (move to CPU for storage if memory is a concern)
        all_abs_errors.append(abs_error_batch.cpu())
        all_mse_errors.append(mse_error_batch.cpu())
    
        # Collect data for histograms (flatten all values)
        all_predicted_values_flat.append(predicted_y_patches.cpu().numpy().flatten())
        all_actual_values_flat.append(sample_y.cpu().numpy().flatten())
    
    # Concatenate all accumulated tensors
    if all_abs_errors and all_mse_errors:
        combined_abs_errors = torch.cat(all_abs_errors, dim=0) # Shape: (Total_Samples, FH, P, CPC)
        combined_mse_errors = torch.cat(all_mse_errors, dim=0) # Shape: (Total_Samples, FH, P, CPC)
    
        # Calculate average MSE and Absolute Error for each cell in each patch
        # Average over batch size and forecast horizon
        # Resulting shape: (NUM_PATCHES, CELLS_PER_PATCH)
        mean_abs_error_per_cell_patch = combined_abs_errors.mean(dim=(0, 1)) # Average over batch and forecast horizon
        mean_mse_per_cell_patch = combined_mse_errors.mean(dim=(0, 1)) # Average over batch and forecast horizon
    
        print("\n--- Error Metrics (Averaged per Cell per Patch) ---")
        print(f"Mean Absolute Error (shape {mean_abs_error_per_cell_patch.shape}):")
        # print(mean_abs_error_per_cell_patch) # Uncomment to see the full tensor
        print(f"Overall Mean Absolute Error:            {mean_abs_error_per_cell_patch.mean().item():.4f}")
    
        print(f"\nMean Squared Error (shape {mean_mse_per_cell_patch.shape}):")
        # print(mean_mse_per_cell_patch) # Uncomment to see the full tensor
    
        mse = mean_mse_per_cell_patch.mean().item()
        print(f"Overall Mean Squared Error:             {mse:.4f}")
    
        rmse = cp.sqrt(mse)
        print(f"Overall Root Mean Squared Error (RMSE): {rmse}")
        
    else:
        print("No data processed for error metrics. Check test_loader and data availability.")

    # --- Histogram and Jensen-Shannon Distance ---
    
    # Concatenate all flattened values
    if all_predicted_values_flat and all_actual_values_flat:
        final_predicted_values = cp.concatenate(all_predicted_values_flat)
        final_actual_values = cp.concatenate(all_actual_values_flat)
    
        print(f"\nTotal predicted values collected: {len(final_predicted_values)}")
        print(f"Total actual values collected: {len(final_actual_values)}")

        # --- Visualize the distribution of the predicted values
    
        analyze_ice_area_imbalance(final_predicted_values)
        plot_ice_area_imbalance(final_predicted_values, "predicted")

        analyze_ice_area_imbalance(final_actual_values)
        plot_ice_area_imbalance(final_actual_values, "actual")
        
        # Define bins for the histogram (e.g., for ice concentration between 0 and 1)
        # Adjust bins based on the expected range of your data
        bins = cp.linspace(0, 1, 51) # 50 bins from 0 to 1
    
        # Compute histograms
        hist_predicted, _ = cp.histogram(final_predicted_values, bins=bins, density=True)
        hist_actual, _ = cp.histogram(final_actual_values, bins=bins, density=True)
    
        # Normalize histograms to sum to 1 (they are already density=True, but re-normalize for safety)
        hist_predicted = hist_predicted / hist_predicted.sum()
        hist_actual = hist_actual / hist_actual.sum()
    
        # Jensen-Shannon Distance function
        def jensen_shannon_distance(p, q):
            """Calculates the Jensen-Shannon distance between two probability distributions."""
            # Ensure distributions sum to 1
            p = p / p.sum()
            q = q / q.sum()
    
            m = 0.5 * (p + q)
            # Add a small epsilon to avoid log(0)
            epsilon = 1e-10
            jsd = 0.5 * (entropy(p + epsilon, m + epsilon) + entropy(q + epsilon, m + epsilon))
            return cp.sqrt(jsd) # JSD is the square root of JS divergence
    
        # Calculate Jensen-Shannon Distance
        jsd = jensen_shannon_distance(hist_actual, hist_predicted)
        print(f"\nJensen-Shannon Distance between actual and predicted histograms: {jsd:.4f}")
    
        # Plotting Histograms
        plt.figure(figsize=(10, 6))
        plt.hist(final_actual_values, bins=bins, alpha=0.7, label='Actual Data', color='skyblue', density=True)
        plt.hist(final_predicted_values, bins=bins, alpha=0.7, label='Predicted Data', color='salmon', density=True)
        plt.title('Distribution of Actual vs. Predicted Ice Concentration Values')
        plt.xlabel('Ice Concentration Value')
        plt.ylabel('Probability Density')
        plt.legend()
        plt.grid(axis='y', alpha=0.75)
        plt.savefig(f"SIE_Distribution_Actual_vs_Predicted_model_{model_version}.png")
        plt.close()
    
        # When reading the histograms, look for overlap:
        # High Overlap: predictions are close to actual values. Decent model.
        # Low Overlap: predictions differ from actual values, issues with the model. 
    
    else:
        print("No data collected for histogram analysis. Check test_loader and data availability.")
    
    print("\nEvaluation complete.")
    
    # Restore stdout
    sys.stdout = sys.__stdout__
    
    # Now, write the captured output to the file
    with open(f'Metrics_{PATH}.txt', 'w') as f:
        f.write(captured_output.getvalue())
    
    print(f"Metrics saved to Metrics_{PATH}.txt")

# Make a Single Prediction

In [None]:
if EVALUATING_ON:
    
    # Load one batch
    data_iter = iter(test_loader)
    sample_x, sample_y, start_idx, end_idx, target_start, target_end = next(data_iter)
    
    print(f"Shape of sample_x {sample_x.shape}")
    print(f"Shape of sample_y {sample_y.shape}")   
    
    print(f"Fetched sample_x start index {start_idx}: Time={dataset.times[start_idx]}")
    print(f"Fetched sample_x end   index {end_idx}:   Time={dataset.times[end_idx]}")
    
    print(f"Fetched sample_y (target) start index {target_end}: Time={dataset.times[target_end]}")
    print(f"Fetched sample_y (target) end   index {target_end}: Time={dataset.times[target_end]}")
    
    # Move to device and apply initial reshape as done in training
    sample_x = sample_x.to(device)
    sample_y = sample_y.to(device) # Keep sample_y for actual comparison
    
    # Initial reshape of x for the Transformer model
    B_sample, T_sample, P_sample, C_sample, L_sample = sample_x.shape
    sample_x_reshaped = sample_x.view(B_sample, T_sample, P_sample, C_sample * L_sample)
    
    print(f"Sample x for inference shape (reshaped): {sample_x_reshaped.shape}")
    
    # Perform inference
    with torch.no_grad(): # Essential for inference to disable gradient calculations
        predicted_y_patches = loaded_model(sample_x_reshaped)
    
    print(f"Predicted y patches shape: {predicted_y_patches.shape}")
    print(f"Expected shape: (B, forecast_horizon, NUM_PATCHES, CELLS_PER_PATCH) ex., (16, {loaded_model.forecast_horizon}, 140, 256)")
                     
    # Option 1: Select a specific day from the forecast horizon (ex., the first day)
    # This is the shape (B, NUM_PATCHES, CELLS_PER_PATCH) for that specific day.
    predicted_for_day_0 = predicted_y_patches[:, 0, :, :].cpu()
    print(f"Predicted ice area for Day 0 (specific day) shape: {predicted_for_day_0.shape}")
    
    # Ensure sample_y has the same structure
    actual_for_day_0 = sample_y[:, 0, :, :].cpu()
    print(f"Actual ice area for Day 0 (specific day) shape: {actual_for_day_0.shape}")
    
    # Save predictions so that I can use cartopy by switching kernels for the next jupyter cell
    cp.save(f'patches/ice_area_patches_predicted_{PATH}_day0.npy', predicted_for_day_0)
    cp.save(f'patches/ice_area_patches_actual_{PATH}_day0.npy', actual_for_day_0)

    # Option 2: Iterate through all forecast days
    all_predicted_ice_areas = []
    all_actual_ice_areas = []
    
    for day_idx in range(loaded_model.forecast_horizon):
        predicted_day = predicted_y_patches[:, day_idx, :, :].cpu()
        all_predicted_ice_areas.append(predicted_day)
    
        actual_day = sample_y[:, day_idx, :, :].cpu()
        all_actual_ice_areas.append(actual_day)
    
        print(f"Processing forecast day {day_idx}: Predicted shape {predicted_day.shape}, Actual shape {actual_day.shape}")
    
        # Save each day's prediction/actual data if needed
        # cp.save(f'patches/ice_area_patches_predicted_day{day_idx}.npy', predicted_day)
        # cp.save(f'patches/ice_area_patches_actual_day{day_idx}.npy', actual_day)


# Recover nCells from Patches for Visualization

In [None]:
if MAP_WITH_CARTOPY_ON:

    ########################################
    # SWAP KERNELS IN THE JUPYTER NOTEBOOK #
    ########################################
    
    from MAP_ANIMATION_GENERATION.map_gen_utility_functions import *
    from NC_FILE_PROCESSING.nc_utility_functions import *
    from NC_FILE_PROCESSING.patchify_utils import *
    
    import numpy as np
    
    predicted_ice_area_patches = cp.load(f'patches/SIC_predicted_{model_version}_day0.npy')
    actual_y_ice_area_patches = cp.load(f'patches/SIC_actual_{model_version}_day0.npy')
    
    NUM_PATCHES = len(predicted_ice_area_patches[0])
    print("NUM_PATCHES is", NUM_PATCHES)
    
    latCell, lonCell = load_mesh(perlmutterpathMesh)
    TOTAL_GRID_CELLS = len(lonCell) 
    cell_mask = latCell >= LATITUDE_THRESHOLD
    
    # Extract Freeboard (index 0) and Ice Area (index 1) for predicted and actual
    # Predicted output is (B, 1, NUM_PATCHES, CELLS_PER_PATCH)
    # Assuming the model predicts ice area, which is the second feature (index 1)
    # if the output of the model aligns with the order of features *within* the original patch_dim.
    
    # Load the original patch-to-cell mapping
    # indices_per_patch_id = [
    #     [idx_cell_0_0, ..., idx_cell_0_255],
    #     [idx_cell_1_0, ..., idx_cell_1_255],
    #     ...
    # ]
    
    full_nCells_patch_ids, indices_per_patch_id, patch_latlons = patchify_by_latlon_spillover(
                latCell, lonCell, k=256, max_patches=NUM_PATCHES, LATITUDE_THRESHOLD=LATITUDE_THRESHOLD)
    
    # Select one sample from the batch for visualization (ex., the first one)
    # Output is (NUM_PATCHES, CELLS_PER_PATCH) for this single sample
    sample_predicted_cells_per_patch = predicted_ice_area_patches[2] # First item in batch
    sample_actual_cells_per_patch = predicted_ice_area_patches[2] # First item in batch
    
    # Initialize empty arrays for the full grid (nCells)
    recovered_predicted_grid = cp.full(TOTAL_GRID_CELLS, cp.nan)
    recovered_actual_grid = cp.full(TOTAL_GRID_CELLS, cp.nan)
    
    # Populate the full grid using the patch data and mapping
    for patch_idx in range(NUM_PATCHES):
        cell_indices_in_patch = indices_per_patch_id[patch_idx]
        
        # For predicted values
        recovered_predicted_grid[cell_indices_in_patch] = sample_predicted_cells_per_patch[patch_idx]
        nan_mask = cp.isnan(recovered_predicted_grid)
        nan_count = cp.sum(nan_mask)
    
        # For actual values
        recovered_actual_grid[cell_indices_in_patch] = sample_actual_cells_per_patch[patch_idx]
    
    print(f"Recovered predicted grid shape: {recovered_predicted_grid.shape}")
    print(f"Recovered actual grid shape: {recovered_actual_grid.shape}")
    
    fig, northMap = generate_axes_north_pole()
    generate_map_north_pole(fig, northMap, latCell, lonCell, recovered_predicted_grid, f"model {model_version} ice area recovered")
    
    fig, northMap = generate_axes_north_pole()
    generate_map_north_pole(fig, northMap, latCell, lonCell, recovered_actual_grid, f"model {model_version} ice area actual")

In [None]:
!sqs