# Imports

In [2]:
import sys
print('System Version:', sys.version)

System Version: 3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]


In [3]:
print(sys.executable)

/global/homes/b/brelypo/.conda/envs/sic_sie_env/bin/python


In [4]:
import numpy as np
print('Numpy version', np.__version__)

Numpy version 2.2.6


In [5]:
import pandas as pd
print('Pandas version', pd.__version__)

Pandas version 2.3.0


In [6]:
import xarray as xr
print('Xarray version', xr.__version__)

Xarray version 2025.6.0


In [7]:
import matplotlib
import matplotlib.pyplot as plt
print('Matplotlib version', matplotlib.__version__)

Matplotlib version 3.10.3


In [8]:
import torch
from torch.utils.data import Dataset, DataLoader

print('PyTorch version', torch.__version__)

PyTorch version 2.5.1


# Example of one netCDF file with xarray

In [9]:
ds = xr.open_dataset("train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-01-01.nc")

In [10]:
ds.data_vars

Data variables:
    timeDaily_counter             (Time) int32 124B ...
    xtime_startDaily              (Time) |S64 2kB ...
    xtime_endDaily                (Time) |S64 2kB ...
    timeDaily_avg_iceAreaCell     (Time, nCells) float32 58MB ...
    timeDaily_avg_iceVolumeCell   (Time, nCells) float32 58MB ...
    timeDaily_avg_snowVolumeCell  (Time, nCells) float32 58MB ...
    timeDaily_avg_uVelocityGeo    (Time, nVertices) float32 117MB ...
    timeDaily_avg_vVelocityGeo    (Time, nVertices) float32 117MB ...

In [11]:
day_counter = ds["timeDaily_counter"]
day_counter.shape

(31,)

In [12]:
print(ds["xtime_startDaily"])

<xarray.DataArray 'xtime_startDaily' (Time: 31)> Size: 2kB
[31 values with dtype=|S64]
Dimensions without coordinates: Time


In [13]:
print(ds["xtime_startDaily"].values)

[b'0010-01-01_00:00:00' b'0010-01-02_00:00:00' b'0010-01-03_00:00:00'
 b'0010-01-04_00:00:00' b'0010-01-05_00:00:00' b'0010-01-06_00:00:00'
 b'0010-01-07_00:00:00' b'0010-01-08_00:00:00' b'0010-01-09_00:00:00'
 b'0010-01-10_00:00:00' b'0010-01-11_00:00:00' b'0010-01-12_00:00:00'
 b'0010-01-13_00:00:00' b'0010-01-14_00:00:00' b'0010-01-15_00:00:00'
 b'0010-01-16_00:00:00' b'0010-01-17_00:00:00' b'0010-01-18_00:00:00'
 b'0010-01-19_00:00:00' b'0010-01-20_00:00:00' b'0010-01-21_00:00:00'
 b'0010-01-22_00:00:00' b'0010-01-23_00:00:00' b'0010-01-24_00:00:00'
 b'0010-01-25_00:00:00' b'0010-01-26_00:00:00' b'0010-01-27_00:00:00'
 b'0010-01-28_00:00:00' b'0010-01-29_00:00:00' b'0010-01-30_00:00:00'
 b'0010-01-31_00:00:00']


In [14]:
ice_area = ds["timeDaily_avg_iceAreaCell"]
ice_area.shape

(31, 465044)

In [15]:
ice_area.values

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], shape=(31, 465044), dtype=float32)

In [16]:
print(ds.coords)
print(ds.dims)

Coordinates:
    *empty*


In [17]:
print(ds)

<xarray.Dataset> Size: 407MB
Dimensions:                       (Time: 31, nCells: 465044, nVertices: 942873)
Dimensions without coordinates: Time, nCells, nVertices
Data variables:
    timeDaily_counter             (Time) int32 124B ...
    xtime_startDaily              (Time) |S64 2kB b'0010-01-01_00:00:00' ... ...
    xtime_endDaily                (Time) |S64 2kB ...
    timeDaily_avg_iceAreaCell     (Time, nCells) float32 58MB 0.0 0.0 ... 0.0
    timeDaily_avg_iceVolumeCell   (Time, nCells) float32 58MB ...
    timeDaily_avg_snowVolumeCell  (Time, nCells) float32 58MB ...
    timeDaily_avg_uVelocityGeo    (Time, nVertices) float32 117MB ...
    timeDaily_avg_vVelocityGeo    (Time, nVertices) float32 117MB ...
Attributes: (12/490)
    case:                                                         v3.LR.DTEST...
    source_id:                                                    9741e0bba2
    realm:                                                        seaIce
    product:              

# Freeboard calculation functions

In [18]:
import numpy as np

