# Imports

In [1]:
import sys
print('System Version:', sys.version)

System Version: 3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]


In [2]:
#print(sys.executable) # for troubleshooting kernel issues
#print(sys.path)

In [3]:
import os
#print(os.getcwd())

In [4]:
import numpy as np
print('Numpy version', np.__version__)

Numpy version 2.2.6


In [5]:
import cupy as cp

ModuleNotFoundError: No module named 'cupy'

In [None]:
import cudf
print('Cudf version', cudf.__version__)

In [None]:
import pandas as pd
print('Pandas version', pd.__version__)

In [None]:
import xarray as xr
import cupy_xarray # This registers the .cupy namespace

In [None]:
%pip list | grep "xarray\|cupy"

In [None]:
import matplotlib
import matplotlib.pyplot as plt
print('Matplotlib version', matplotlib.__version__)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

print('PyTorch version', torch.__version__)

# Hardware Details

In [None]:
print(torch.cuda.device_count()) # check the number of available CUDA devices
# will print 1 on login node; 4 on GPU exclusive node; 1 on shared GPU node

In [None]:
#print(torch.cuda.get_device_properties(0)) #provides information about a specific GPU

#total_memory=40326MB, multi_processor_count=108, L2_cache_size=40MB

In [None]:
import psutil
import platform

# Get general CPU information
processor_name = platform.processor()
print(f"Processor Name: {processor_name}")

# Get core counts
physical_cores = psutil.cpu_count(logical=False)
logical_cores = psutil.cpu_count(logical=True)
print(f"Physical Cores: {physical_cores}")
print(f"Logical Cores: {logical_cores}")

# Get CPU frequency
cpu_frequency = psutil.cpu_freq()
if cpu_frequency:
    print(f"Current CPU Frequency: {cpu_frequency.current:.2f} MHz")
    print(f"Min CPU Frequency: {cpu_frequency.min:.2f} MHz")
    print(f"Max CPU Frequency: {cpu_frequency.max:.2f} MHz")

# Get CPU utilization (percentage)
# The interval argument specifies the time period over which to measure CPU usage.
# Setting percpu=True gives individual core utilization.
cpu_percent_total = psutil.cpu_percent(interval=1)
print(f"Total CPU Usage: {cpu_percent_total}%")

# cpu_percent_per_core = psutil.cpu_percent(interval=1, percpu=True)
# print("CPU Usage per Core:")
# for i, percent in enumerate(cpu_percent_per_core):
#     print(f"  Core {i+1}: {percent}%")



# Example of one netCDF file with xarray

In [None]:
ds = xr.open_dataset("train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-01-01.nc")

In [None]:
ds.data_vars

In [None]:
day_counter = ds["timeDaily_counter"]
day_counter.shape

In [None]:
print(ds["xtime_startDaily"])

In [None]:
print(ds["xtime_startDaily"].values)

In [None]:
ice_area = ds["timeDaily_avg_iceAreaCell"]
ice_area.shape

In [None]:
ice_area.values

In [None]:
print(ds.coords)
print(ds.dims)

In [None]:
print(ds)
del ds

# Example of Mesh File

In [None]:
mesh = xr.open_dataset("NC_FILE_PROCESSING/mpassi.IcoswISC30E3r5.20231120.nc")

In [None]:
mesh.data_vars

In [None]:
cellsOnCell = mesh["cellsOnCell"].values
print(mesh["cellsOnCell"].values)

In [None]:
print(mesh["cellsOnCell"].max().values)
print(mesh["cellsOnCell"].min().values)

In [None]:
cp.save('cellsOnCell.npy', cellsOnCell) 

In [None]:
print(mesh.coords)
print(mesh.dims)

In [None]:
print(mesh)

In [None]:
del mesh

# Pre-processing + Freeboard calculation functions

In [None]:
# Constants (adjust if you use different units)
D_WATER = 1023  # Density of seawater (kg/m^3)
D_ICE = 917     # Density of sea ice (kg/m^3)
D_SNOW = 330    # Density of snow (kg/m^3)

MIN_AREA = 1e-6

def compute_freeboard(area: cp.ndarray, 
                      ice_volume: cp.ndarray, 
                      snow_volume: cp.ndarray) -> cp.ndarray:
    """
    Compute sea ice freeboard from ice and snow volume and area.
    
    Parameters
    ----------
    area : cp.ndarray
        Sea ice concentration / area (same shape as ice_volume and snow_volume).
    ice_volume : cp.ndarray
        Sea ice volume per grid cell.
    snow_volume : cp.ndarray
        Snow volume per grid cell.
    
    Returns
    -------
    freeboard : cp.ndarray
        Freeboard height for each cell, same shape as inputs.
    """
    # Initialize arrays
    height_ice = cp.zeros_like(ice_volume)
    height_snow = cp.zeros_like(snow_volume)

    # Valid mask: avoid dividing by very small or zero area
    valid = area > MIN_AREA

    # Safely compute heights where valid
    height_ice[valid] = ice_volume[valid] / area[valid]
    height_snow[valid] = snow_volume[valid] / area[valid]

    # Compute freeboard using the physical formula
    freeboard = (
        height_ice * (D_WATER - D_ICE) / D_WATER +
        height_snow * (D_WATER - D_SNOW) / D_WATER
    )

    return freeboard


