# Model Hyperparameter Constants / Defaults

Check over these CAREFULLY!

Note that if you use the login node for training (even for the trial dataset that is much smaller), you run the risk of getting the error: # OutOfMemoryError: CUDA out of memory.

In [1]:
import xarray as xr
print('Xarray version', xr.__version__)

Xarray version 2025.6.1


In [2]:
import numpy as np
print('Numpy version', np.__version__)

Numpy version 2.2.6


In [3]:
from perlmutterpath import * # Contains the data_dir and mesh_dir variables

# --- Space Constants:
NUM_FEATURES = 2              # C: Number of features per cell (ex., Freeboard, Ice Area)
LATITUDE_THRESHOLD = 40       # Determines number of cells and patches

# Load the mesh and data to plot.
mesh = xr.open_dataset(mesh_dir)
latCell = np.degrees(mesh["latCell"].values)
lonCell = np.degrees(mesh["lonCell"].values)
mesh.close()
print("Total nCells:       ", len(latCell))

mask = latCell >= LATITUDE_THRESHOLD
masked_ncells_size = np.count_nonzero(mask)
print("Mask size:          ", masked_ncells_size)

CELLS_PER_PATCH = 256                                  # L: Number of cells within each patch
NUM_PATCHES = masked_ncells_size // CELLS_PER_PATCH    # P: Number of spatial patches

print("cells_per_patch:    ", CELLS_PER_PATCH)
print("n_patches:          ", NUM_PATCHES)

Total nCells:        465044
Mask size:           53973
cells_per_patch:     256
n_patches:           210


In [4]:
# --- Time Constants:
CONTEXT_LENGTH = 7            # T: Number of historical time steps used for input
FORECAST_HORIZON = 3          # Number of future time steps to predict (ex. 1 day for next time step)

# Model Constants
D_MODEL = 128                 # d_model: Dimension of the transformer's internal representations (embedding dimension)
N_HEAD = 8                    # nhead: Number of attention heads
NUM_TRANSFORMER_LAYERS = 4    # num_layers: Number of TransformerEncoderLayers
BATCH_SIZE = 16
NUM_EPOCHS = 10

# The input dimension for the patch embedding linear layer.
# Each patch at a given time step has NUM_FEATURES * CELLS_PER_PATCH features.
# This is the 'D' dimension used in the Transformer's input tensor (B, T, P, D).
PATCH_EMBEDDING_INPUT_DIM = NUM_FEATURES * CELLS_PER_PATCH # 2 * 256 = 512

# Performance-related
NUM_WORKERS = 64

In [5]:
TRIAL_RUN =              False # TODO - SET THIS TO USE THE PRACTICE SET (MUCH FASTER AND SMALLER)
TRAINING =               True  # TODO - SET THIS TO RUN THE TRAINING LOOP
PLOT_DATA_DISTRIBUTION = True  # TODO - SET THIS TO PLOT THE OUTLIERS (use on the full dataset)
NORMALIZE_ON =           True  # TODO - SET THIS TO USE NORMALIZATION ON FREEBOARD
EVALUATING_ON =          True  # TODO - SET THIS TO RUN THE METRICS AT THE BOTTOM

if TRIAL_RUN:
    model_mode = "tr" # Training Dataset
else:
    model_mode = "fd" # Full Dataset

if NORMALIZE_ON:
    norm = "nT"
else:
    norm = "nF"

# Model nome convention - fd:full data, etc.
model_version = f"{model_mode}_{norm}_D{D_MODEL}_B{BATCH_SIZE}_lt{LATITUDE_THRESHOLD}_P{NUM_PATCHES}_L{CELLS_PER_PATCH}_T{CONTEXT_LENGTH}_Fh{FORECAST_HORIZON}_e{NUM_EPOCHS}"
print(model_version)

fd_nT_D128_B16_lt40_P210_L256_T7_Fh3_e10


### Notes:

TRY: NUM_WORKERS as 16 to 32 - profile to see if the GPU is still waiting on the CPU.

TRY: NUM_WORKERS as 64 - the number of CPU cores available.

TRY: NUM_WORKERS experiment with os.cpu_count() - 2

TRY: NUM_WORKERS experiment with (logical_cores_per_gpu * num_gpus)

num_workers considerations:
Too few workers: GPUs might become idle waiting for data.
Too many workers: Can lead to increased CPU memory usage and context switching overhead.

# More Imports

In [6]:
import sys
print('System Version:', sys.version)

System Version: 3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]


In [7]:
#print(sys.executable) # for troubleshooting kernel issues
#print(sys.path)

In [8]:
import os
#print(os.getcwd())

In [9]:
import pandas as pd
print('Pandas version', pd.__version__)

Pandas version 2.3.1


In [10]:
import matplotlib
import matplotlib.pyplot as plt
print('Matplotlib version', matplotlib.__version__)

Matplotlib version 3.10.3


In [11]:
import torch
from torch.utils.data import Dataset, DataLoader

print('PyTorch version', torch.__version__)

PyTorch version 2.5.1


# Hardware Details

In [12]:
if TRAINING and not torch.cuda.is_available():
    raise ValueError("There is a problem with Torch not recognizing the GPUs")
else:
    print(torch.cuda.device_count()) # check the number of available CUDA devices
    # will print 1 on login node; 4 on GPU exclusive node; 1 on shared GPU node

4


In [13]:
#print(torch.cuda.get_device_properties(0)) #provides information about a specific GPU
#total_memory=40326MB, multi_processor_count=108, L2_cache_size=40MB

In [14]:
import psutil
import platform

# Get general CPU information
processor_name = platform.processor()
print(f"Processor Name: {processor_name}")

# Get core counts
physical_cores = psutil.cpu_count(logical=False)
logical_cores = psutil.cpu_count(logical=True)
print(f"Physical Cores: {physical_cores}")
print(f"Logical Cores: {logical_cores}")

# Get CPU frequency
cpu_frequency = psutil.cpu_freq()
if cpu_frequency:
    print(f"Current CPU Frequency: {cpu_frequency.current:.2f} MHz")
    print(f"Min CPU Frequency: {cpu_frequency.min:.2f} MHz")
    print(f"Max CPU Frequency: {cpu_frequency.max:.2f} MHz")

# Get CPU utilization (percentage)
# The interval argument specifies the time period over which to measure CPU usage.
# Setting percpu=True gives individual core utilization.
cpu_percent_total = psutil.cpu_percent(interval=1)
print(f"Total CPU Usage: {cpu_percent_total}%")

# cpu_percent_per_core = psutil.cpu_percent(interval=1, percpu=True)
# print("CPU Usage per Core:")
# for i, percent in enumerate(cpu_percent_per_core):
#     print(f"  Core {i+1}: {percent}%")



Processor Name: x86_64
Physical Cores: 64
Logical Cores: 128
Current CPU Frequency: 2498.48 MHz
Min CPU Frequency: 1500.00 MHz
Max CPU Frequency: 2450.00 MHz
Total CPU Usage: 0.3%


# Example of one netCDF file with xarray

In [15]:
# ds = xr.open_dataset("train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-01-01.nc")

from perlmutterpath import * # has the path to the data on Perlmutter
ds = xr.open_dataset(full_data_dir_sample, decode_times=True)

In [16]:
ds.data_vars

Data variables:
    timeDaily_counter             (Time) int32 124B ...
    xtime_startDaily              (Time) |S64 2kB ...
    xtime_endDaily                (Time) |S64 2kB ...
    timeDaily_avg_iceAreaCell     (Time, nCells) float32 58MB ...
    timeDaily_avg_iceVolumeCell   (Time, nCells) float32 58MB ...
    timeDaily_avg_snowVolumeCell  (Time, nCells) float32 58MB ...
    timeDaily_avg_uVelocityGeo    (Time, nVertices) float32 117MB ...
    timeDaily_avg_vVelocityGeo    (Time, nVertices) float32 117MB ...

In [17]:
day_counter = ds["timeDaily_counter"].shape[0]
print(day_counter)

31


In [18]:
print(ds["xtime_startDaily"])

<xarray.DataArray 'xtime_startDaily' (Time: 31)> Size: 2kB
[31 values with dtype=|S64]
Dimensions without coordinates: Time


In [19]:
print(ds["xtime_startDaily"].values)

[b'2024-12-01_00:00:00' b'2024-12-02_00:00:00' b'2024-12-03_00:00:00'
 b'2024-12-04_00:00:00' b'2024-12-05_00:00:00' b'2024-12-06_00:00:00'
 b'2024-12-07_00:00:00' b'2024-12-08_00:00:00' b'2024-12-09_00:00:00'
 b'2024-12-10_00:00:00' b'2024-12-11_00:00:00' b'2024-12-12_00:00:00'
 b'2024-12-13_00:00:00' b'2024-12-14_00:00:00' b'2024-12-15_00:00:00'
 b'2024-12-16_00:00:00' b'2024-12-17_00:00:00' b'2024-12-18_00:00:00'
 b'2024-12-19_00:00:00' b'2024-12-20_00:00:00' b'2024-12-21_00:00:00'
 b'2024-12-22_00:00:00' b'2024-12-23_00:00:00' b'2024-12-24_00:00:00'
 b'2024-12-25_00:00:00' b'2024-12-26_00:00:00' b'2024-12-27_00:00:00'
 b'2024-12-28_00:00:00' b'2024-12-29_00:00:00' b'2024-12-30_00:00:00'
 b'2024-12-31_00:00:00']