# Constants (adjust if you use different units)
D_WATER = 1023  # Density of seawater (kg/m^3)
D_ICE = 917     # Density of sea ice (kg/m^3)
D_SNOW = 330    # Density of snow (kg/m^3)

MIN_AREA = 1e-6

def compute_freeboard(area: np.ndarray, 
                      ice_volume: np.ndarray, 
                      snow_volume: np.ndarray) -> np.ndarray:
    """
    Compute sea ice freeboard from ice and snow volume and area.
    
    Parameters
    ----------
    area : np.ndarray
        Sea ice concentration / area (same shape as ice_volume and snow_volume).
    ice_volume : np.ndarray
        Sea ice volume per grid cell.
    snow_volume : np.ndarray
        Snow volume per grid cell.
    
    Returns
    -------
    freeboard : np.ndarray
        Freeboard height for each cell, same shape as inputs.
    """
    # Initialize arrays
    height_ice = np.zeros_like(ice_volume)
    height_snow = np.zeros_like(snow_volume)

    # Valid mask: avoid dividing by very small or zero area
    valid = area > MIN_AREA

    # Safely compute heights where valid
    height_ice[valid] = ice_volume[valid] / area[valid]
    height_snow[valid] = snow_volume[valid] / area[valid]

    # Compute freeboard using the physical formula
    freeboard = (
        height_ice * (D_WATER - D_ICE) / D_WATER +
        height_snow * (D_WATER - D_SNOW) / D_WATER
    )

    return freeboard


# Custom Pytorch Dataset

Example from NERSC of using ERA5 Dataset:
https://github.com/NERSC/dl-at-scale-training/blob/main/utils/data_loader.py

In [19]:
import os
import time

from torch.utils.data import Dataset
from datetime import timedelta
from typing import List, Union, Callable, Tuple
from NC_FILE_PROCESSING.nc_utility_functions import *
from perlmutterpath import *

# __ init __ - masks and loads the data into tensors

