# Imports

In [76]:
import sys
print('System Version:', sys.version)

System Version: 3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]


In [77]:
print(sys.executable)

/global/homes/b/brelypo/.conda/envs/sic_sie_env/bin/python


In [78]:
import numpy as np
print('Numpy version', np.__version__)

Numpy version 2.2.6


In [79]:
import pandas as pd
print('Pandas version', pd.__version__)

Pandas version 2.3.0


In [80]:
import xarray as xr
print('Xarray version', xr.__version__)

Xarray version 2025.6.0


In [81]:
import matplotlib
import matplotlib.pyplot as plt
print('Matplotlib version', matplotlib.__version__)

Matplotlib version 3.10.3


In [82]:
import torch
from torch.utils.data import Dataset, DataLoader

print('PyTorch version', torch.__version__)

PyTorch version 2.5.1


# Custom Pytorch Dataset

In [129]:
ds = xr.open_dataset("train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-01-01.nc")

In [130]:
ds.data_vars

Data variables:
    timeDaily_counter             (Time) int32 124B ...
    xtime_startDaily              (Time) |S64 2kB ...
    xtime_endDaily                (Time) |S64 2kB ...
    timeDaily_avg_iceAreaCell     (Time, nCells) float32 58MB ...
    timeDaily_avg_iceVolumeCell   (Time, nCells) float32 58MB ...
    timeDaily_avg_snowVolumeCell  (Time, nCells) float32 58MB ...
    timeDaily_avg_uVelocityGeo    (Time, nVertices) float32 117MB ...
    timeDaily_avg_vVelocityGeo    (Time, nVertices) float32 117MB ...

In [131]:
day_counter = ds["timeDaily_counter"]
day_counter.shape

(31,)

In [134]:
print(ds["xtime_startDaily"])

<xarray.DataArray 'xtime_startDaily' (Time: 31)> Size: 2kB
[31 values with dtype=|S64]
Dimensions without coordinates: Time


In [135]:
print(ds["xtime_startDaily"].values)

[b'0010-01-01_00:00:00' b'0010-01-02_00:00:00' b'0010-01-03_00:00:00'
 b'0010-01-04_00:00:00' b'0010-01-05_00:00:00' b'0010-01-06_00:00:00'
 b'0010-01-07_00:00:00' b'0010-01-08_00:00:00' b'0010-01-09_00:00:00'
 b'0010-01-10_00:00:00' b'0010-01-11_00:00:00' b'0010-01-12_00:00:00'
 b'0010-01-13_00:00:00' b'0010-01-14_00:00:00' b'0010-01-15_00:00:00'
 b'0010-01-16_00:00:00' b'0010-01-17_00:00:00' b'0010-01-18_00:00:00'
 b'0010-01-19_00:00:00' b'0010-01-20_00:00:00' b'0010-01-21_00:00:00'
 b'0010-01-22_00:00:00' b'0010-01-23_00:00:00' b'0010-01-24_00:00:00'
 b'0010-01-25_00:00:00' b'0010-01-26_00:00:00' b'0010-01-27_00:00:00'
 b'0010-01-28_00:00:00' b'0010-01-29_00:00:00' b'0010-01-30_00:00:00'
 b'0010-01-31_00:00:00']


In [123]:
ice_area = ds["timeDaily_avg_iceAreaCell"]
ice_area.shape

(31, 465044)

In [124]:
print(ds.coords)
print(ds.dims)

Coordinates:
    *empty*


In [125]:
print(ds)

<xarray.Dataset> Size: 407MB
Dimensions:                       (Time: 31, nCells: 465044, nVertices: 942873)
Dimensions without coordinates: Time, nCells, nVertices
Data variables:
    timeDaily_counter             (Time) int32 124B ...
    xtime_startDaily              (Time) |S64 2kB ...
    xtime_endDaily                (Time) |S64 2kB ...
    timeDaily_avg_iceAreaCell     (Time, nCells) float32 58MB ...
    timeDaily_avg_iceVolumeCell   (Time, nCells) float32 58MB ...
    timeDaily_avg_snowVolumeCell  (Time, nCells) float32 58MB ...
    timeDaily_avg_uVelocityGeo    (Time, nVertices) float32 117MB ...
    timeDaily_avg_vVelocityGeo    (Time, nVertices) float32 117MB ...
Attributes: (12/490)
    case:                                                         v3.LR.DTEST...
    source_id:                                                    9741e0bba2
    realm:                                                        seaIce
    product:                                                     

In [126]:
import os
from typing import List, Union, Callable, Tuple

In [127]:
from torch.utils.data import Dataset
from datetime import timedelta
from NC_FILE_PROCESSING.nc_utility_functions import *