In [20]:
ice_area = ds["timeDaily_avg_iceAreaCell"]
ice_area.shape

(31, 465044)

In [21]:
ice_area.values

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], shape=(31, 465044), dtype=float32)

In [22]:
print(ds.coords)
print(ds.dims)

Coordinates:
    *empty*


In [23]:
print(ds)
ds.close()

<xarray.Dataset> Size: 407MB
Dimensions:                       (Time: 31, nCells: 465044, nVertices: 942873)
Dimensions without coordinates: Time, nCells, nVertices
Data variables:
    timeDaily_counter             (Time) int32 124B ...
    xtime_startDaily              (Time) |S64 2kB b'2024-12-01_00:00:00' ... ...
    xtime_endDaily                (Time) |S64 2kB ...
    timeDaily_avg_iceAreaCell     (Time, nCells) float32 58MB 0.0 0.0 ... 0.0
    timeDaily_avg_iceVolumeCell   (Time, nCells) float32 58MB ...
    timeDaily_avg_snowVolumeCell  (Time, nCells) float32 58MB ...
    timeDaily_avg_uVelocityGeo    (Time, nVertices) float32 117MB ...
    timeDaily_avg_vVelocityGeo    (Time, nVertices) float32 117MB ...
Attributes: (12/490)
    case:                                                         v3.LR.histo...
    source_id:                                                    399d430113
    realm:                                                        seaIce
    product:              

# Example of Mesh File

In [24]:
mesh = xr.open_dataset("NC_FILE_PROCESSING/mpassi.IcoswISC30E3r5.20231120.nc")

In [25]:
mesh.data_vars

Data variables:
    edgesOnEdge        (nEdges, maxEdges2) int32 68MB ...
    weightsOnEdge      (nEdges, maxEdges2) float64 135MB ...
    cellsOnEdge        (nEdges, TWO) int32 11MB ...
    verticesOnEdge     (nEdges, TWO) int32 11MB ...
    angleEdge          (nEdges) float64 11MB ...
    dcEdge             (nEdges) float64 11MB ...
    dvEdge             (nEdges) float64 11MB ...
    indexToEdgeID      (nEdges) int32 6MB ...
    latEdge            (nEdges) float64 11MB ...
    lonEdge            (nEdges) float64 11MB ...
    nEdgesOnEdge       (nEdges) int32 6MB ...
    xEdge              (nEdges) float64 11MB ...
    yEdge              (nEdges) float64 11MB ...
    zEdge              (nEdges) float64 11MB ...
    fEdge              (nEdges) float64 11MB ...
    cellsOnVertex      (nVertices, vertexDegree) int32 11MB ...
    edgesOnVertex      (nVertices, vertexDegree) int32 11MB ...
    kiteAreasOnVertex  (nVertices, vertexDegree) float64 23MB ...
    areaTriangle       (nVertices)

In [26]:
cellsOnCell = mesh["cellsOnCell"].values
print(mesh["cellsOnCell"].values)

[[     5      4      0      0      0      0]
 [    12     11      9      8      0      3]
 [     4     13     12      2      0      0]
 ...
 [465043      0 465040 465041      0      0]
 [     0 465042      0 465044      0      0]
 [     0      0      0      0 465043      0]]


In [27]:
print(cellsOnCell.shape[1])

6


In [28]:
print(mesh["cellsOnCell"].max().values)
print(mesh["cellsOnCell"].min().values)

465044
0


In [29]:
#np.save('cellsOnCell.npy', cellsOnCell) 

In [30]:
#landIceMask = mesh["landIceMask"].values
#np.save('landIceMask.npy', landIceMask)

In [31]:
print(mesh.coords)
print(mesh.dims)

Coordinates:
    *empty*


In [32]:
print(mesh)

<xarray.Dataset> Size: 509MB
Dimensions:            (nEdges: 1408196, maxEdges2: 12, TWO: 2,
                        nVertices: 942873, vertexDegree: 3, nCells: 465044,
                        maxEdges: 6, Time: 1)
Dimensions without coordinates: nEdges, maxEdges2, TWO, nVertices,
                                vertexDegree, nCells, maxEdges, Time
Data variables: (12/40)
    edgesOnEdge        (nEdges, maxEdges2) int32 68MB ...
    weightsOnEdge      (nEdges, maxEdges2) float64 135MB ...
    cellsOnEdge        (nEdges, TWO) int32 11MB ...
    verticesOnEdge     (nEdges, TWO) int32 11MB ...
    angleEdge          (nEdges) float64 11MB ...
    dcEdge             (nEdges) float64 11MB ...
    ...                 ...
    nEdgesOnCell       (nCells) int32 2MB ...
    xCell              (nCells) float64 4MB ...
    yCell              (nCells) float64 4MB ...
    zCell              (nCells) float64 4MB ...
    fCell              (nCells) float64 4MB ...
    landIceMask        (Time, nCells) 

In [33]:
mesh.close()

# Pre-processing + Freeboard calculation functions

In [34]:
# Constants (adjust if you use different units)
D_WATER = 1023  # Density of seawater (kg/m^3)
D_ICE = 917     # Density of sea ice (kg/m^3)
D_SNOW = 330    # Density of snow (kg/m^3)

MIN_AREA = 1e-6

def compute_freeboard(area: np.ndarray, 
                      ice_volume: np.ndarray, 
                      snow_volume: np.ndarray) -> np.ndarray:
    """
    Compute sea ice freeboard from ice and snow volume and area.
    
    Parameters
    ----------
    area : np.ndarray
        Sea ice concentration / area (same shape as ice_volume and snow_volume).
    ice_volume : np.ndarray
        Sea ice volume per grid cell.
    snow_volume : np.ndarray
        Snow volume per grid cell.
    
    Returns
    -------
    freeboard : np.ndarray
        Freeboard height for each cell, same shape as inputs.
    """
    # Initialize arrays
    height_ice = np.zeros_like(ice_volume)
    height_snow = np.zeros_like(snow_volume)

    # Valid mask: avoid dividing by very small or zero area
    valid = area > MIN_AREA

    # Safely compute heights where valid
    height_ice[valid] = ice_volume[valid] / area[valid]
    height_snow[valid] = snow_volume[valid] / area[valid]

    # Compute freeboard using the physical formula
    freeboard = (
        height_ice * (D_WATER - D_ICE) / D_WATER +
        height_snow * (D_WATER - D_SNOW) / D_WATER
    )

    return freeboard


In [35]:
def check_freeboard_outliers(freeboard_data: np.ndarray):
    """
    Checks for bad outliers in the freeboard data using the IQR method.
    Logs the findings.
    
    Parameters
    ----------
    freeboard_data : np.ndarray
        The flattened NumPy array of freeboard values to check.
    """
    logging.info("--- Checking for Freeboard Outliers ---")
    
    flat_freeboard = freeboard_data.flatten()
    total_elements = len(flat_freeboard)

    count_zero = np.sum(flat_freeboard == 0)
    percent_zero = (count_zero / total_elements) * 100
    logging.info(f"Percentage of Freeboard values exactly 0: {percent_zero:.2f}% ({count_zero} points)")

    Q1 = np.percentile(flat_freeboard, 25)
    Q3 = np.percentile(flat_freeboard, 75)
    IQR = Q3 - Q1
    
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    
    outliers_low = flat_freeboard[flat_freeboard < lower_bound]
    outliers_high = flat_freeboard[flat_freeboard > upper_bound]

    num_outliers = len(outliers_low) + len(outliers_high)
    
    logging.info(f"Freeboard Q1: {Q1:.4f}")
    logging.info(f"Freeboard Q3: {Q3:.4f}")
    logging.info(f"Freeboard IQR: {IQR:.4f}")
    logging.info(f"Freeboard Lower Bound (Q1 - 1.5*IQR): {lower_bound:.4f}")
    logging.info(f"Freeboard Upper Bound (Q3 + 1.5*IQR): {upper_bound:.4f}")
    logging.info(f"Number of low outliers: {len(outliers_low)}")
    logging.info(f"Number of high outliers: {len(outliers_high)}")
    logging.info(f"Total outliers: {num_outliers} ({num_outliers / total_elements * 100:.2f}% of total elements)")

    if num_outliers > 0:
        logging.warning("Potential outliers detected in Freeboard data!")
        logging.info(f"Sample low outliers (first 10): {outliers_low[:10]}")
        logging.info(f"Sample high outliers (first 10): {outliers_high[:10]}")
        logging.info(f"Sample high outliers (last 10): {outliers_high[:-10]}")
    else:
        logging.info("No significant outliers detected in Freeboard data based on IQR method.")