In [None]:
def normalize_freeboard(freeboard, min_val=-0.2, max_val=1.2):
    return cp.clip((freeboard - min_val) / (max_val - min_val), 0, 1)

# Custom Pytorch Dataset

Example from NERSC of using ERA5 Dataset:
https://github.com/NERSC/dl-at-scale-training/blob/main/utils/data_loader.py

## Constants

TRY: NUM_WORKERS as 16 to 32 - profile to see if the GPU is still waiting on the CPU.

TRY: NUM_WORKERS as 64 - the number of CPU cores available.

TRY: NUM_WORKERS experiment with os.cpu_count() - 2

TRY: NUM_WORKERS experiment with (logical_cores_per_gpu * num_gpus)

num_workers considerations:
Too few workers: GPUs might become idle waiting for data.
Too many workers: Can lead to increased CPU memory usage and context switching overhead.


In [None]:
NUM_WORKERS = 64
BATCH_SIZE = 16

# __ init __ - masks and loads the data into tensors

In [None]:
import os
import time
from datetime import datetime
from datetime import timedelta

from torch.utils.data import Dataset
from typing import List, Union, Callable, Tuple
from NC_FILE_PROCESSING.patchify_utils import *
from perlmutterpath import * # Contains the data_dir and mesh_dir variables

import logging

# Set level to logging.INFO to see the statements
logging.basicConfig(filename='DailyNetCDFDataset.log', filemode='w', level=logging.INFO)