In [20]:
class DailyNetCDFDataset(Dataset):
    """
    PyTorch Dataset that concatenates a directory of month-wise NetCDF files
    along their 'Time' dimension and yields daily data *plus* its timestamp.

    Parameters
    ----------
    data_dir : str
        Directory containing NetCDF files
    transform : Callable | None
        Optional transform applied to the data tensor *only*.
    decode_time : bool
        Let xarray convert CF-style time coordinates to np.datetime64.
    drop_missing : bool
        If True, drops any days where one of the requested variables is missing.
    """
    def __init__(
        self,
        data_dir: str,
        transform: Callable = None,
        decode_time: bool = True,
        drop_missing: bool = True,
        cell_mask=None
    ):

        """ __init__ needs to 
        1) Gather the sorted daily data from each netCDF file (1 file = 1 month of daily data)
            The netCDF files contain nCells worth of data per day for each feature
            nCells = 465044 with the IcoswISC30E3r5 mesh
        2) Store the datetime information from each nCells array from the daily data
        3) Apply a mask to nCells to look just at regions above 40 degrees north (TODO: IMPLEMENT THE MASK)
        4) Patchify and store patch_ids so the data loader can use them (TODO: IMPLEMENT THIS)
        5) Perform pre-processing (calculate Freeboard from ice area, ice volume, and snow volume
        6) Normalize the data (TODO: IMPLEMENT THIS)
        7) Concatenate the data across Time """

        start_time = time.time()
        
        self.transform = transform

        # --- 1. Gather files (sorted for deterministic order) ---------
        self.data_dir = data_dir
        self.file_paths = sorted(
            [
                os.path.join(data_dir, f)
                for f in os.listdir(data_dir)
                if f.endswith(".nc")
            ]
        )
        print(f"Found {len(self.file_paths)} NetCDF files:")
        # for f in self.file_paths:
        #     print(f"  - {f}")     # Print all the file names in the folder

        if not self.file_paths:
            raise FileNotFoundError(f"No *.nc files found in {data_dir!r}")

        # Open all the netCDF files and concatenate them along Time dimension
        print("Loading datasets with xarray.open_mfdataset...")
        
        ds = xr.open_mfdataset(
            self.file_paths,
            combine="nested",
            concat_dim="Time", # Use the NetCDF's Time dimension for concatenation
            decode_times=False,
            parallel=False,
        )

        print("Finished loading full dataset into a local variable.")

        print(f"Dataset dimensions: {ds.dims}")
        print(f"Dataset variables: {list(ds.data_vars)}")
        
        # --- 2. Store a list of datetimes from each file -> helps with retrieving 1 day's data later
        all_times = []
        for path in self.file_paths:
            ds = xr.open_dataset(path)
        
            # Decode byte strings and fix the format
            xtime_strs = ds["xtime_startDaily"].str.decode("utf-8").values
            xtime_strs = [s.replace("_", " ") for s in xtime_strs]  # "0010-01-01_00:00:00" → "0010-01-01 00:00:00"
        
            # Convert to datetime.datetime objects
            times = [datetime.strptime(s, "%Y-%m-%d %H:%M:%S") for s in xtime_strs]
            all_times.extend(times)
        
        # Store in self.times
        self.times = all_times
        self.times = np.array(self.times, dtype='datetime64[s]')

        # Checking the dates
        print(f"Parsed {len(self.times)} total dates")
        print("First few:", self.times[:5])

        print(f"Total days collected: {len(self.times)}")
        print(f"Unique days: {len(np.unique(self.times))}")
        print(f"First 35 days: {self.times[:35]}")
        print(f"First days 360 to 400 days: {self.times[360:401]}")

        # --- 3. Apply a mask to the nCells
        # TODO: MASK DATA

        # --- 4. Get patch IDs
        # TODO: implement this
        # self.patch_ids = ???

        # --- 5. Derive Freeboard from ice area, snow volume and ice volume
        self.freeboard_all = []
        self.ice_area_all = []

        for path in self.file_paths:
            ds = xr.open_dataset(path)

            # Extract raw data
            area = ds["timeDaily_avg_iceAreaCell"].values
            ice_volume = ds["timeDaily_avg_iceVolumeCell"].values
            snow_volume = ds["timeDaily_avg_snowVolumeCell"].values

            # Optional mask
            if cell_mask is not None:
                area = area[:, cell_mask]
                ice_volume = ice_volume[:, cell_mask]
                snow_volume = snow_volume[:, cell_mask]
            
            freeboard = compute_freeboard(area, ice_volume, snow_volume)

            # These will be deleted later to save space
            self.freeboard_all.append(freeboard) 
            self.ice_area_all.append(area)

        
        # --- 6. Normalize the data (TODO: IMPLEMENT THIS)

        
        # --- 7. Concatenate the data across Time

        # Concatenate across time
        self.freeboard = np.concatenate(self.freeboard_all, axis=0)  # (T, nCells)
        self.ice_area = np.concatenate(self.ice_area_all, axis=0)    # (T, nCells)

        # Discard the lists that are not needed anymore -- save space
        del self.freeboard_all, self.ice_area_all

        print("Freeboard", self.freeboard.shape)
        print("Ice Area", self.ice_area.shape)

        print("End of __init__")

        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed time: {elapsed_time} seconds")

    def __len__(self) -> int:
        """ Returns how many time steps? (Days for Daily data) """
        
        print("Calling __len__")
        return len(self.times)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, np.datetime64]:
        """__ getitem __ needs to 
        
        1. Select one time step (ex. 1 day). 
        It currently returns (features, timestamp) for a single day.
        2. TODO: Return a set of patches for one time step
        Features are: [freeboard, ice_area] over masked cells. """

        start_time = time.time()
    
        # 1. Select timestep (day)
        print("Calling __getitem__")
    
        freeboard_day = self.freeboard[idx]  # shape: (nCells,)
        ice_area_day = self.ice_area[idx]    # shape: (nCells,)
        print("Freeboard shape", freeboard_day.shape)
        print("Ice Area shape", ice_area_day.shape)
        
        features = np.stack([freeboard_day, ice_area_day], axis=0)  # shape: (2, nCells)
        data_tensor = torch.as_tensor(features, dtype=torch.float32)
    
        if self.transform:
            data_tensor = self.transform(data_tensor)
            
        print(f"Fetched index {idx}: Time={self.times[idx]}, shape={data_tensor.shape}")

        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed time: {elapsed_time} seconds")
        
        return data_tensor, self.times[idx] # TODO: RETURN PATCHES INSTEAD OF ALL DATA PER DAY

    def __repr__(self):
        """ Format the string representation of the data """
        return (
            f"<DailyNetCDFDataset: {len(self)} days, "
            f"{len(self.freeboard[0])} cells/day, "
            f"{len(self.file_paths)} files loaded>"
        )

    def time_to_dataframe(self) -> pd.DataFrame:
            """Return a DataFrame of time features you can merge with predictions."""
            t = pd.to_datetime(self.times)            # pandas Timestamp index
            return pd.DataFrame(
                {
                    "time": t,
                    "year": t.year,
                    "month": t.month,
                    "day": t.day,
                    "doy": t.dayofyear,
                }
            )

# DataLoader

In [21]:
from torch.utils.data import DataLoader

print("===== Making the Dataset Class ===== ")
dataset = DailyNetCDFDataset(data_dir)

print("===== Printing Dataset ===== ")
print(dataset)                 # calls __repr__ → see how many files & days loaded
sample, ts = dataset[0]        # sample is tensor, ts is np.datetime64

# wrap in a DataLoader
loader = DataLoader(dataset, batch_size=8, shuffle=False)

# quickly get engineered time-features
# df_time = dataset.time_to_dataframe()