def plot_freeboard_distribution(freeboard_data: np.ndarray, prefix: str = ""):
    """
    Plots the distribution of the freeboard variable using a histogram and a boxplot,
    and saves the plot as a PNG file.
    
    Parameters
    ----------
    freeboard_data : np.ndarray
        The flattened NumPy array of freeboard values to plot.
    save_path : str
        The directory where the plot PNG file will be saved.
    """
    logging.info(f"--- Plotting Freeboard Distribution ({prefix}) ---")
    
    flat_freeboard = freeboard_data.flatten()
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    # --- Histogram ---
    axes[0].hist(flat_freeboard, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
    axes[0].set_title('Distribution of Freeboard (Histogram)')
    axes[0].set_xlabel('Freeboard Value')
    axes[0].set_ylabel('Frequency')
    axes[0].grid(True, linestyle='--', alpha=0.6)
    axes[0].set_xlim(0, 1.8)

    # --- Boxplot ---
    axes[1].boxplot(flat_freeboard, vert=True, patch_artist=True, boxprops=dict(facecolor='lightcoral'),
                    medianprops=dict(color='black'), whiskerprops=dict(color='gray'),
                    capprops=dict(color='gray'), flierprops=dict(marker='o', markersize=5, markerfacecolor='red', alpha=0.5))
    axes[1].set_title('Distribution of Freeboard (Boxplot)')
    axes[1].set_ylabel('Freeboard Value')
    axes[1].set_ylim(0, 0.3)
    axes[1].set_xticks([])

    plt.suptitle(f'Freeboard Data Distribution and Outlier Visualization {prefix}', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust rect to leave space for suptitle
    
    # Save to the current working directory
    current_directory = os.getcwd()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = os.path.join(current_directory, f"{prefix}_{timestamp}.png")
    
    plt.savefig(filename, dpi=300) # dpi=300 for high-quality image
    plt.close(fig) 
    
    logging.info(f"Freeboard distribution plot saved to: {filename}")

In [36]:
def analyze_ice_area_imbalance(ice_area_data: np.ndarray):
    """
    Measures and logs the percentage of ice_area data points that are 0, 1, or between 0 and 1.
    
    Parameters
    ----------
    ice_area_data : np.ndarray
        The NumPy array of ice_area values (can be multi-dimensional).
    """
    logging.info("--- Analyzing Ice Area Imbalance ---")

    flat_ice_area = ice_area_data.flatten()
    total_elements = len(flat_ice_area)

    if total_elements == 0:
        logging.warning("Ice Area data is empty, cannot analyze imbalance.")
        return

    count_zero = np.sum(flat_ice_area == 0)
    count_one = np.sum(flat_ice_area == 1)
    count_between = np.sum((flat_ice_area > 0) & (flat_ice_area < 1))

    percent_zero = (count_zero / total_elements) * 100
    percent_one = (count_one / total_elements) * 100
    percent_between = (count_between / total_elements) * 100

    logging.info(f"Total Ice Area data points: {total_elements}")
    logging.info(f"Percentage of values == 0: {percent_zero:.2f}% ({count_zero} points)")
    logging.info(f"Percentage of values == 1: {percent_one:.2f}% ({count_one} points)")
    logging.info(f"Percentage of values between 0 and 1 (exclusive): {percent_between:.2f}% ({count_between} points)")
    
    # Optional check for values outside [0, 1] range, if any
    count_invalid = np.sum((flat_ice_area < 0) | (flat_ice_area > 1))
    if count_invalid > 0:
        logging.warning(f"Found {count_invalid} ice_area values outside the [0, 1] range!")


def plot_ice_area_imbalance(ice_area_data: np.ndarray, prefix: str = ""):
    """
    Creates a bar chart to visualize the imbalance of ice_area values (0, 1, or between 0-1).
    Saves the chart as a PNG file.
    
    Parameters
    ----------
    ice_area_data : np.ndarray
        The NumPy array of ice_area values to plot (can be multi-dimensional).
    save_path : str
        The directory where the plot PNG file will be saved.
    """
    logging.info("--- Plotting Ice Area Imbalance Chart ---")

    flat_ice_area = ice_area_data.flatten()
    total_elements = len(flat_ice_area)

    if total_elements == 0:
        logging.warning("Ice Area data is empty, cannot plot imbalance.")
        return

    count_zero = np.sum(flat_ice_area == 0)
    count_00_to_25_percent = np.sum((flat_ice_area > 0) & (flat_ice_area < 0.25))
    count_25_to_50_percent = np.sum((flat_ice_area > 0.25) & (flat_ice_area < 0.5))
    count_59_to_75_percent = np.sum((flat_ice_area > 0.5) & (flat_ice_area < 0.75))
    count_75_to_99_percent = np.sum((flat_ice_area > 0.75) & (flat_ice_area < 1))
    count_one = np.sum(flat_ice_area == 1)
    
    categories = ['Exactly 0', '>0 - 0.25','0.25 - 0.5','0.5 - 0.75','0.75 - <1', 'Exactly 1']
    percentages = [
        (count_zero / total_elements) * 100,
        (count_00_to_25_percent / total_elements) * 100,
        (count_25_to_50_percent / total_elements) * 100,
        (count_59_to_75_percent / total_elements) * 100,
        (count_75_to_99_percent / total_elements) * 100,
        (count_one / total_elements) * 100,
    ]

    fig, ax = plt.subplots(figsize=(10, 7))

    bars = ax.bar(categories, percentages, color=['black','gray','silver','lightgrey','whitesmoke','white','red'], edgecolor='black')
     
    ax.set_title('Distribution of Ice Area Values', fontsize=16)
    ax.set_xlabel('Value Category', fontsize=12)
    ax.set_ylabel('Percentage of Data (%)', fontsize=12)
    ax.set_ylim(0, 80)
    ax.grid(axis='y', linestyle='--', alpha=0.7)

    # Add percentage labels on top of the bars
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.2f}%',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=10)

    plt.tight_layout()

    # Save to the current working directory
    current_directory = os.getcwd()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = os.path.join(current_directory, f"{prefix}_SIC_imbalance_{timestamp}.png")
    
    plt.savefig(filename, dpi=300) # dpi=300 for high-quality image
    plt.close(fig) 
    logging.info(f"Ice Area imbalance chart saved to: {filename}")

In [37]:
def normalize_freeboard(freeboard, min_val=-0.2, max_val=1.2):
    return np.clip((freeboard - min_val) / (max_val - min_val), 0, 1)

# Custom Pytorch Dataset
Example from NERSC of using ERA5 Dataset:

https://github.com/NERSC/dl-at-scale-training/blob/main/utils/data_loader.py

# __ init __ - masks and loads the data into tensors

In [38]:
import os
import time
from datetime import datetime
from datetime import timedelta

from torch.utils.data import Dataset
from typing import List, Union, Callable, Tuple
from NC_FILE_PROCESSING.patchify_utils import patchify_by_latlon_spillover
from perlmutterpath import * # Contains the data_dir and mesh_dir variables

import logging

# Set level to logging.INFO to see the statements
logging.basicConfig(filename='DailyNetCDFDataset.log', filemode='w', level=logging.INFO)

class DailyNetCDFDataset(Dataset):
    """
    PyTorch Dataset that concatenates a directory of month-wise NetCDF files
    along their 'Time' dimension and yields daily data *plus* its timestamp.

    Parameters
    ----------
    data_dir : str
        Directory containing NetCDF files
    transform : Callable | None
        Optional - transform applied to the data tensor *only*.
    latitude_threshold
        The minimum latitude to use for Artic data
    context_length
        The number of days to fetch for input in the prediction step
    forecast_horizon
        The number of days to predict in the future
    plot_outliers_and_imbalance
        Optional - check outliers and imbalance on the variables Ice Area and Freeboard
    trial_run
        Optional - use the data in the trial directory instead of the full dataset
        Specify the name of the trial director in perlmutterpath.py
    
    """
    def __init__(
        self,
        data_dir: str = data_dir,
        mesh_dir: str = mesh_dir,
        transform: Callable = None,
        latitude_threshold: int = LATITUDE_THRESHOLD,
        context_length: int = CONTEXT_LENGTH,
        forecast_horizon: int = FORECAST_HORIZON,
        normalize_on: bool = NORMALIZE_ON,
        plot_outliers_and_imbalance: bool = PLOT_DATA_DISTRIBUTION, # set FALSE FOR FINAL
        trial_run: bool = TRIAL_RUN, # Use the trial data directory
        num_patches: int = NUM_PATCHES,
        cells_per_patch: int = CELLS_PER_PATCH
        
    ):

        """ __init__ needs to 

        Handle the raw data:
        1) Gather the sorted daily data from each netCDF file (1 file = 1 month of daily data)
            The netCDF files contain nCells worth of data per day for each feature (ice area, ice volume, etc.)
            nCells = 465044 with the IcoswISC30E3r5 mesh
        2) Load the mesh and initialize the cell mask
        3) Store a list of datetimes from each file 
        4) Extract raw data
        
        Perform pre-processing:
        5) Apply a mask to nCells to look just at regions in certain latitudes
            nCells >= 40 degrees is 53973 cells
            nCells >= 50 degrees is 35623 cells
        6) Derive Freeboard from ice area, snow volume, and ice volume
        7) Custom patchify and store patch_ids so the data loader can use them
        8) Optional: Plot the outliers and data imbalance for Ice Area and Freeboard
        9) Optional: Normalize the data (Ice area is already between 0 and 1; Freeboard is not) """

        start_time = time.time()
        self.transform = transform
        self.latitude_threshold = latitude_threshold
        self.context_length = context_length
        self.forecast_horizon = forecast_horizon
        self.normalize_on = normalize_on
        self.plot_outliers_and_imbalance = plot_outliers_and_imbalance
        self.trial_run = trial_run
        self.num_patches = num_patches
        self.cells_per_patch = cells_per_patch

        # --- 1. Gather files (sorted for deterministic order) ---------
        if self.trial_run:
            # USE THIS FOR PRACTICE (SMALLER CHUNK OF DATA)
            self.data_dir = trial_data_dir
            self.file_paths = sorted(
                [
                    os.path.join(trial_data_dir, f)
                    for f in os.listdir(trial_data_dir)
                    if f.endswith(".nc")
                ]
            )

        else:
            # USE THE FULL DATASET (OR JUST A CERTAIN CENTURY, LIKE timeSeriesStatsDaily.20--)
            self.data_dir = data_dir
            self.file_paths = sorted(
                [
                    os.path.join(data_dir, f)
                    for f in os.listdir(data_dir)
                    if f.startswith("v3.LR.historical_0051.mpassi.hist.am.timeSeriesStatsDaily.") and f.endswith(".nc")
                ]
            )
        
        logging.info(f"Found {len(self.file_paths)} NetCDF files:")
        if not self.file_paths:
            raise FileNotFoundError(f"No *.nc files found in {data_dir!r}")
        
        # --- 2. Load the mesh file. Latitudes and Longitudes are in radians. ---
        mesh = xr.open_dataset(mesh_dir)
        latCell = np.degrees(mesh["latCell"].values)
        lonCell = np.degrees(mesh["lonCell"].values)
        mesh.close()
        
        # Initialize the cell mask
        self.cell_mask = latCell >= latitude_threshold        
        masked_ncells_size = np.count_nonzero(self.cell_mask)
        logging.info(f"Mask size: {masked_ncells_size}")

        self.full_to_masked = {
            full_idx: new_idx
            for new_idx, full_idx in enumerate(np.where(self.cell_mask)[0])
        }

        # Also store reverse mapping: masked -> full for recovery of data later
        self.masked_to_full = {
            v: k for k, v in self.full_to_masked.items()
        }

        logging.info(f"=== Extracting raw data and times in a single loop === ")

        all_times_list = []
        ice_area_all_list = []
        ice_volume_all_list = []
        snow_volume_all_list = []
        
        for i, path in enumerate(self.file_paths):
            ds = xr.open_dataset(path)

            # --- 3. Store a list of datetimes from each file -> helps with retrieving 1 day's data later
            # Extract times from byte string format
            xtime_byte_array = ds["xtime_startDaily"].values
            xtime_unicode_array = xtime_byte_array.astype(str)
            xtime_cleaned_array = np.char.replace(xtime_unicode_array, "_", " ")
            times_array = np.asarray(xtime_cleaned_array, dtype='datetime64[s]')
            all_times_list.append(times_array)

            # --- 4. Extract raw data
            ice_area = ds["timeDaily_avg_iceAreaCell"].values
            ice_volume = ds["timeDaily_avg_iceVolumeCell"].values
            snow_volume = ds["timeDaily_avg_snowVolumeCell"].values

            # --- 5. Apply a mask to the nCells
            ice_area = ice_area[:, self.cell_mask]
            ice_volume = ice_volume[:, self.cell_mask]
            snow_volume = snow_volume[:, self.cell_mask]

            # Append masked data to lists
            ice_area_all_list.append(ice_area)
            ice_volume_all_list.append(ice_volume)
            snow_volume_all_list.append(snow_volume)

            ds.close() # Close dataset after processing

        # --- Concatenate all collected data into single NumPy arrays after the loop
        self.times = np.concatenate(all_times_list, axis=0)
        self.ice_area = np.concatenate(ice_area_all_list, axis=0)
        ice_volume_combined = np.concatenate(ice_volume_all_list, axis=0)
        snow_volume_combined = np.concatenate(snow_volume_all_list, axis=0)

        # Checking the dates
        logging.info(f"Parsed {len(self.times)} total dates")
        logging.info(f"First few: {str(self.times[:5])}")

        # Stats on how many dates there are
        logging.info(f"Total days collected: {len(self.times)}")
        logging.info(f"Unique days: {len(np.unique(self.times))}")
        logging.info(f"First 35 days: {self.times[:35]}")
        logging.info(f"Last 35 days: {self.times[-35:]}")

        logging.info(f"Shape of combined ice_area array: {self.ice_area.shape}")
        logging.info(f"Elapsed time for combined data/time loading: {time.time() - start_time} seconds")
        
        # --- 6. Derive Freeboard from ice area, snow volume and ice volume
        logging.info(f"=== Calculating Freeboard === ")
        self.freeboard = compute_freeboard(self.ice_area, ice_volume_combined, snow_volume_combined)
        logging.info(f"Elapsed time for freeboard calculation: {time.time() - start_time} seconds")
        
        logging.info(f"=== Patchifying === ")
        
        # --- 7. Custom patchify function
        #     Returns 
        # full_nCells_patch_ids : np.ndarray
        #     Array of shape (nCells,) giving patch ID or -1 if unassigned.
        # indices_per_patch_id : List[np.ndarray]
        #     List of patches, each a list of cell indices (np.ndarray of ints) that correspond with nCells array.
        # patch_latlons : np.ndarray
        #     Array of shape (n_patches, 2) containing (latitude, longitude) for one
        #     representative cell per patch (the first cell added to the patch)
        self.full_nCells_patch_ids, self.indices_per_patch_id, self.patch_latlons, self.algorithm = patchify_by_latlon_spillover(
            latCell, lonCell, k=self.cells_per_patch, max_patches=self.num_patches, latitude_threshold=self.latitude_threshold)
        
        # Convert full-domain patch indices to masked-domain indices
        # This ensures there's no out of bounds problem,
        # like index 296237 is out of bounds for axis 1 with size 53973
        self.indices_per_patch_id = [
            [self.full_to_masked[i] for i in patch if i in self.full_to_masked]
            for patch in self.indices_per_patch_id
        ]
        logging.info(f"Elapsed time for patchifying with the {self.algorithm} algorithm: {time.time() - start_time} seconds")

        # --- 8. Optional --- OUTLIER DETECTION AND DATA IMBALANCE CHECK ---
        prefix = ""
        if self.trial_run:
            prefix = "trial"
            
        if self.plot_outliers_and_imbalance:
            logging.info(f"=== Plotting Outliers and Imbalance === ")
            check_freeboard_outliers(self.freeboard)
            plot_freeboard_distribution(self.freeboard, f"{prefix}_fb_pre_norm")
            analyze_ice_area_imbalance(self.ice_area)
            plot_ice_area_imbalance(self.ice_area, prefix)

        # --- 9. Optional --- Normalize the data (Area is already between 0 and 1; Freeboard is not)
        if self.normalize_on:
            logging.info(f"=== Normalizing Freeboard === ")
    
            self.freeboard_min = self.freeboard.min()
            self.freeboard_max = self.freeboard.max()
    
            logging.info(f"Freeboard min (pre-norm): {self.freeboard_min} meters" )
            logging.info(f"Freeboard max (pre-norm): {self.freeboard_max} meters")
    
            self.freeboard = normalize_freeboard(
                self.freeboard, min_val=self.freeboard_min, max_val=self.freeboard_max)
    
            logging.info(f"Freeboard Shape: {self.freeboard.shape}")
            logging.info(f"Ice Area Shape:  {self.ice_area.shape}")
    
            logging.info("=== Normalized Freeboard ===")
            freeboard_min_after_norm = self.freeboard.min()
            freeboard_max_after_norm  = self.freeboard.max()
    
            logging.info(f"Freeboard min (post-norm): {freeboard_min_after_norm}" )
            logging.info(f"Freeboard max (post-norm): {freeboard_max_after_norm}")

            if self.plot_outliers_and_imbalance:
                check_freeboard_outliers(self.freeboard)
                plot_freeboard_distribution(self.freeboard, f"{prefix}_fb_post_norm")

        logging.info("End of __init__")
        end_time = time.time()
        logging.info(f"Elapsed time: {end_time - start_time} seconds")
        print(f"Elapsed time for __init__: {end_time - start_time} seconds")
        print(f"In minutes:                {(end_time - start_time)//60} minutes")

    def __len__(self):
        """
        Returns the total number of possible starting indices (idx) for a valid sequence.
        A valid sequence needs `self.context_length` days for input and `self.forecast_horizon` days for target.
        
        ex) If the total number of days is 365, the context_length is 7 and the forecast_horizon is 3, then
        
        365 - (7 + 3) + 1 = 365 - 10 + 1 = 356 valid starting indices
        """
        required_length = self.context_length + self.forecast_horizon
        if len(self.freeboard) < required_length:
            return 0 # Not enough raw data to form even one sample

        # The number of valid starting indices
        return len(self.freeboard) - required_length + 1

    def get_patch_tensor(self, day_idx: int) -> torch.Tensor:
        
        """
        Retrieves the feature data for a specific day, organized into patches.

        This method extracts 'freeboard' and 'ice_area' data for a given day
        and then reshapes it according to the pre-defined patches. Each patch
        will contain its own set of feature values.

        Parameters
        ----------
        day_idx : int
            The integer index of the day to retrieve data for, relative to the
            concatenated dataset's time dimension.

        Returns
        -------
        torch.Tensor
            A tensor containing the feature data organized by patches for the
            specified day.
            Shape: (num_patches, num_features, patch_size)
            Where:
            - num_patches: Total number of patches (ex., 140).
            - num_features: The number of features per cell (currently 2: freeboard, ice_area).
            - patch_size: The number of cells within each patch.
            
        """
        
        freeboard_day = self.freeboard[day_idx]  # (nCells,)
        ice_area_day = self.ice_area[day_idx]    # (nCells,)
        features = np.stack([freeboard_day, ice_area_day], axis=0)  # (2, nCells)
        patch_tensors = []

        for patch_indices in self.indices_per_patch_id:
            patch = features[:, patch_indices]  # (2, patch_size)
            patch_tensors.append(torch.tensor(patch, dtype=torch.float32))

        return torch.stack(patch_tensors)  # (context_length, num_patches, num_features, patch_size)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, np.datetime64]:

        """__ getitem __ needs to 
        
        1. Given an input of a certain date id, get the input and the target tensors
        2. Return all the patches for the input and the target
           Features are: [freeboard, ice_area] over masked cells. 
           
        """
        # Start with the id of the day in question
        start_idx = idx

        # end_idx is the exclusive end of the input sequence,
        # and the inclusive start of the target sequence.
        end_idx = idx + self.context_length
        target_start = end_idx

        # the target sequence ends after forecast horizon
        target_end = end_idx + self.forecast_horizon

        if target_end > len(self.freeboard):
            raise IndexError(
                f"Requested time window exceeds dataset. "
                f"Problematic idx: {idx}, "
                f"Context Length: {self.context_length}, "
                f"Forecast Horizon: {self.forecast_horizon}, "
                f"Calculated target_end: {target_end}, "
                f"Actual dataset length (len(self.freeboard)): {len(self.freeboard)}"
            )

        # Build input tensor
        input_seq = [self.get_patch_tensor(i) for i in range(start_idx, end_idx)]
        input_tensor = torch.stack(input_seq)
    
        # Build target tensor: shape (forecast_horizon, num_patches)
        target_seq = self.ice_area[end_idx:target_end]
        target_patches = []
        for day in target_seq:
            patch_day = [
                torch.tensor(day[patch_indices]) for patch_indices in self.indices_per_patch_id
            ]
            
            # After stacking, patch_day_tensor will be (num_patches, CELLS_PER_PATCH)
            patch_day_tensor = torch.stack(patch_day)  # (num_patches,)
            target_patches.append(patch_day_tensor)

        # Final target tensor shape: (forecast_horizon, num_patches, CELLS_PER_PATCH)
        target_tensor = torch.stack(target_patches)  # (forecast_horizon, num_patches)
        
        return input_tensor, target_tensor, start_idx, end_idx, target_start, target_end

    def __repr__(self):
        """ Format the string representation of the data """
        return (
            f"<DailyNetCDFDataset: {len(self)} days, "
            f"{len(self.freeboard[0])} cells/day, "
            f"{len(self.file_paths)} files loaded>"
        )

    def time_to_dataframe(self) -> pd.DataFrame:
            """Return a DataFrame of time features you can merge with predictions."""
            t = pd.to_datetime(self.times)            # pandas Timestamp index
            return pd.DataFrame(
                {
                    "time": t,
                    "year": t.year,
                    "month": t.month,
                    "day": t.day,
                    "doy": t.dayofyear,
                }
            )

In [39]:
!sqs

JOBID            ST USER      NAME          NODES TIME_LIMIT       TIME  SUBMIT_TIME          QOS             START_TIME           FEATURES       NODELIST(REASON
41060061         R  brelypo   jupyter       1        6:00:00    1:11:02  2025-07-25T22:09:34  gpu_jupyter     2025-07-25T22:09:36  gpu&a100       nid001013      


# DataLoader

In [40]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset

print(f"===== Making the Dataset Class: TRIAL_RUN MODE IS {TRIAL_RUN} ===== ")

# load all the data from one folder
dataset = DailyNetCDFDataset(data_dir)

# Patch locations for positional embedding
PATCH_LATLONS_TENSOR = torch.tensor(dataset.patch_latlons, dtype=torch.float32)

# TODO: PLAY AROUND WITH DIFFERENT SUBSETS OF TIME FOR TESTING
total_days = len(dataset)
train_end = int(total_days * 0.7)
val_end = int(total_days * 0.85)

train_set = Subset(dataset, range(0, train_end))
val_set   = Subset(dataset, range(train_end, val_end))
test_set  = Subset(dataset, range(val_end, total_days))

print("Training data length:   ", len(train_set))
print("Validation data length: ", len(val_set))
print("Testing data length:    ", len(test_set))

total_days = len(train_set) + len(val_set) + len(test_set)
print("Total days = ", total_days)

print("Number of training batches", len(train_set)/BATCH_SIZE)
print("Number of training batches", len(val_set)/BATCH_SIZE)

print("Number of test batches after drop_last incomplete batch", int(len(test_set)/BATCH_SIZE))
print("Number of test days to drop after drop_last incomplete batch", len(test_set)//BATCH_SIZE)

print("===== Printing Dataset ===== ")
print(dataset)                 # calls __repr__ → see how many files & days loaded

# sample is tensor, ts is np.datetime64
input_tensor, target_tensor, start_idx, end_idx, target_start, target_end = dataset[0]

print(f"Fetched start index {start_idx}: Time={dataset.times[start_idx]}")
print(f"Fetched end   index {end_idx}: Time={dataset.times[end_idx]}")

print(f"Fetched target start index {target_start}: Time={dataset.times[target_start]}")
print(f"Fetched target end   index {target_end}: Time={dataset.times[target_end]}")

print("===== Starting DataLoader ====")
# wrap in a DataLoader
# 1. Use pinned memory for faster asynch transfer to GPUs)
# 2. Use a prefetch factor so that the GPU is fed w/o a ton of CPU memory use
# 3. Use shuffle=False to preserve time order (especially for forecasting)
# 4. Use drop_last=True to prevent it from testing on incomplete batches
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
val_loader   = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
test_loader  = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2, drop_last=True)

print("input_tensor should be of shape (context_length, num_patches, num_features, patch_size)")
print(f"actual input_tensor.shape = {input_tensor.shape}")
print("target_tensor should be of shape (forecast_horizon, num_patches, patch_size)")
print(f"actual target_tensor.shape = {target_tensor.shape}")

===== Making the Dataset Class: TRIAL_RUN MODE IS False ===== 
Built 210 patches of size ~256
Cluster sizes:
min size 256
max size 411284
smallest count (np.int64(0), 256)
max count (np.int64(-1), 411284)
number of patches: 211
Elapsed time for __init__: 3281.9779698848724 seconds
In minutes:                54.0 minutes
Training data length:    44706
Validation data length:  9580
Testing data length:     9580
Total days =  63866
Number of training batches 2794.125
Number of training batches 598.75
Number of test batches after drop_last incomplete batch 598
Number of test days to drop after drop_last incomplete batch 598
===== Printing Dataset ===== 
<DailyNetCDFDataset: 63866 days, 53973 cells/day, 2100 files loaded>
Fetched start index 0: Time=1850-01-01T00:30:00
Fetched end   index 7: Time=1850-01-08T00:00:00
Fetched target start index 7: Time=1850-01-08T00:00:00
Fetched target end   index 10: Time=1850-01-11T00:00:00
===== Starting DataLoader ====
input_tensor should be of shape (co

# Transformer Class

In [41]:
import torch
import torch.nn as nn
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class IceForecastTransformer(nn.Module):
    
    """
    A Transformer-based model for forecasting ice conditions based on sequences of
    historical patch data.

    Parameters
    ----------
    input_patch_features_dim : int
        The dimensionality of the feature vector for each individual patch (ex. 2 features).
        This is the input dimension for the patch embedding layer.
    num_patches : int
        The total number of geographical patches that the `nCells` data was divided into.
        (ex., 256 patches).
    context_length : int, optional
        The number of historical days (time steps) to use as input for the transformer.
        Defaults to 7.
    forecast_horizon : int, optional
        The number of future days to predict for each patch.
        Defaults to 1.
    d_model : int, optional
        The dimension of the model's hidden states (embedding dimension).
        This is the size of the vectors that flow through the Transformer encoder.
        Defaults to 128.
    nhead : int, optional
        The number of attention heads in the multi-head attention mechanism within
        each Transformer encoder layer. Defaults to 8.
    num_layers : int, optional
        The number of Transformer encoder layers in the model. Defaults to 4.

    Attributes
    ----------
    patch_embed : nn.Linear
        Linear layer to project input patch features into the `d_model` hidden space.
    encoder : nn.TransformerEncoder
        The Transformer encoder module composed of `num_layers` encoder layers.
    mlp_head : nn.Sequential
        A multi-layer perceptron head for outputting predictions for each patch.
    """
    
    def __init__(self,
                 input_patch_features_dim: int = PATCH_EMBEDDING_INPUT_DIM, # D: The flat feature dimension of a single patch (ex., 512)
                 num_patches: int = NUM_PATCHES,  # P: Number of spatial patches
                 context_length: int = CONTEXT_LENGTH, # T: Number of historical time steps
                 forecast_horizon: int = FORECAST_HORIZON, # Number of future time steps to predict (usually 1)
                 d_model: int = D_MODEL,        # d_model: Transformer's embedding dimension
                 nhead: int = N_HEAD,           # nhead: Number of attention heads
                 num_layers: int = NUM_TRANSFORMER_LAYERS # num_layers: Number of TransformerEncoderLayers
                ):
        
        super().__init__()

        """
        The transformer should
        1. Accept a sequence of days (ex. 7 days of patches). 
           The context_length parameter says how many days to use for input.
        2. Encode each patch with the transformer.
        3. Output the patches for regression (ex. predict the 8th day).
           The forecast_horizon parameter says how many days to use for the output prediction.
        
        """

        self.context_length = context_length
        self.forecast_horizon = forecast_horizon
        self.num_patches = num_patches
        self.d_model = d_model
        self.input_patch_features_dim = input_patch_features_dim
   
        print("Calling IceForecastTransformer __init__")
        start_time = time.time()

        # Patch embedding layer: projects the raw patch features (512)
        # into d_model (128) hidden space dimension
        self.patch_embed = nn.Linear(input_patch_features_dim, d_model)

        # Transformer Encoder
        # batch_first=True means input/output tensors are (batch, sequence, features)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output MLP head: (B, P, CELLS_PER_PATCH * forecast_horizon)
        # Make a prediction for every cell per patch
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, CELLS_PER_PATCH * forecast_horizon)
        )

        end_time = time.time()
        print(f"Elapsed time: {end_time - start_time:.2f} seconds")
        print("End of __init__")

    def forward(self, x):
        """
        B = Batch size
        T = Time (context_length)
        P = Patch count
        D = Patch Dimension (cells per patch * feature count)
        x: Tensor of shape (B, T, P, D)
        Output: Tensor of shape (batch_size, forecast_horizon, num_patches)
        Output: (B, forecast_horizon, P)
        """
        
        # Initial input x shape from DataLoader / pre-processing:
        # (B, T, P, D) i.e., (Batch_Size, Context_Length, Num_Patches, Input_Patch_Features_Dim)
        # Example: (16, 7, 140, 512)
        
        B, T, P, D = x.shape

        # Flatten time and patches for the Transformer Encoder:
        # Each (Time, Patch) combination becomes a single token in the sequence.
        # Output shape: (B, T * P, D)
        # Example: (16, 7 * 140 = 980, 512)
        
        # Flatten time and patches for the Transformer Encoder: (B, T * P, D)
        # This treats each patch at each time step as a distinct token
        x = x.view(B, T * P, D)

        # Project patch features to the transformer's d_model dimension
        x = self.patch_embed(x)  # Output: (B, T * P, d_model) ex., (16, 980, 128)
        
        # Apply transformer encoder layers
        x = self.encoder(x)      # Output: (B, T * P, d_model) ex., (16, 980, 128)

        # Reshape back to separate time and patches: (B, T, P, d_model) ex., (16, 7, 140, 128)
        x = x.view(B, T, P, self.d_model) 

        # Mean pooling over the time (context_length) dimension for each patch.
        # This aggregates information from all historical time steps for each patch's final prediction.        
        x = x.mean(dim=1)  # Output: (B, P, d_model) ex., (16, 140, 128)

        # TODO: SOMEHOW SAVE ATTENTION TO MAP LATER

        # Apply MLP head to predict values for each cell in each patch
        # The MLP head outputs (B, P, CELLS_PER_PATCH * forecast_horizon)
        x = self.mlp_head(x) # ex. (16, 140, 256 * 3) = (16, 140, 768)

        # Reshape the output to (B, forecast_horizon, P, CELLS_PER_PATCH)
        # Explicitly reshape the last dimension to seperate the forecast horizon out
        x = x.view(B, P, self.forecast_horizon, CELLS_PER_PATCH) # Reshape into forecast_horizon and CELLS_PER_PATCH
        x = x.permute(0, 2, 1, 3) # Permute to (B, forecast_horizon, P, CELLS_PER_PATCH)

        return x


In [42]:
!sqs

JOBID            ST USER      NAME          NODES TIME_LIMIT       TIME  SUBMIT_TIME          QOS             START_TIME           FEATURES       NODELIST(REASON
41060061         R  brelypo   jupyter       1        6:00:00    2:05:49  2025-07-25T22:09:34  gpu_jupyter     2025-07-25T22:09:36  gpu&a100       nid001013      


# Training Loop

In [43]:
if TRAINING:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torch import Tensor
    import torch.nn.functional as F
    
    import logging
    
    # Set level to logging.INFO to see the statements
    logging.basicConfig(filename='IceForecastTransformerInstance.log', filemode='w', level=logging.INFO)
    
    model = IceForecastTransformer().to(device)
    
    print("\n--- Model Architecture ---")
    print(model)
    print("--------------------------\n")
    
    logging.info("\n--- Model Architecture ---")
    logging.info(str(model)) # Log the full model structure
    logging.info(f"Total model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logging.info("--------------------------\n")
    
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()
    
    start_time = time.time()
    logging.info("===============================")
    logging.info("       STARTING EPOCHS       ")
    logging.info("===============================")
    logging.info(f"Number of epochs: {NUM_EPOCHS}")
    logging.info(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0
    
        for batch_idx, (input_tensor, target_tensor, start_idx, end_idx, target_start, target_end) in enumerate(train_loader):  
    
            # Move input and target to the device
            # x: (B, context_length, num_patches, input_patch_features_dim), y: (B, forecast_horizon, num_patches)
            x = input_tensor.to(device)  # Shape: (B, T, P, C, L)
            y = target_tensor.to(device)  # Shape: (B, forecast_horizon, P, L)
    
            # Reshape x for transformer input
            B, T, P, C, L = x.shape
            x_reshaped_for_transformer_D = x.view(B, T, P, C * L)
    
            # Run through transformer
            y_pred = model(x_reshaped_for_transformer_D) # y_pred is (B, forecast_horizon, num_patches) ex., (16, 1, 140)
            
            # Compute loss
            loss = criterion(y_pred, y) # DIRECTLY compare y_pred and y
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
    
        avg_train_loss = total_loss / len(train_loader)
        logging.info(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {avg_train_loss:.4f}")
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {avg_train_loss:.4f}") # Keep print for immediate console feedback
    
        # --- Validation loop ---
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                # Unpack the full tuple
                x_val, y_val, start_idx, end_idx, target_start, target_end = batch
        
                # Move to GPU if available
                x_val = x_val.to(device)
                y_val = y_val.to(device)
    
                # Extract dimensions from x_val for reshaping
                # x_val before reshaping: (B_val, T_val, P_val, C_val, L_val)
                B_val, T_val, P_val, C_val, L_val = x_val.shape
                
                # Reshape x_val for transformer input
                x_val_reshaped_for_transformer_input = x_val.view(B_val, T_val, P_val, C_val * L_val)
    
                # Model output is (B, forecast_horizon, P, L)
                y_val_pred = model(x_val_reshaped_for_transformer_input) 
    
                # Compute validation loss (y_val_pred and y_val should have identical shapes)
                val_loss += criterion(y_val_pred, y_val).item() # y_val is (B, forecast_horizon, P, L)
        
        avg_val_loss = val_loss / len(val_loader)
        logging.info(f"Epoch {epoch+1}/{NUM_EPOCHS} - Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Loss: {avg_val_loss:.4f}") # Keep print for immediate console feedback
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    logging.info("===============================================")
    logging.info(f"Elapsed time for TRAINING: {elapsed_time:.2f} seconds")
    logging.info("===============================================")
    print("===============================================")
    print(f"Elapsed time for TRAINING: {elapsed_time:.2f} seconds")
    print("===============================================")

Calling IceForecastTransformer __init__
Elapsed time: 0.03 seconds
End of __init__

--- Model Architecture ---
IceForecastTransformer(
  (patch_embed): Linear(in_features=512, out_features=128, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mlp_head): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)

In [44]:
!sqs

JOBID            ST USER      NAME          NODES TIME_LIMIT       TIME  SUBMIT_TIME          QOS             START_TIME           FEATURES       NODELIST(REASON
41060061         R  brelypo   jupyter       1        6:00:00    2:56:16  2025-07-25T22:09:34  gpu_jupyter     2025-07-25T22:09:36  gpu&a100       nid001013      


TODO OPTION: Try temporal attention only (ex., Informer, Time Series Transformer).

# Save the Model

In [45]:
# Define the path where to save or load the model
if TRAINING:
    PATH = f"SIC_model_{model_version}.pth"
    
    # Save the model's state_dict
    torch.save(model.state_dict(), PATH)
    print(f"Saved model at {PATH}")

Saved model at SIC_model_fd_nT_D128_B16_lt40_P210_L256_T7_Fh3_e10.pth


# === BELOW - CAN BE USED ANY TIME FROM A .PTH FILE

Make sure and run the cells that contain constants or run all, but comment out the "save" and the training loop cell.

# Re-Load the Model

In [46]:
if EVALUATING_ON:
    
    import torch
    import torch.nn as nn
    
    if not torch.cuda.is_available():
        raise ValueError("There is a problem with Torch not recognizing the GPUs")
    
    # Instantiate the model (must have the same architecture as when it was saved)
    # Create an identical instance of the original __init__ parameters
    loaded_model = IceForecastTransformer()
    
    # Load the saved state_dict (weights_only=True helps ensure safety of pickle files)
    loaded_model.load_state_dict(torch.load(PATH, weights_only=True))
    
    # Set the model to evaluation mode
    loaded_model.eval()
    
    # Move the model to the appropriate device (CPU or GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loaded_model.to(device)
    
    print("Model loaded successfully!")

Calling IceForecastTransformer __init__
Elapsed time: 0.01 seconds
End of __init__
Model loaded successfully!


# Metrics

In [47]:
if EVALUATING_ON:

    import io
    
    # Create a string buffer to capture output
    captured_output = io.StringIO()
    
    # Redirect stdout to the buffer
    sys.stdout = captured_output
    
    from scipy.stats import entropy
    
    # Accumulators for errors
    all_abs_errors = [] # To store absolute errors for each cell in each patch
    all_mse_errors = [] # To store MSE for each cell in each patch
    
    # Accumulators for histogram data
    all_predicted_values_flat = []
    all_actual_values_flat = []
    
    print("\nStarting evaluation and metric calculation...")
    print("==================")
    print(f"DEBUG: Batch Size: {BATCH_SIZE} Days")
    print(f"DEBUG: Context Length: {CONTEXT_LENGTH} Days")
    print(f"DEBUG: Forecast Horizon: {FORECAST_HORIZON} Days")
    print(f"DEBUG: Number of batches in test_loader (with drop_last=True): {len(test_loader)} Batches")
    print("==================")
    print(f"DEBUG: len(test_set): {len(test_set)} Days")
    print(f"DEBUG: len(dataset) for splitting: {len(dataset)} Days") # Should be 356
    print(f"DEBUG: train_end: {train_end}")
    print(f"DEBUG: val_end: {val_end}") # Should be 302
    print(f"DEBUG: range for test_set: {range(val_end, total_days)}") # Should be range(302, 356)
    print("==================")
    
    # Iterate over the test_loader
    # (B, forecast_horizon, P, CELLS_PER_PATCH) to match the model's output.
    for i, (sample_x, sample_y, start_idx, end_idx, target_start, target_end) in enumerate(test_loader):
        print(f"Processing batch {i+1}/{len(test_loader)}")
        
        # Move to device and apply initial reshape as done in training
        sample_x = sample_x.to(device)
        sample_y = sample_y.to(device) # Actual target values
    
        # Initial reshape of x for the Transformer model
        B_sample, T_sample, P_sample, C_sample, L_sample = sample_x.shape
        sample_x_reshaped = sample_x.view(B_sample, T_sample, P_sample, C_sample * L_sample)
    
        # Perform inference
        with torch.no_grad(): # Essential for inference to disable gradient calculations
            predicted_y_patches = loaded_model(sample_x_reshaped)
    
        # Ensure predicted_y_patches and sample_y have the same shape for comparison
        # Expected shape: (B, forecast_horizon, NUM_PATCHES, CELLS_PER_PATCH)
        if predicted_y_patches.shape != sample_y.shape:
            print(f"Shape mismatch: Predicted {predicted_y_patches.shape}, Actual {sample_y.shape}")
            continue # Skip this batch if shapes are incompatible
    
        # Calculate errors for each cell in each patch, across the forecast horizon and batch
        # The errors will implicitly be averaged over the batch when we take the mean later
        diff = predicted_y_patches - sample_y
        abs_error_batch = torch.abs(diff)
        mse_error_batch = diff ** 2
    
        # Accumulate errors (move to CPU for storage if memory is a concern)
        all_abs_errors.append(abs_error_batch.cpu())
        all_mse_errors.append(mse_error_batch.cpu())
    
        # Collect data for histograms (flatten all values)
        all_predicted_values_flat.append(predicted_y_patches.cpu().numpy().flatten())
        all_actual_values_flat.append(sample_y.cpu().numpy().flatten())
    
    # Concatenate all accumulated tensors
    if all_abs_errors and all_mse_errors:
        combined_abs_errors = torch.cat(all_abs_errors, dim=0) # Shape: (Total_Samples, FH, P, CPC)
        combined_mse_errors = torch.cat(all_mse_errors, dim=0) # Shape: (Total_Samples, FH, P, CPC)
    
        # Calculate average MSE and Absolute Error for each cell in each patch
        # Average over batch size and forecast horizon
        # Resulting shape: (NUM_PATCHES, CELLS_PER_PATCH)
        mean_abs_error_per_cell_patch = combined_abs_errors.mean(dim=(0, 1)) # Average over batch and forecast horizon
        mean_mse_per_cell_patch = combined_mse_errors.mean(dim=(0, 1)) # Average over batch and forecast horizon
    
        print("\n--- Error Metrics (Averaged per Cell per Patch) ---")
        print(f"Mean Absolute Error (shape {mean_abs_error_per_cell_patch.shape}):")
        # print(mean_abs_error_per_cell_patch) # Uncomment to see the full tensor
        print(f"Overall Mean Absolute Error:            {mean_abs_error_per_cell_patch.mean().item():.4f}")
    
        print(f"\nMean Squared Error (shape {mean_mse_per_cell_patch.shape}):")
        # print(mean_mse_per_cell_patch) # Uncomment to see the full tensor
    
        mse = mean_mse_per_cell_patch.mean().item()
        print(f"Overall Mean Squared Error:             {mse:.4f}")
    
        rmse = np.sqrt(mse)
        print(f"Overall Root Mean Squared Error (RMSE): {rmse}")
        
    else:
        print("No data processed for error metrics. Check test_loader and data availability.")
    
    # --- Histogram and Jensen-Shannon Distance ---
    
    # Concatenate all flattened values
    if all_predicted_values_flat and all_actual_values_flat:
        final_predicted_values = np.concatenate(all_predicted_values_flat)
        final_actual_values = np.concatenate(all_actual_values_flat)
    
        print(f"\nTotal predicted values collected: {len(final_predicted_values)}")
        print(f"Total actual values collected: {len(final_actual_values)}")
    
        # Define bins for the histogram (e.g., for ice concentration between 0 and 1)
        # Adjust bins based on the expected range of your data
        bins = np.linspace(0, 1, 51) # 50 bins from 0 to 1
    
        # Compute histograms
        hist_predicted, _ = np.histogram(final_predicted_values, bins=bins, density=True)
        hist_actual, _ = np.histogram(final_actual_values, bins=bins, density=True)
    
        # Normalize histograms to sum to 1 (they are already density=True, but re-normalize for safety)
        hist_predicted = hist_predicted / hist_predicted.sum()
        hist_actual = hist_actual / hist_actual.sum()
    
        # Jensen-Shannon Distance function
        def jensen_shannon_distance(p, q):
            """Calculates the Jensen-Shannon distance between two probability distributions."""
            # Ensure distributions sum to 1
            p = p / p.sum()
            q = q / q.sum()
    
            m = 0.5 * (p + q)
            # Add a small epsilon to avoid log(0)
            epsilon = 1e-10
            jsd = 0.5 * (entropy(p + epsilon, m + epsilon) + entropy(q + epsilon, m + epsilon))
            return np.sqrt(jsd) # JSD is the square root of JS divergence
    
        # Calculate Jensen-Shannon Distance
        jsd = jensen_shannon_distance(hist_actual, hist_predicted)
        print(f"\nJensen-Shannon Distance between actual and predicted histograms: {jsd:.4f}")
    
        # Plotting Histograms
        plt.figure(figsize=(10, 6))
        plt.hist(final_actual_values, bins=bins, alpha=0.7, label='Actual Data', color='skyblue', density=True)
        plt.hist(final_predicted_values, bins=bins, alpha=0.7, label='Predicted Data', color='salmon', density=True)
        plt.title('Distribution of Actual vs. Predicted Ice Concentration Values')
        plt.xlabel('Ice Concentration Value')
        plt.ylabel('Probability Density')
        plt.legend()
        plt.grid(axis='y', alpha=0.75)
        plt.savefig(f"SIE_Distribution_Actual_vs_Predicted_model_{model_version}.png")
        plt.close()
    
        # When reading the histograms, look for overlap:
        # High Overlap: predictions are close to actual values. Decent model.
        # Low Overlap: predictions differ from actual values, issues with the model. 
    
    else:
        print("No data collected for histogram analysis. Check test_loader and data availability.")
    
    print("\nEvaluation complete.")
    
    # Restore stdout
    sys.stdout = sys.__stdout__
    
    # Now, write the captured output to the file
    with open(f'Metrics_{PATH}.txt', 'w') as f:
        f.write(captured_output.getvalue())
    
    print(f"Metrics saved to Metrics_{PATH}.txt")

Metrics saved to Metrics_SIC_model_fd_nT_D128_B16_lt40_P210_L256_T7_Fh3_e10.pth.txt
Shape of sample_x torch.Size([16, 7, 210, 2, 256])
Shape of sample_y torch.Size([16, 3, 210, 256])
Fetched sample_x start index tensor([54286, 54287, 54288, 54289, 54290, 54291, 54292, 54293, 54294, 54295,
        54296, 54297, 54298, 54299, 54300, 54301]): Time=['1998-09-24T00:00:00' '1998-09-25T00:00:00' '1998-09-26T00:00:00'
 '1998-09-27T00:00:00' '1998-09-28T00:00:00' '1998-09-29T00:00:00'
 '1998-09-30T00:00:00' '1998-10-01T00:00:00' '1998-10-02T00:00:00'
 '1998-10-03T00:00:00' '1998-10-04T00:00:00' '1998-10-05T00:00:00'
 '1998-10-06T00:00:00' '1998-10-07T00:00:00' '1998-10-08T00:00:00'
 '1998-10-09T00:00:00']
Fetched sample_x end   index tensor([54293, 54294, 54295, 54296, 54297, 54298, 54299, 54300, 54301, 54302,
        54303, 54304, 54305, 54306, 54307, 54308]):   Time=['1998-10-01T00:00:00' '1998-10-02T00:00:00' '1998-10-03T00:00:00'
 '1998-10-04T00:00:00' '1998-10-05T00:00:00' '1998-10-06T00:0

# Make a Single Prediction

In [48]:
if EVALUATING_ON:
    # Turn off the logging for this part
    # https://docs.python.org/3/library/logging.html#logrecord-attributes
    logging.disable(level=logging.INFO)
    
    # Load one batch
    data_iter = iter(test_loader)
    sample_x, sample_y, start_idx, end_idx, target_start, target_end = next(data_iter)
    
    print(f"Shape of sample_x {sample_x.shape}")
    print(f"Shape of sample_y {sample_y.shape}")   
    
    print(f"Fetched sample_x start index {start_idx}: Time={dataset.times[start_idx]}")
    print(f"Fetched sample_x end   index {end_idx}:   Time={dataset.times[end_idx]}")
    
    print(f"Fetched sample_y (target) start index {target_end}: Time={dataset.times[target_end]}")
    print(f"Fetched sample_y (target) end   index {target_end}: Time={dataset.times[target_end]}")
    
    # Move to device and apply initial reshape as done in training
    sample_x = sample_x.to(device)
    sample_y = sample_y.to(device) # Keep sample_y for actual comparison
    
    # Initial reshape of x for the Transformer model
    B_sample, T_sample, P_sample, C_sample, L_sample = sample_x.shape
    sample_x_reshaped = sample_x.view(B_sample, T_sample, P_sample, C_sample * L_sample)
    
    print(f"Sample x for inference shape (reshaped): {sample_x_reshaped.shape}")
    
    # Perform inference
    with torch.no_grad(): # Essential for inference to disable gradient calculations
        predicted_y_patches = loaded_model(sample_x_reshaped)
    
    print(f"Predicted y patches shape: {predicted_y_patches.shape}")
    print(f"Expected shape: (B, forecast_horizon, NUM_PATCHES, CELLS_PER_PATCH) ex., (16, {loaded_model.forecast_horizon}, 140, 256)")
                     
    # Option 1: Select a specific day from the forecast horizon (ex., the first day)
    # This is the shape (B, NUM_PATCHES, CELLS_PER_PATCH) for that specific day.
    predicted_for_day_0 = predicted_y_patches[:, 0, :, :].cpu()
    print(f"Predicted ice area for Day 0 (specific day) shape: {predicted_for_day_0.shape}")
    
    # Ensure sample_y has the same structure
    actual_for_day_0 = sample_y[:, 0, :, :].cpu()
    print(f"Actual ice area for Day 0 (specific day) shape: {actual_for_day_0.shape}")
    
    # Save predictions so that I can use cartopy by switching kernels for the next jupyter cell
    np.save(f'patches/ice_area_patches_predicted_{PATH}_day0.npy', predicted_for_day_0)
    np.save(f'patches/ice_area_patches_actual_{PATH}_day0.npy', actual_for_day_0)

    # Option 2: Iterate through all forecast days
    all_predicted_ice_areas = []
    all_actual_ice_areas = []
    
    for day_idx in range(loaded_model.forecast_horizon):
        predicted_day = predicted_y_patches[:, day_idx, :, :].cpu()
        all_predicted_ice_areas.append(predicted_day)
    
        actual_day = sample_y[:, day_idx, :, :].cpu()
        all_actual_ice_areas.append(actual_day)
    
        print(f"Processing forecast day {day_idx}: Predicted shape {predicted_day.shape}, Actual shape {actual_day.shape}")
    
        # Save each day's prediction/actual data if needed
        # np.save(f'patches/ice_area_patches_predicted_day{day_idx}.npy', predicted_day)
        # np.save(f'patches/ice_area_patches_actual_day{day_idx}.npy', actual_day)


# Recover nCells from Patches for Visualization

In [49]:
if EVALUATING_ON:

    ########################################
    # SWAP KERNELS IN THE JUPYTER NOTEBOOK #
    ########################################
    
    from MAP_ANIMATION_GENERATION.map_gen_utility_functions import *
    from NC_FILE_PROCESSING.nc_utility_functions import *
    from NC_FILE_PROCESSING.patchify_utils import *
    
    import numpy as np
    
    predicted_ice_area_patches = np.load(f'patches/SIC_predicted_{model_version}_day0.npy')
    actual_y_ice_area_patches = np.load(f'patches/SIC_actual_{model_version}_day0.npy')
    
    NUM_PATCHES = len(predicted_ice_area_patches[0])
    print("NUM_PATCHES is", NUM_PATCHES)
    
    latCell, lonCell = load_mesh(perlmutterpathMesh)
    TOTAL_GRID_CELLS = len(lonCell) 
    cell_mask = latCell >= LATITUDE_THRESHOLD
    
    # Extract Freeboard (index 0) and Ice Area (index 1) for predicted and actual
    # Predicted output is (B, 1, NUM_PATCHES, CELLS_PER_PATCH)
    # Assuming the model predicts ice area, which is the second feature (index 1)
    # if the output of the model aligns with the order of features *within* the original patch_dim.
    
    # Load the original patch-to-cell mapping
    # indices_per_patch_id = [
    #     [idx_cell_0_0, ..., idx_cell_0_255],
    #     [idx_cell_1_0, ..., idx_cell_1_255],
    #     ...
    # ]
    
    full_nCells_patch_ids, indices_per_patch_id, patch_latlons = patchify_by_latlon_spillover(
                latCell, lonCell, k=256, max_patches=NUM_PATCHES, LATITUDE_THRESHOLD=LATITUDE_THRESHOLD)
    
    # Select one sample from the batch for visualization (ex., the first one)
    # Output is (NUM_PATCHES, CELLS_PER_PATCH) for this single sample
    sample_predicted_cells_per_patch = predicted_ice_area_patches[2] # First item in batch
    sample_actual_cells_per_patch = predicted_ice_area_patches[2] # First item in batch
    
    # Initialize empty arrays for the full grid (nCells)
    recovered_predicted_grid = np.full(TOTAL_GRID_CELLS, np.nan)
    recovered_actual_grid = np.full(TOTAL_GRID_CELLS, np.nan)
    
    # Populate the full grid using the patch data and mapping
    for patch_idx in range(NUM_PATCHES):
        cell_indices_in_patch = indices_per_patch_id[patch_idx]
        
        # For predicted values
        recovered_predicted_grid[cell_indices_in_patch] = sample_predicted_cells_per_patch[patch_idx]
        nan_mask = np.isnan(recovered_predicted_grid)
        nan_count = np.sum(nan_mask)
    
        # For actual values
        recovered_actual_grid[cell_indices_in_patch] = sample_actual_cells_per_patch[patch_idx]
    
    print(f"Recovered predicted grid shape: {recovered_predicted_grid.shape}")
    print(f"Recovered actual grid shape: {recovered_actual_grid.shape}")
    
    fig, northMap = generate_axes_north_pole()
    generate_map_north_pole(fig, northMap, latCell, lonCell, recovered_predicted_grid, f"model {model_version} ice area recovered")
    
    fig, northMap = generate_axes_north_pole()
    generate_map_north_pole(fig, northMap, latCell, lonCell, recovered_actual_grid, f"model {model_version} ice area actual")

ModuleNotFoundError: No module named 'cartopy'

In [None]:
!sqs