class DailyNetCDFDataset(Dataset):
    """
    PyTorch Dataset that concatenates a directory of month-wise NetCDF files
    along their 'Time' dimension and yields daily data *plus* its timestamp.

    Parameters
    ----------
    data_dir : str
        Directory containing NetCDF files
    transform : Callable | None
        Optional transform applied to the data tensor *only*.
    decode_time : bool
        Let xarray convert CF-style time coordinates to cp.datetime64.
    drop_missing : bool
        If True, drops any days where one of the requested variables is missing.
    latitude_threshold
        The minimum latitude to use for Artic data
    context_length
        The number of days to fetch for input in the prediction step
    forecast_horizon
        The number of days to predict in the future
    
    """
    def __init__(
        self,
        data_dir: str = data_dir,
        mesh_dir: str = mesh_dir,
        transform: Callable = None,
        decode_time: bool = True,
        drop_missing: bool = True,
        latitude_threshold = 40,
        context_length = 7,
        forecast_horizon = 1
    ):

        """ __init__ needs to 

        Handle the raw data:
        1) Gather the sorted daily data from each netCDF file (1 file = 1 month of daily data)
            The netCDF files contain nCells worth of data per day for each feature (ice area, ice volume, etc.)
            nCells = 465044 with the IcoswISC30E3r5 mesh
        2) Store the datetime information from each nCells array from the daily data
        3) Extract raw data
        
        Perform pre-processing:
        4) Apply a mask to nCells to look just at regions in certain latitudes
            nCells >= 40 degrees is 53973 cells
            nCells >= 50 degrees is 35623 cells
        5) Derive Freeboard from ice area, snow volume, and ice volume
        6) Custom patchify and store patch_ids so the data loader can use them
        7) Concatenate the data across Time
        8) Normalize the data (Ice area is already between 0 and 1; Freeboard is not) """

        start_time = time.time()
        self.transform = transform
        self.context_length = context_length
        self.forecast_horizon = forecast_horizon

        # --- 1. Gather files (sorted for deterministic order) ---------
        self.data_dir = data_dir
        self.mesh_dir = mesh_dir
        self.file_paths = sorted(
            [
                os.path.join(data_dir, f)
                for f in os.listdir(data_dir)
                if f.endswith(".nc")
            ]
        )

        logging.info(f"Found {len(self.file_paths)} NetCDF files:")
        # for f in self.file_paths:
        #     logging.info(f"  - {f}")     # Print all the file names in the folder

        if not self.file_paths:
            raise FileNotFoundError(f"No *.nc files found in {data_dir!r}")

        # Open all the netCDF files and concatenate them along Time dimension
        logging.info("Loading datasets with xarray.open_mfdataset...")
        
        # --- 2. Store a list of datetimes from each file -> helps with retrieving 1 day's data later
        # This happens on the CPU
        all_times = []
        for path in self.file_paths:
            ds = xr.open_dataset(path)
        
            # Decode byte strings and fix the format
            xtime_strs = ds["xtime_startDaily"].str.decode("utf-8").values
            xtime_strs = [s.replace("_", " ") for s in xtime_strs]  # "0010-01-01_00:00:00" → "0010-01-01 00:00:00"
        
            # Convert to datetime.datetime objects
            times = [datetime.strptime(s, "%Y-%m-%d %H:%M:%S") for s in xtime_strs]
            all_times.extend(times)
        
        # Store in self.times
        self.times = all_times
        self.times = cudf.to_datetime(all_times)

        # Checking the dates
        logging.info(f"Parsed {len(self.times)} total dates")
        logging.info(f"First few: {str(self.times[:5])}")

        # Stats on how many dates there are
        logging.info(f"Total days collected: {len(self.times)}")
        logging.info(f"Unique days: {len(self.times.unique())}") # Use .unique() for cudf Series
        logging.info(f"First 35 days: {self.times[:35]}")
        logging.info(f"First days 360 to 400 days: {self.times[360:401]}")

        # Load the mesh file and create a mask. Latitudes and Longitudes are in radians.
        latCell, lonCell = load_mesh_radians(self.mesh_dir)
        latCell = cp.array(latCell)
        lonCell = cp.array(lonCell)
        latCell = cp.degrees(latCell)
        lonCell = cp.degrees(lonCell)
        print(type(latCell))
        
        self.cell_mask = latCell >= latitude_threshold
        logging.info(f"Mask size: {cp.count_nonzero(self.cell_mask)}")

        self.full_to_masked = {
            int(full_idx): new_idx
            for new_idx, full_idx in enumerate(cp.where(self.cell_mask)[0])
        }

        # --- 3. Extract raw data 
        self.freeboard_all = []
        self.ice_area_all = []

        for path in self.file_paths:
            ds = xr.open_dataset(path)

            # Extract raw data
            area = cp.array(ds["timeDaily_avg_iceAreaCell"].values)
            ice_volume = cp.array(ds["timeDaily_avg_iceVolumeCell"].values)
            snow_volume = cp.array(ds["timeDaily_avg_snowVolumeCell"].values)

            # --- 4. Apply a mask to the nCells
            area = area[:, self.cell_mask]
            ice_volume = ice_volume[:, self.cell_mask]
            snow_volume = snow_volume[:, self.cell_mask]

            # --- 5. Derive Freeboard from ice area, snow volume and ice volume
            freeboard = compute_freeboard(area, ice_volume, snow_volume)

            # These will be deleted later to save space
            self.freeboard_all.append(freeboard) 
            self.ice_area_all.append(area)

        # --- 6. Custom patchify function       
        self.full_nCells_patch_ids, self.indices_per_patch_id = patchify_by_latlon_spillover(
            latCell, lonCell, k=256, max_patches=140, lat_threshold=latitude_threshold)

        # Convert full-domain patch indices to masked-domain indices
        self.indices_per_patch_id = [
            [self.full_to_masked[i] for i in patch if i in self.full_to_masked]
            for patch in self.indices_per_patch_id
        ]

        # --- 7. Concatenate the data across Time

        # Concatenate across time
        self.freeboard = cp.concatenate(self.freeboard_all, axis=0)  # (T, nCells)
        self.ice_area = cp.concatenate(self.ice_area_all, axis=0)    # (T, nCells)

        # Discard the lists that are not needed anymore -- save space
        del self.freeboard_all, self.ice_area_all

        logging.info(f"Freeboard {self.freeboard.shape}")
        logging.info(f"Ice Area  {self.ice_area.shape}")

        # --- 8. Normalize the data (Area is already between 0 and 1; Freeboard is not)
        self.freeboard_min = self.freeboard[0].min()
        self.freeboard_max = self.freeboard[0].max()
        
        logging.info(f"Freeboard min: {self.freeboard_min}" )
        logging.info(f"Freeboard max: {self.freeboard_max}")

        self.freeboard_all = normalize_freeboard(
            freeboard, min_val=self.freeboard_min, max_val=self.freeboard_max)

        logging.info("=== Normalized Freeboard ===")
        logging.info("End of __init__")

        end_time = time.time()
        logging.info(f"Elapsed time: {end_time - start_time} seconds")

    def __len__(self) -> int:
        """
        Returns the total number of possible starting indices (idx) for a valid sequence.
        A valid sequence needs `self.context_length` days for input and `self.forecast_horizon` days for target.
        
        ex) If the total number of days is 365, the context_length is 7 and the forecast_horizon is 3, then
        last valid starting index = total days - (context length + forecast horizon) + 1
        365 - (7 + 3) + 1 = 365 - 10 + 1 = 356 valid starting indices
        """
        required_length = self.context_length + self.forecast_horizon
        if len(self.freeboard) < required_length:
            return 0 # Not enough raw data to form even one sample

        # The total number of valid starting indices
        return len(self.freeboard) - required_length + 1

    def get_patch_tensor(self, day_idx: int) -> torch.Tensor:
        
        """
        Retrieves the feature data for a specific day, organized into patches.

        This method extracts 'freeboard' and 'ice_area' data for a given day
        and then reshapes it according to the pre-defined patches. Each patch
        will contain its own set of feature values.

        Parameters
        ----------
        day_idx : int
            The integer index of the day to retrieve data for, relative to the
            concatenated dataset's time dimension.

        Returns
        -------
        torch.Tensor
            A tensor containing the feature data organized by patches for the
            specified day.
            Shape: (num_patches, num_features, patch_size)
            Where:
            - num_patches: Total number of patches (ex., 140).
            - num_features: The number of features per cell (currently 2: freeboard, ice_area).
            - patch_size: The number of cells within each patch.
            
        """
        
        freeboard_day = self.freeboard[day_idx]  # (nCells,)
        ice_area_day = self.ice_area[day_idx]    # (nCells,)
        features = cp.stack([freeboard_day, ice_area_day], axis=0)  # (2, nCells)
        patch_tensors = []

        for patch_indices in self.indices_per_patch_id:
            patch = features[:, patch_indices]  # (2, patch_size)
            patch_tensors.append(torch.tensor(patch, dtype=torch.float32))

        return torch.stack(patch_tensors)  # (context_length, num_patches, num_features, patch_size)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, np.datetime64]:

        """__ getitem __ needs to 
        
        1. Given an input of a certain date id, get the input and the target tensors
        2. Return all the patches for the input and the target
           Features are: [freeboard, ice_area] over masked cells. 
           
        """

        start_time = time.time()
        logging.info("Calling __getitem__")

        start_idx = idx
        end_idx = idx + self.context_length
        target_start = end_idx + 1              # added this - TODO - CHECK FOR ERRORS
        target_end = end_idx + self.forecast_horizon

        if target_end > len(self.freeboard):
            raise IndexError("Requested time window exceeds dataset")
        
        # Build input tensor
        input_seq = [self.get_patch_tensor(i) for i in range(start_idx, end_idx)]
        input_tensor = torch.stack(input_seq)
    
        # Build target tensor: shape (forecast_horizon, num_patches)
        target_seq = self.ice_area[end_idx:target_end]  # (forecast_horizon, nCells)
        target_patches = []
        for day in target_seq:
            patch_day = [
                torch.tensor(day[patch_indices]) for patch_indices in self.indices_per_patch_id
            ]
            patch_day_tensor = torch.stack(patch_day)  # (num_patches,)
            target_patches.append(patch_day_tensor)
        
        target_tensor = torch.stack(target_patches)  # (forecast_horizon, num_patches)
        
        logging.info(f"Input  tensor shape {input_tensor.shape}")
        logging.info(f"Target tensor shape {target_tensor.shape}")

        logging.info("input_tensor should be of shape (context_length, num_patches, num_features, patch_size)")
        logging.info("target_tensor should be of shape (forecast_horizon, num_patches, patch_size)")

        logging.info(f"Fetched start index {start_idx}: Time={self.times[start_idx]}")
        logging.info(f"Fetched end   index {end_idx}: Time={self.times[end_idx]}")
        
        logging.info(f"Fetched target start index {target_end}: Time={self.times[target_end]}")
        logging.info(f"Fetched target end   index {target_end}: Time={self.times[target_end]}")

        end_time = time.time()
        logging.info(f"Elapsed time: {end_time - start_time} seconds")

        return input_tensor, target_tensor, start_idx, end_idx, target_start, target_end # TODO, CHECK FOR ERRORS

    def __repr__(self):
        """ Format the string representation of the data """
        return (
            f"<DailyNetCDFDataset: {len(self)} days, "
            f"{len(self.freeboard[0])} cells/day, "
            f"{len(self.file_paths)} files loaded>"
        )

    def time_to_dataframe(self) -> pd.DataFrame:
            """Return a DataFrame of time features you can merge with predictions."""
            t = pd.to_datetime(self.times)            # pandas Timestamp index
            return pd.DataFrame(
                {
                    "time": t,
                    "year": t.year,
                    "month": t.month,
                    "day": t.day,
                    "doy": t.dayofyear,
                }
            )