===== Making the Dataset Class ===== 
Found 12 NetCDF files:
Loading datasets with xarray.open_mfdataset...
Finished loading full dataset into a local variable.
Dataset variables: ['timeDaily_counter', 'xtime_startDaily', 'xtime_endDaily', 'timeDaily_avg_iceAreaCell', 'timeDaily_avg_iceVolumeCell', 'timeDaily_avg_snowVolumeCell', 'timeDaily_avg_uVelocityGeo', 'timeDaily_avg_vVelocityGeo']
Parsed 365 total dates
First few: ['0010-01-01T00:00:00' '0010-01-02T00:00:00' '0010-01-03T00:00:00'
 '0010-01-04T00:00:00' '0010-01-05T00:00:00']
Total days collected: 365
Unique days: 365
First 35 days: ['0010-01-01T00:00:00' '0010-01-02T00:00:00' '0010-01-03T00:00:00'
 '0010-01-04T00:00:00' '0010-01-05T00:00:00' '0010-01-06T00:00:00'
 '0010-01-07T00:00:00' '0010-01-08T00:00:00' '0010-01-09T00:00:00'
 '0010-01-10T00:00:00' '0010-01-11T00:00:00' '0010-01-12T00:00:00'
 '0010-01-13T00:00:00' '0010-01-14T00:00:00' '0010-01-15T00:00:00'
 '0010-01-16T00:00:00' '0010-01-17T00:00:00' '0010-01-18T00:00:00'
 

# Transformer

In [22]:
# import torch
# import torch.nn as nn

# # THIS ONE IS AN OPTION FOR IF I IMPLEMENT PATCHES

# class PatchTransformer(nn.Module):
#     def __init__(self, patch_dim, num_patches, d_model=128, nhead=8, num_layers=4):
#         super().__init__()

#         print("Calling __init__")
#         start_time = time.time()
        
#         self.patch_embed = nn.Linear(patch_dim, d_model)  # input projection

#         self.pos_embed = nn.Parameter(torch.randn(1, num_patches, d_model))  # learnable positional encoding

#         encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
#         self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

#         self.mlp_head = nn.Sequential(
#             nn.LayerNorm(d_model),
#             nn.Linear(d_model, 1)  # regression or classification
#         )

#         end_time = time.time()
#         elapsed_time = end_time - start_time
#         print(f"Elapsed time: {elapsed_time} seconds")
#         print("End of __init__")

#     def forward(self, x):
#         """
#         x: (batch_size, num_patches, patch_dim)
#         """

#         print("Calling forward")
#         start_time = time.time()
        
#         x = self.patch_embed(x) + self.pos_embed
#         x = self.encoder(x)  # (batch_size, num_patches, d_model)

#         # Option 1: Mean over tokens
#         x = x.mean(dim=1)  # (batch_size, d_model)

#         # attn: shape (num_layers, num_heads, num_tokens, num_tokens)
#         attn = self.transformer(..., output_attentions=True)

#         end_time = time.time()
#         elapsed_time = end_time - start_time
#         print(f"Elapsed time: {elapsed_time} seconds")
#         print("End of __init__")

#         return self.mlp_head(x)  # output shape: (batch_size, 1)


In [23]:
# import torch.nn as nn

# class IceForecastTransformer(nn.Module):
#     def __init__(self, input_dim, model_dim, n_heads, n_layers, output_dim):
#         super().__init__()
#         self.input_proj = nn.Linear(input_dim, model_dim)
#         encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=n_heads)
#         self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
#         self.head = nn.Linear(model_dim, output_dim)

#     def forward(self, x):
#         # x shape: (batch, seq_len, input_dim)
#         x = self.input_proj(x)
#         x = x.permute(1, 0, 2)  # (seq_len, batch, model_dim)
#         x = self.transformer(x)
#         x = x[-1]  # use final token for prediction
#         # attn: shape (num_layers, num_heads, num_tokens, num_tokens)
#         attn = self.transformer(..., output_attentions=True)
#         return self.head(x)


In [24]:
# from torch.utils.data import DataLoader
# import torch.optim as optim

# num_epochs = 100

# model = PatchTransformer(patch_dim=2, num_patches=100)  # TODO: adjust
# optimizer = optim.Adam(model.parameters(), lr=1e-4)
# criterion = nn.MSELoss()

# # This version expects patching to be done in the Dataset
# for epoch in range(num_epochs):
#     for x, _ in loader:  # x: (B, num_patches, patch_dim)
#         y_pred = model(x)
#         loss = criterion(y_pred, targets)  # you still need to define targets

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

TODO: Add Positional Encoding to represent time steps.

TODO: Use patch embedding (like in Vision Transformers).

TODO OPTION: Try temporal attention only (e.g., Informer, Time Series Transformer).