In [140]:
class DailyNetCDFDataset(Dataset):
    """
    PyTorch Dataset that concatenates a directory of month-wise NetCDF files
    along their 'Time' dimension and yields daily data *plus* its timestamp.

    Parameters
    ----------
    data_dir : str
        Directory containing NetCDF files (e.g. 202501.nc, 202502.nc, …).
    variable_names : str | List[str]
        Variable(s) to extract.  Default "timeDaily_avg_iceAreaCell".
    transform : Callable | None
        Optional transform applied to the data tensor *only*.
    decode_time : bool
        Let xarray convert CF-style time coordinates to np.datetime64.
    drop_missing : bool
        If True, drops any days where one of the requested variables is missing.
    """
    def __init__(
        self,
        data_dir: str,
        variable_names: Union[str, List[str]] = "timeDaily_avg_iceAreaCell",
        transform: Callable = None,
        decode_time: bool = True,
        drop_missing: bool = True,
    ):
        self.data_dir = data_dir
        self.transform = transform
        self.variable_names = (
            [variable_names] if isinstance(variable_names, str) else variable_names
        )

        # --- 1. Gather month files (sorted for deterministic order) ---------
        self.file_paths = sorted(
            [
                os.path.join(data_dir, f)
                for f in os.listdir(data_dir)
                if f.endswith(".nc")
            ]
        )
        print(f"Found {len(self.file_paths)} NetCDF files:")
        for f in self.file_paths:
            print(f"  - {f}")

        if not self.file_paths:
            raise FileNotFoundError(f"No *.nc files found in {data_dir!r}")

        # --- 2. Load & concatenate along time --------------------------------
        print("Loading datasets with xarray.open_mfdataset...")

        # --- 2. Load & concatenate along time --------------------------------
        self.ds = xr.open_mfdataset(
            self.file_paths,
            combine="nested",
            concat_dim="Time",       # capital "T"
            decode_times=False,
            parallel=False,
        )
        
        # List of datetimes from each file
        all_times = []
        
        for path in self.file_paths:
            ds = xr.open_dataset(path)
        
            # Decode byte strings and fix the format
            xtime_strs = ds["xtime_startDaily"].str.decode("utf-8").values
            xtime_strs = [s.replace("_", " ") for s in xtime_strs]  # "0010-01-01_00:00:00" → "0010-01-01 00:00:00"
        
            # Convert to datetime.datetime objects
            times = [datetime.strptime(s, "%Y-%m-%d %H:%M:%S") for s in xtime_strs]
            all_times.extend(times)
        
        # Store in self.times (can remain as list or convert to np.datetime64 if needed)
        self.times = all_times
        self.times = np.array(self.times, dtype='datetime64[s]')
        
        # Optional check
        print(f"Parsed {len(self.times)} total dates")
        print("First few:", self.times[:5])


        print(f"Total days collected: {len(self.times)}")
        print(f"Unique days: {len(np.unique(self.times))}")
        print(f"First 35 days: {self.times[:35]}")

        print("Finished loading dataset.")

        print(f"Dataset dimensions: {self.ds.dims}")
        print(f"Dataset variables: {list(self.ds.data_vars)}")

        # --- 3. Sub-select requested variables ------------------------------
        missing = [v for v in self.variable_names if v not in self.ds]
        if missing:
            raise KeyError(f"Variable(s) {missing} not found in dataset.")

        print(f"Subsetting variables: {self.variable_names}")
        self.subset = self.ds[self.variable_names]
        
    def __len__(self) -> int:
        return len(self.times)
        
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, np.datetime64]:
        daily = self.subset.isel(Time=idx).to_array().values
        data_tensor = torch.as_tensor(daily, dtype=torch.float32)

        if data_tensor.shape[0] == 1:
            data_tensor = data_tensor.squeeze(0)

        if self.transform:
            data_tensor = self.transform(data_tensor)

        print(f"Fetched index {idx}: Time={self.times[idx]}, shape={data_tensor.shape}")
        return data_tensor, self.times[idx]


    def time_to_dataframe(self) -> pd.DataFrame:
            """Return a DataFrame of time features you can merge with predictions."""
            t = pd.to_datetime(self.times)            # pandas Timestamp index
            return pd.DataFrame(
                {
                    "time": t,
                    "year": t.year,
                    "month": t.month,
                    "day": t.day,
                    "doy": t.dayofyear,
                }
            )

In [142]:
from torch.utils.data import DataLoader

data_dir = "/global/u2/b/brelypo/python_model_visualization/Predicting_SIC_SIE/train"
dataset = DailyNetCDFDataset(
    data_dir,
    variable_names=["timeDaily_avg_iceAreaCell", "timeDaily_avg_iceVolumeCell"],
)

print(dataset)                 # → see how many files & days loaded
sample, ts = dataset[0]        # sample is tensor, ts is np.datetime64

# wrap in a DataLoader as usual
loader = DataLoader(dataset, batch_size=8, shuffle=False)

# quickly get engineered time-features if you want them numerically
df_time = dataset.time_to_dataframe()


Found 12 NetCDF files:
  - /global/u2/b/brelypo/python_model_visualization/Predicting_SIC_SIE/train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-01-01.nc
  - /global/u2/b/brelypo/python_model_visualization/Predicting_SIC_SIE/train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-02-01.nc
  - /global/u2/b/brelypo/python_model_visualization/Predicting_SIC_SIE/train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-03-01.nc
  - /global/u2/b/brelypo/python_model_visualization/Predicting_SIC_SIE/train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-04-01.nc
  - /global/u2/b/brelypo/python_model_visualization/Predicting_SIC_SIE/train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-05-01.nc
  - /global/u2/b/brelypo/python_model_visualization/Predicting_SIC_SIE/train/v3.LR.DTESTM.pm-cpu-10yr.mpassi.hist.am.timeSeriesStatsDaily.0010-06-01.nc
  - /global/u2/b/brelypo/python_model_visualization/Predicting_SI