# DataLoader

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset

print("===== Making the Dataset Class ===== ")

# OPTION 1: LOADING FROM ONE BIG FOLDER:
dataset = DailyNetCDFDataset(data_dir)

total_days = len(dataset)
train_end = int(total_days * 0.7)
val_end = int(total_days * 0.85)

train_set = Subset(dataset, range(0, train_end))
val_set   = Subset(dataset, range(train_end, val_end))
test_set  = Subset(dataset, range(val_end, total_days))

print("Training data length:   ", len(train_set))
print("Validation data length: ", len(val_set))
print("Testing data length:    ", len(test_set))

total_days = len(train_set) + len(val_set) + len(test_set)
print("Total days = ", total_days)

# OPTION 2: LOADING FROM SEPARATE FOLDERS:
# train_dataset = DailyNetCDFDataset(data_dir="/train", mesh_dir=mesh_dir)
# val_dataset   = DailyNetCDFDataset(data_dir="/valid", mesh_dir=mesh_dir)
# test_dataset  = DailyNetCDFDataset(data_dir="/test",  mesh_dir=mesh_dir)

print("===== Printing Dataset ===== ")
print(dataset)                 # calls __repr__ → see how many files & days loaded

input_tensor, target_tensor, start_idx, end_idx, target_start, target_end = dataset[0]        # sample is tensor, ts is cp.datetime64

print(f"Fetched start index {start_idx}: Time={dataset.times[start_idx]}")
print(f"Fetched end   index {end_idx}: Time={dataset.times[end_idx]}")

print(f"Fetched target start index {target_end}: Time={dataset.times[target_end]}")
print(f"Fetched target end   index {target_end}: Time={dataset.times[target_end]}")

print("===== Starting DataLoader ====")
# wrap in a DataLoader
# 1. Use pinned memory for faster asynch transfer to GPUs)
# 2. Use a prefetch factor so that the GPU is fed w/o a ton of CPU memory use
# 3. Use shuffle=False to preserve time order (especially for forecasting)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
val_loader   = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)
test_loader  = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor=2)

print("input_tensor should be of shape (context_length, num_patches, num_features, patch_size)")
print("target_tensor should be of shape (forecast_horizon, num_patches)")

# Model Hyperparameter Constants / Defaults

In [None]:
CONTEXT_LENGTH = 7         # T: Number of historical time steps used for input
FORECAST_HORIZON = 1       # Number of future time steps to predict (ex. 1 day for next time step)
NUM_PATCHES = 140          # P: Number of spatial patches
NUM_FEATURES = 2           # C: Number of features per cell (ex., Freeboard, Ice Area)
CELLS_PER_PATCH = 256      # L: Number of cells within each patch
D_MODEL = 128              # d_model: Dimension of the transformer's internal representations (embedding dimension)
N_HEAD = 8                 # nhead: Number of attention heads
NUM_TRANSFORMER_LAYERS = 4 # num_layers: Number of TransformerEncoderLayers

# The input dimension for the patch embedding linear layer.
# Each patch at a given time step has NUM_FEATURES * CELLS_PER_PATCH features.
# This is the 'D' dimension used in the Transformer's input tensor (B, T, P, D).
PATCH_EMBEDDING_INPUT_DIM = NUM_FEATURES * CELLS_PER_PATCH # 2 * 256 = 512

# Transformer Class
<!-- outputs = model(features)
model.train()
model.eval() -->

In [None]:
import torch
import torch.nn as nn
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class IceForecastTransformer(nn.Module):
    
    """
    A Transformer-based model for forecasting ice conditions based on sequences of
    historical patch data.

    Parameters
    ----------
    input_patch_features_dim : int
        The dimensionality of the feature vector for each individual patch (ex. 2 features).
        This is the input dimension for the patch embedding layer.
    num_patches : int
        The total number of geographical patches that the `nCells` data was divided into.
        (ex., 256 patches).
    context_length : int, optional
        The number of historical days (time steps) to use as input for the transformer.
        Defaults to 7.
    forecast_horizon : int, optional
        The number of future days to predict for each patch.
        Defaults to 1.
    d_model : int, optional
        The dimension of the model's hidden states (embedding dimension).
        This is the size of the vectors that flow through the Transformer encoder.
        Defaults to 128.
    nhead : int, optional
        The number of attention heads in the multi-head attention mechanism within
        each Transformer encoder layer. Defaults to 8.
    num_layers : int, optional
        The number of Transformer encoder layers in the model. Defaults to 4.

    Attributes
    ----------
    patch_embed : nn.Linear
        Linear layer to project input patch features into the `d_model` hidden space.
    encoder : nn.TransformerEncoder
        The Transformer encoder module composed of `num_layers` encoder layers.
    mlp_head : nn.Sequential
        A multi-layer perceptron head for outputting predictions for each patch.
    """
    
    def __init__(self,
                 input_patch_features_dim: int, # D: The flat feature dimension of a single patch (ex., 512)
                 num_patches: int,              # P: Number of spatial patches
                 context_length: int,           # T: Number of historical time steps
                 forecast_horizon: int,         # Number of future time steps to predict (usually 1)
                 d_model: int = D_MODEL,        # d_model: Transformer's embedding dimension
                 nhead: int = N_HEAD,           # nhead: Number of attention heads
                 num_layers: int = NUM_TRANSFORMER_LAYERS # num_layers: Number of TransformerEncoderLayers
                ):
        
        super().__init__()

        """
        The transformer should
        1. Accept a sequence of days (ex. 7 days of patches). 
           The context_length parameter says how many days to use for input.
        2. Encode each patch with the transformer
        3. Output the patches for regression (ex. predict the 8th day)
           The forecast_horizon parameter says how many days to use for the output prediction
        
        """

        self.context_length = context_length
        self.forecast_horizon = forecast_horizon
        self.num_patches = num_patches
        self.d_model = d_model
        self.input_patch_features_dim = input_patch_features_dim
   
        print("Calling IceForecastTransformer __init__")
        start_time = time.time()

        # Patch embedding layer: projects the raw patch features (512)
        # into d_model (128) hidden space dimension
        self.patch_embed = nn.Linear(input_patch_features_dim, d_model)

        # Transformer Encoder
        # batch_first=True means input/output tensors are (batch, sequence, features)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output MLP head:
        # Make a prediction for every cell per patch
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, CELLS_PER_PATCH) # TODO: CHECK - Should this be multiplied by forecast_horizon???
        )

        # Positional Encoding (from your previous code, assuming it's implemented)
        # self.pos_encoder = PositionalEncoding(d_model)
        # TODO: IMPLEMENT A simple positional embedding or a standard sine/cosine one.

        end_time = time.time()
        print(f"Elapsed time: {end_time - start_time:.2f} seconds")
        print("End of __init__")

    def forward(self, x):
        """
        B = Batch size
        T = Time (context_length)
        P = Patch count
        D = Patch Dimension (cells per patch * feature count)
        x: Tensor of shape (B, T, P, D)
        Output: Tensor of shape (batch_size, forecast_horizon, num_patches)
        Output: (B, forecast_horizon, P)
        """

        logging.info("Calling forward")
        
        # Initial input x shape from DataLoader / pre-processing:
        # (B, T, P, D) i.e., (Batch_Size, Context_Length, Num_Patches, Input_Patch_Features_Dim)
        # Example: (16, 7, 140, 512)

        logging.info("Expected x shape: (B, T, P, D) ex., (16, 7, 140, 512)")
        logging.info("Actual   x shape: ", x.shape)
        
        B, T, P, D = x.shape

        # Flatten time and patches for the Transformer Encoder:
        # Each (Time, Patch) combination becomes a single token in the sequence.
        # Output shape: (B, T * P, D)
        # Example: (16, 7 * 140 = 980, 512)
        
        # Flatten time and patches for the Transformer Encoder: (B, T * P, D)
        # This treats each patch at each time step as a distinct token
        x = x.view(B, T * P, D)

        # Project patch features to the transformer's d_model dimension
        x = self.patch_embed(x)  # Output: (B, T * P, d_model) ex., (16, 980, 128)
        logging.info("Expected patch embedding dimensions: (B, T * P, d_model) ex., (16, 980, 128)")
        logging.info("Actual   patch embedding dimensions: ", x.shape)

        # TODO: Add positional encoding HERE
        # x = self.pos_encoder(x)
        
        # Apply transformer encoder layers
        x = self.encoder(x)      # Output: (B, T * P, d_model) ex., (16, 980, 128)

        # Reshape back to separate time and patches: (B, T, P, d_model) ex., (16, 7, 140, 128)
        x = x.view(B, T, P, self.d_model) 

        # Mean pooling over the time (context_length) dimension for each patch.
        # This aggregates information from all historical time steps for each patch's final prediction.        
        x = x.mean(dim=1)  # Output: (B, P, d_model) ex., (16, 140, 128)

        # TODO: SOMEHOW SAVE ATTENTION TO MAP LATER

        # Apply MLP head to predict values for each cell in each patch
        # The MLP head outputs CELLS_PER_PATCH values for each of the P patches
        x = self.mlp_head(x)  # Output: (B, P, CELLS_PER_PATCH) ex., (16, 140, 256)

        # Add forecast_horizon dimension
        # The target 'y' is (B, forecast_horizon, P, CELLS_PER_PATCH)
        # This makes the output shape match the target 'y' or the forecast_horizon
        x = x.unsqueeze(self.forecast_horizon) # Output: (B, 1, P, CELLS_PER_PATCH) ex., (16, 1, 140, 256)

        logging.info("Expected output dimensions: (B, 1, P, CELLS_PER_PATCH) ex., (16, 1, 140, 256)")
        logging.info("Actual   output dimensions: ", x.shape)
        
        return x



# Training Loop

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader
# from torch import Tensor
# import torch.nn.functional as F

# # Set level to logging.INFO to see the statements
# logging.basicConfig(filename='IceForecastTransformer.log', filemode='w', level=logging.INFO)

# model = IceForecastTransformer(
#     input_patch_features_dim=PATCH_EMBEDDING_INPUT_DIM,
#     num_patches=NUM_PATCHES,
#     context_length=CONTEXT_LENGTH,
#     forecast_horizon=FORECAST_HORIZON
# ).to(device)

# print("\n--- Model Architecture ---")
# print(model)
# print("--------------------------\n")

# optimizer = optim.Adam(model.parameters(), lr=1e-4)
# criterion = nn.MSELoss()
# num_epochs = 100

# start_time = time.time()
# logging.info(" ===============================")
# logging.info(" =      STARTING EPOCHS        =")
# logging.info(" ===============================")

# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0

#     for batch_idx, (x, y) in enumerate(train_loader):  
        
#         # x: (B, context_length, num_patches, input_patch_features_dim), y: (B, forecast_horizon, num_patches)
#         x = x.to(device) # Move to GPU if available
#         y = y.to(device) # y is (B, forecast_horizon, num_patches) ex., (16, 1, 140)

#         logging.info("Expected x shape is (B, T, P, C, L) ex., (16, 7, 140, 2, 256)")
#         logging.info("Actual   x shape is ", x.shape)

#         logging.info("Expected y shape is (B, forecast_horizon, P, L) ex., (16, 1, 140, 256)")
#         logging.info("Actual   y shape is ", y.shape)

#         # Reshape x for transformer input
#         B, T, P, C, L = x.shape
#         x_reshaped_for_transformer_D = x.view(B, T, P, C * L)

#         logging.info("Expected reshaped x is (B, T, P, D_input)")
#         logging.info("Actual   reshaped x is ", x_reshaped_for_transformer_D.shape)  # should now be (B, T, P, 512)
    
#         # Run through transformer
#         y_pred = model(x_reshaped_for_transformer_D) # y_pred is (B, forecast_horizon, num_patches) ex., (16, 1, 140)

#         logging.info("Expected y_pred shape is (B, forecast_horizon , P, L)")
#         logging.info("Actual   y_pred shape is ", y_pred.shape)
        
#         # Compute loss
#         loss = criterion(y_pred, y) # DIRECTLY compare y_pred and y
    
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()

#     print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {total_loss:.4f}")

#     # --- Validation loop ---
#     model.eval()
#     val_loss = 0
#     with torch.no_grad():
#         for x_val, y_val in val_loader:
#             x_val = x_val.to(device)
#             y_val = y_val.to(device)

#             # Extract dimensions from x_val for reshaping
#             # x_val before reshaping: (B_val, T_val, P_val, C_val, L_val)
#             B_val, T_val, P_val, C_val, L_val = x_val.shape
            
#             # Reshape x_val for transformer input
#             x_val_reshaped_for_transformer_input = x_val.view(B_val, T_val, P_val, C_val * L_val)

#             # Model output is (B, forecast_horizon, P, L)
#             y_val_pred = model(x_val_reshaped_for_transformer_input) 

#             # Compute validation loss (y_val_pred and y_val should have identical shapes)
#             val_loss += criterion(y_val_pred, y_val).item() # y_val is (B, forecast_horizon, P, L)
    
#     print(f"Validation Loss: {val_loss:.4f}")

# end_time = time.time()
# print("===============================================")
# print(f"Elapsed time for TRAINING: {end_time - start_time:.2f} seconds")
# print("===============================================")

TODO: Add Positional Encoding to represent time steps.

TODO OPTION: Try temporal attention only (ex., Informer, Time Series Transformer).

# Save the Model

In [None]:
# # Define the path where to save the model
# PATH = "sea_ice_concentration_model_2.pth"

# # Save the model's state_dict
# torch.save(model.state_dict(), PATH)

# print("Saved model")

# Re-Load the Model

In [None]:
import torch
import torch.nn as nn

# Define the path where to load the model
PATH = "sea_ice_concentration_model.pth"

# Instantiate the model (must have the same architecture as when it was saved)
# Create an identical instance of the original __init__ parameters
# Make sure global constants (like D_MODEL, N_HEAD, etc.) are consistent.
loaded_model = IceForecastTransformer(
    input_patch_features_dim=PATCH_EMBEDDING_INPUT_DIM,
    num_patches=NUM_PATCHES,
    context_length=CONTEXT_LENGTH,
    forecast_horizon=FORECAST_HORIZON,
    d_model=D_MODEL,
    nhead=N_HEAD,
    num_layers=NUM_TRANSFORMER_LAYERS
)

# Load the saved state_dict (weights_only=True helps ensure safety of pickle files)
loaded_model.load_state_dict(torch.load(PATH, weights_only=True))

# Set the model to evaluation mode
loaded_model.eval()

# Move the model to the appropriate device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)

print("Model loaded successfully!")

# Make a Single Prediction

In [None]:
# Turn off the logging for this part
# https://docs.python.org/3/library/logging.html#logrecord-attributes
logging.disable(level=logging.INFO)

# Load one batch for demonstration
data_iter = iter(test_loader)
sample_x, sample_y, start_idx, end_idx, target_start, target_end = next(data_iter)

print(f"Fetched sample_x start index {start_idx}: Time={dataset.times[start_idx]}")
print(f"Fetched sample_x end   index {end_idx}:   Time={dataset.times[end_idx]}")

print(f"Fetched sample_y (target) start index {target_end}: Time={dataset.times[target_end]}")
print(f"Fetched sample_y (target) end   index {target_end}: Time={dataset.times[target_end]}")

# Move to device and apply initial reshape as done in training
sample_x = sample_x.to(device)
sample_y = sample_y.to(device) # Keep sample_y for actual comparison

# Initial reshape of x for the Transformer model
B_sample, T_sample, P_sample, C_sample, L_sample = sample_x.shape
sample_x_reshaped = sample_x.view(B_sample, T_sample, P_sample, C_sample * L_sample)

print(f"Sample x for inference shape (reshaped): {sample_x_reshaped.shape}")

# Perform inference
with torch.no_grad(): # Essential for inference to disable gradient calculations
    predicted_y_patches = loaded_model(sample_x_reshaped)

print(f"Predicted y patches shape: {predicted_y_patches.shape}")
print("Expected shape: (B, 1, NUM_PATCHES, CELLS_PER_PATCH) e.g., (16, 1, 140, 256)")

# Squeeze the forecast_horizon dimension (since it's 1)
predicted_ice_area_patches = predicted_y_patches.squeeze(1).cpu() # Shape: (B, NUM_PATCHES, CELLS_PER_PATCH)
actual_y_ice_area_patches = sample_y.squeeze(1).cpu() # Shape: (B, NUM_PATCHES, CELLS_PER_PATCH).

cp.save('ice_area_patches_predicted.npy', predicted_ice_area_patches)
cp.save('ice_area_patches_actual.npy', actual_y_ice_area_patches)


# Recover nCells from Patches for Visualization

In [None]:
########################################
# SWAP KERNELS IN THE JUPYTER NOTEBOOK #
########################################

from MAP_ANIMATION_GENERATION.map_gen_utility_functions import *
from NC_FILE_PROCESSING.nc_utility_functions import *
from NC_FILE_PROCESSING.patchify_utils import *

import numpy as np

LAT_THRESHOLD = 40

predicted_ice_area_patches = cp.load("ice_area_patches_predicted.npy")
actual_y_ice_area_patches = cp.load("ice_area_patches_actual.npy")

NUM_PATCHES = len(predicted_ice_area_patches[0])
print("NUM_PATCHES is", NUM_PATCHES)

latCell, lonCell = load_mesh(perlmutterpathMesh)
TOTAL_GRID_CELLS = len(lonCell) 

# Extract Freeboard (index 0) and Ice Area (index 1) for predicted and actual
# Predicted output is (B, 1, NUM_PATCHES, CELLS_PER_PATCH)
# Assuming the model predicts ice area, which is the second feature (index 1)
# if the output of the model aligns with the order of features *within* the original patch_dim.

# Load the original patch-to-cell mapping
# indices_per_patch_id = [
#     [idx_cell_0_0, ..., idx_cell_0_255],
#     [idx_cell_1_0, ..., idx_cell_1_255],
#     ...
# ]

full_nCells_patch_ids, indices_per_patch_id = patchify_by_latlon_spillover(
            latCell, lonCell, k=256, max_patches=140, lat_threshold=LAT_THRESHOLD)

# Select one sample from the batch for visualization (e.g., the first one)
# Output is (NUM_PATCHES, CELLS_PER_PATCH) for this single sample
sample_predicted_cells_per_patch = predicted_ice_area_patches[0] # First item in batch
sample_actual_cells_per_patch = predicted_ice_area_patches[0] # First item in batch

# Initialize empty arrays for the full grid (nCells)
recovered_predicted_grid = cp.full(TOTAL_GRID_CELLS, -1, dtype=int)
recovered_actual_grid = cp.full(TOTAL_GRID_CELLS, -1, dtype=int)

# Populate the full grid using the patch data and mapping
for patch_idx in range(NUM_PATCHES):
    cell_indices_in_patch = indices_per_patch_id[patch_idx]
    
    # For predicted values
    recovered_predicted_grid[cell_indices_in_patch] = sample_predicted_cells_per_patch[patch_idx]

    # For actual values
    recovered_actual_grid[cell_indices_in_patch] = sample_actual_cells_per_patch[patch_idx]

print(f"Recovered predicted grid shape: {recovered_predicted_grid.shape}")
print(f"Recovered actual grid shape: {recovered_actual_grid.shape}")

fig, northMap = generate_axes_north_pole()
generate_map_north_pole(fig, northMap, latCell, lonCell, recovered_predicted_grid, "ice area recovered")

fig, northMap = generate_axes_north_pole()
generate_map_north_pole(fig, northMap, latCell, lonCell, recovered_actual_grid, "ice area actual")