This notebook is an iteration of the IFS model to take advantage of presaved X and Y arrays to allow shuffling and multiple works in the training dataloader. 

In [1]:
from IPython.display import clear_output
import json
import os
import random
import glob
import shutil
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import time
import copy
import xarray as xr
from scipy.spatial import KDTree
from datetime import datetime
from sklearn.model_selection import train_test_split
from scipy.special import inv_boxcox
import scipy.stats as stats
import matplotlib.pyplot as plt
# from denseweight import DenseWeight
# import smogn
from collections import defaultdict

import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torch.utils.tensorboard import SummaryWriter
print(torch.__version__)
# from fwx.utilities import calc_windspeed

from IPython.display import clear_output
import os
import shutil
import numpy as np
import time
import xarray as xr
import glob
import pandas as pd
import pyarrow.parquet as pq

import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models

2.5.1


In [None]:
from pytorch_lightning.profilers import PyTorchProfiler
import torch

profiler = PyTorchProfiler(
    on_trace_ready=torch.profiler.tensorboard_trace_handler(
        "/gscratch/kylabazlen/ml_out/tb_profiler_logs"  # No trailing space
    ),
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
    profile_memory=True,
    record_shapes=True,
)

In [None]:
print(profiler.summary())

In [None]:
#check node
import socket
print(f"Running on: {socket.gethostname()}")

In [None]:
print(f'Found {torch.cuda.device_count()} GPUs')
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_properties(i).name)

# device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
# print(f"Using {device} device")

In [None]:

def calc_windspeed(
    wind_v: np.ndarray | xr.DataArray,
    wind_u: np.ndarray | xr.DataArray,
) -> np.ndarray | xr.DataArray:
    """
    Compute wind speed from eastward (u) and northward (v) wind velocities

    Parameters
    ----------
    wind_u: np.ndarray | xr.DataArray
        east (+) west (-) wind velocity (m/s)
    wind_v: np.ndarray | xr.DataArray
        north (+) south (-) wind velocity (m/s)

    Return
    ---------
    : np.ndarray | xr.DataArray
        Wind speed (m/s)
    """
    assert wind_v.shape == wind_u.shape, "arguments have mismatched shape"
    return np.hypot(wind_u, wind_v)

In [None]:
def intersect_time_latlon(ds: xr.Dataset,
                        qc_ds: xr.Dataset,
                )-> tuple[xr.Dataset, xr.Dataset]:
    """
    Intersect two datasets on both time and exact (latitude, longitude) pairs.

    Parameters
    ----------
    ds : xr.Dataset
        First dataset with dims (time, space) and coordinates (latitude, longitude).
    qc_ds : xr.Dataset
        Second dataset with dims (time, space) and coordinates (latitude, longitude).

    Returns
    -------
    (xr.Dataset, xr.Dataset)
        Subsetted datasets with only common times and lat, lons.
    """
    common_time = np.intersect1d(ds.time.values, qc_ds.time.values)
    ds = ds.sel(time=common_time)
    qc_ds = qc_ds.sel(time=common_time)

    #get lat lon pairs
    coords_ds = set(zip(ds.latitude.values, ds.longitude.values))
    coords_qc = set(zip(qc_ds.latitude.values, qc_ds.longitude.values))
    common_coords = coords_ds & coords_qc

    #Masks
    mask_ds = [(lat, lon) in common_coords for lat, lon in zip(ds.latitude.values, ds.longitude.values)]
    mask_qc = [(lat, lon) in common_coords for lat, lon in zip(qc_ds.latitude.values, qc_ds.longitude.values)]

    #subset datasets by index
    ds = ds.isel(space=np.array(mask_ds))
    qc_ds = qc_ds.isel(space=np.array(mask_qc))

    return ds, qc_ds

def compute_qc_fail_mask(dset_qc: xr.Dataset, bits_to_mask: list[int]) -> np.ndarray:
    """
    Compute boolean QC failure mask for a QC dataset.

    Returns True where any specified bits fail for any variable.
    """
    bits = [b for b in bits_to_mask if b != 1]
    mask = sum(1 << (b - 1) for b in bits)

    combined_fails = None
    for var in dset_qc.data_vars:
        qc_values = np.nan_to_num(dset_qc[var].values, nan=0).astype("int32")
        # qc_values = dset_qc[var].fillna(0).astype("int32") #maybe faster? idk 
        fails = (qc_values & mask) != 0
        combined_fails = (
            fails if combined_fails is None else (combined_fails | fails)
        )

    return combined_fails

def qc_threshold(ds: xr.Dataset, 
                 threshold_variable: list[str],
                 threshold: int,
                 mask_variable: list[str] = None):


    if mask_variable is None:
        mask_variable = threshold_variable

    combined_mask = None
    for var in threshold_variable:
        mask = ds[var] > threshold
        combined_mask = mask if combined_mask is None else (combined_mask | mask)

    for mvar in mask_variable:
        if mvar in ds:
             ds[mvar] = ds[mvar].where(~combined_mask)

    return ds
#spatial checkerboard of cowy into train and testing datasets
def spatial_blocking(
    cowy: "CoWyPointDataset",
    block_size: float,
    n_folds: int,
    ):
    """
    Split dataset into spatially blocked, checkerboard train and test datasets.

    Parameters
    ----------
    cowy : 
        Pytorch object... 
    block_size : float, optional
        Size of spatial blocks in coordinate units (e.g., degrees).
    n_folds : int, optional
        Number of spatial folds, 2 folds results in equall allocation of
        checkers to test and train datasets.

    Returns
    -------
    ds_train : xr.Dataset
        Subset of the original dataset containing training locations.
    ds_test : xr.Dataset
        Subset of the original dataset containing testing locations.

    """
    # Extract coordinate arrays
    lat = cowy.lat_obs
    lon = cowy.lon_obs

    # Anchor the grid at the dataset's min lat/lon (ensures reproducibility)
    lat_min = float(np.floor(lat.min()))
    lon_min = float(np.floor(lon.min()))

    # Compute block indices
    block_y = np.floor((lat - lat_min) / block_size)
    block_x = np.floor((lon - lon_min) / block_size)

    # Compute fold IDs (checkerboard assignment)
    fold_id = (block_x + block_y) % n_folds

    # Create masks
    train_mask = fold_id != 0
    test_mask = fold_id == 0
    # Convert masks to indices
    train_idx = np.where(train_mask)[0]
    test_idx = np.where(test_mask)[0]

    # Subset dataset
    train_ds = copy.copy(cowy) #shallow copy
    
    train_ds.obs_lookup = cowy.obs_lookup[
        cowy.obs_lookup['idx_obs'].isin(train_idx)
    ].reset_index(drop=True)

    test_ds = copy.copy(cowy)
    test_ds.obs_lookup = cowy.obs_lookup[cowy.obs_lookup['idx_obs'].isin(test_idx)].reset_index(drop=True)

    return train_ds, test_ds
#not used
def add_raster_variable(
    ds: xr.Dataset,
    path: str,
) -> xr.Dataset:
    """
    Add variable from .nc file with coordinates latitude and longitude.

    Performs vectorized nearest-neighbor lookup to extract values from a
    gridded raster dataset at each meteorological station location.

    Parameters
    ----------
    ds : xarray.Dataset
        Dataset containing latitude and longitude of met station
        indexed by space.
    path : str
        Path to variable netcdf file.

    Returns
    -------
    xarray.Dataset
        Dataset with an added variable, by name variable_name_out.
    """

    raster_ds = xr.open_mfdataset(path, compat='no_conflicts')

    for var_name in raster_ds.data_vars:
        vals = raster_ds[var_name].interp(
            latitude=ds["latitude"],
            longitude=ds["longitude"],
            method="nearest",
        )
        ds[var_name] = vals

    return ds

In [None]:
def _hypsometric_equation(p1, p2, t1, t2):
    r = 287 #specific gas constant J/kg/K
    g = 9.81 #m/s2
    t_mean = (t1+t2)/2
    return (r * t_mean) /g * np.log(p1 / p2)

def compute_elr(p_surface, t_surface, p_upper, t_upper, z_surface=2.0):
    """
    Compute environmental lapse rate between surface and an upper level.

    Parameters
    ----------
    p_surface : float or array
        Surface pressure (Pa).
    t_surface : float or array
        Surface temperature (K).
    p_upper : float or array
        Upper-level pressure (Pa).
    t_upper : float or array
        Upper-level temperature (K).
    z_surface : float, optional
        Surface height (m). Default is 2.

    Returns
    -------
    elr : float or array
        Lapse rate (K/m).
    """
    # Height difference using hypsometric equation
    dz = _hypsometric_equation(p_surface, p_upper, t_surface, t_upper)
    z_upper = z_surface + dz

    # Lapse rate K/m
    elr = (t_upper - t_surface) / dz
    return elr

In [None]:
class CoWyPointDataset(Dataset):
    """This object returns a single observation from the training or validation data.
    """
    def __init__(self,
                 source_madis_fps,
                 source_hrrr_fps,
                 source_topo_fps,
                 dist_lim = 0.25, #dist_lim=0.027,
                 shuffle=True,
                 obs_meta_cache = None):
        """
        ***Shuffle here and NOT in the DataLoader object!!!***
        ***Only use num_workers=0 with the DataLoader object!!!***

        Parameters
        ----------
        dist_lim : float
            ~ 3km in decimal degrees = 0.027
            ~ 0.25 degrees is IFS resolution ~ 25km
        shuffle : bool
            Flag to shuffle timesteps
        """

        self.dset_madis_open = xr.open_dataset(source_madis_fps, engine='h5netcdf')
        self.dset_madis = self.dset_madis_open[['windspeed_10m']]

        #read in hrrr removing step attirbute
        self.dsets_hrrr = {}
        for fp in source_hrrr_fps:
            ds = xr.open_dataset(fp, decode_cf=False)
            if "step" in ds and "dtype" in ds["step"].attrs:
                ds["step"].attrs.pop("dtype")
            ds = xr.decode_cf(ds)
            ds['ws_10'] = calc_windspeed(wind_v=ds['v10'], wind_u=ds['u10'])

            # Convert step from timedelta to hours and add as a variable
            step_hours = ds.step.values / np.timedelta64(1, 'h')
            ds['lead_time_hrs'] = xr.DataArray(
            np.full((ds.dims['latitude'], ds.dims['longitude']), step_hours, dtype='float32'),
            dims=['latitude', 'longitude']
            )

            #encode initialization time as 0 or 1
            init_hour = pd.Timestamp(ds['time'].values).hour
            init_z = 1.0 if init_hour == 12 else 0.0
            ds['init_z'] = xr.DataArray(
            np.full((ds.dims['latitude'], ds.dims['longitude']), init_z, dtype='float32'),
            dims=['latitude', 'longitude']
            )

            #remove "surface"
            ds = ds.drop_vars("surface", errors="ignore")

            # Compute ELR
            p_surface = ds["sp"].astype("float32")
            t_surface = ds["t2m"].astype("float32")
            p_500 = 50000.0
            t_500 = ds["t_500hPa"].astype("float32")

            ds["elr"] = compute_elr(
                p_surface=p_surface,
                t_surface=t_surface,
                p_upper=p_500,
                t_upper=t_500,
            ).astype("float32")

            # Add cyclical encoding of month and day
                # adds 4 variables (hour_sin, hour_cos, day_sin, day_cos)
            # ds = self.encode_time(ds)
            # ds = ds.drop_vars(["time"], errors="ignore")

            #1 D lat lon
            lat = ds['latitude'].values
            lon = ds['longitude'].values

            # Create 2D meshgrid for latitude and longitude
            lat_2d, lon_2d = np.meshgrid(lat, lon, indexing='ij')  #shape (y, x)

            # Add the 2D latitude and longitude to the dataset as coordinates
            ds = ds.assign_coords({
                "latitude": (("y", "x"), lat_2d),
                "longitude": (("y", "x"), lon_2d)
            })

            # Reindex the dataset to use x and y as dimensions
            ds = ds.swap_dims({"latitude": "y", "longitude": "x"})  # Swap dimensions

            # Drop the original 1D latitude and longitude dimensions
            ds = ds.drop_dims(["lat", "lon"], errors="ignore")

            self.dsets_hrrr[fp] = ds

        #only keep over lapping time steps.

        #read in terrain data
        self.dset_topo = xr.open_mfdataset(
            source_topo_fps,
            chunks={'latitude': 1000, 'longitude': 1000},
            engine="h5netcdf"
            )

        self.vars_topo = [v for v in self.dset_topo.data_vars if self.dset_topo[v].ndim == 2]

        self.dist_lim = dist_lim
        self.shuffle = shuffle

        # self.hrrr_madis_cache_fp = hrrr_madis_cache_fp
        self.cache_madis = None
        self.cache_hrrr = None
        self.cache_timestamp = None

        # sample hrrr single timestep for meta data calculation
        dset_hrrr = list(self.dsets_hrrr.values())[0]

        # "valid_time" timestamp to hrrr filepath lookup
        self.times_hrrr = {pd.to_datetime(dset['valid_time'].values): fp 
                           for fp, dset in self.dsets_hrrr.items()}

        ## CHANGE TO THIS AND UPDATE ALL OTHER CODE TO ACCEPT A LIST OF FILE PATHS!!!
        # "valid_time" timestamp to hrrr filepath lookup (multiple files per time)
        #     #multipl files per valid time for working with multiple forecast lead times of ifs
        # self.times_hrrr = defaultdict(list)
        # for fp, dset in self.dsets_hrrr.items():
        #     vt = pd.to_datetime(dset["valid_time"].values)
        #     self.times_hrrr[vt].append(fp)

        #filter madis to only the valid times present in forecasts
        valid_times = np.array(list(self.times_hrrr.keys()), dtype="datetime64[ns]")
        mask = np.isin(self.dset_madis["time"].values.astype("datetime64[ns]"), valid_times)
        self.dset_madis = self.dset_madis.isel(time=mask)
        self.ti_madis = pd.to_datetime(self.dset_madis["time"].values)

        # self.ti_madis = pd.to_datetime(self.dset_madis['time'])
        self.vars_hrrr = list(dset_hrrr.keys())
        self.vars_madis = list(self.dset_madis.keys())

        #load madis into memory once!!!
        self.madis_arrays = {}
        for var in self.vars_madis:
            self.madis_arrays[var] = self.dset_madis[var].values

        self.lat_obs = self.dset_madis['latitude'].values
        self.lon_obs = self.dset_madis['longitude'].values
        self.lat_hrrr = dset_hrrr['latitude'].values
        self.lon_hrrr = dset_hrrr['longitude'].values
        #topo originally comes as 1D lat/lon
        lat1d_topo = self.dset_topo['latitude'].values
        lon1d_topo = self.dset_topo['longitude'].values
        #make topo lat lon 2d to match grid (and hrrr workflow which comes in x,y)
        self.lon_topo, self.lat_topo = np.meshgrid(lon1d_topo, lat1d_topo)
        self.shape_topo = self.lat_topo.shape

        coords_hrrr = (self.lat_hrrr.flatten(), self.lon_hrrr.flatten())
        coords_hrrr = np.vstack(coords_hrrr).T
        coords_obs = (self.lat_obs, self.lon_obs)
        coords_obs = np.vstack(coords_obs).T

        # self.nn_ind will give you the HRRR index for a requesting obs index
        self.tree = KDTree(coords_hrrr)
        self.nn_dist, self.nn_ind = self.tree.query(coords_obs)

        self.size_hrrr = self.lat_hrrr.size
        self.shape_hrrr = self.lat_hrrr.shape
        self.ind_hrrr = np.arange(self.size_hrrr)
        self.ind_hrrr = self.ind_hrrr.reshape(self.shape_hrrr)

        #kdtree for topo data to match grid to point obs
        coords_topo = (self.lat_topo.flatten(), self.lon_topo.flatten()) #coordinates at grid cells x,y
        coords_topo = np.vstack(coords_topo).T

        self.tree_topo = KDTree(coords_topo) #give index of nearest gridcell 
        self.nn_dist_topo, self.nn_ind_topo = self.tree_topo.query(coords_obs)
        
        #splitting flat index from kdtree into x and y indexs
        self.idy_topo_pre, self.idx_topo_pre = np.unravel_index(
            self.nn_ind_topo, self.shape_topo
        )
        #convert index into int32
        self.idy_topo_pre = self.idy_topo_pre.astype(np.int32)
        self.idx_topo_pre = self.idx_topo_pre.astype(np.int32)
        #wrap indexes in data array objects
        idy = xr.DataArray(self.idy_topo_pre, dims="nobs")
        idx = xr.DataArray(self.idx_topo_pre, dims="nobs")

        topo_list = []
        #loop through each topo variable and select value at nearest grid cell
        for v in self.vars_topo:
            topo_list.append(self.dset_topo[v].isel(latitude=idy, longitude=idx))
        print("end topo list")
        #cmobine topo features into per observation array
        self.topo_per_obs = (
            xr.concat(topo_list, dim="var")
            .compute()
            .to_numpy()
            .T
            .astype(np.float32)
        )

        self.idx_obs_inbounds = np.where(self.nn_dist < self.dist_lim)[0]

        self.obs_lookup = None

        if obs_meta_cache is None:
            self.set_obs_lookup()
        elif isinstance(obs_meta_cache, str):
            df = pd.read_csv(obs_meta_cache)
            valid_fps = set(source_hrrr_fps)
            df = df[df['fp_hrrr'].isin(valid_fps)].reset_index(drop=True)

            # column types to match what set_obs_lookup() produces
            df['idx_obs'] = df['idx_obs'].astype(np.int32)
            df['idt_madis'] = df['idt_madis'].astype(np.int32)
            df['idy_hrrr'] = df['idy_hrrr'].astype(np.int32)
            df['idx_hrrr'] = df['idx_hrrr'].astype(np.int32)
            df['idy_topo'] = df['idy_topo'].astype(np.int32)
            df['idx_topo'] = df['idx_topo'].astype(np.int32)
            df['latitude'] = df['latitude'].astype(np.float32)
            df['longitude'] = df['longitude'].astype(np.float32)

            #convert timestampsto pandas.Timestamp
            df['timestamp'] = pd.to_datetime(df['timestamp'])

            self.obs_lookup = df

        #add obs number column
        self.obs_lookup['obs_number'] = np.arange(len(self.obs_lookup))

    def encode_time(self,
                    ds: xr.Dataset) -> xr.Dataset:
        """Add cyclical time encodings (hour and day of year) as 2D spatial grids."""
        
        # Get shape and dims from first 2D variable
        spatial_vars = [v for v in ds.data_vars if ds[v].ndim == 2]
        if not spatial_vars:
            raise ValueError("No 2D variables found in dataset")
        
        sample_var = ds[spatial_vars[0]]
        ny, nx = sample_var.shape
        spatial_dims = list(sample_var.dims)
    
        # Extract time
        timestamp = pd.to_datetime(ds['time'].values)
        hour = timestamp.hour
        doy = timestamp.timetuple().tm_yday
        
        # Add cyclical encodings (sin/cos pairs preserve periodicity)
        ds['hour_sin'] = (spatial_dims, np.full((ny, nx), np.sin(2*np.pi*hour/24), dtype=np.float32))
        ds['hour_cos'] = (spatial_dims, np.full((ny, nx), np.cos(2*np.pi*hour/24), dtype=np.float32))
        ds['doy_sin'] = (spatial_dims, np.full((ny, nx), np.sin(2*np.pi*doy/366), dtype=np.float32))
        ds['doy_cos'] = (spatial_dims, np.full((ny, nx), np.cos(2*np.pi*doy/366), dtype=np.float32))
        
        return ds

    def set_obs_lookup(self): #new build set obs lookup collecting -- adds to the table once per timestep not once per obs
        """Set all of the meta data for each observation index so you can quickly 
        look up timestep and corresponding HRRR file

        This can take longer for full years of data with lots of obs
        """
        time_index = list(self.ti_madis.copy())

        #initialize lists
        data = {
            'idx_obs': [],
            'idt_madis': [],
            'timestamp': [],
            'fp_hrrr': [],
            'idy_hrrr': [],
            'idx_hrrr': [],
            'idy_topo': [],
            'idx_topo': [],
            'latitude':[],
            'longitude':[],
        }

        for idt, timestamp in enumerate(time_index):
            idt_madis = np.where(self.ti_madis == timestamp)[0][0]
            fp_hrrr = self.times_hrrr.get(timestamp, None)

            if fp_hrrr is not None:
                obs_arr_ws = self.madis_arrays['windspeed_10m'][idt_madis]
                idx_obs_notnan = np.where(~np.isnan(obs_arr_ws))[0]
                idx_obs = sorted(set(idx_obs_notnan).intersection(set(self.idx_obs_inbounds)))
                
                if len(idx_obs) == 0:
                    continue
                
                idx_obs_arr = np.array(idx_obs)
                n_obs = len(idx_obs_arr)
                
                hrrr_positions = np.array([np.where(self.ind_hrrr == self.nn_ind[idx]) 
                                        for idx in idx_obs_arr])
                idy_hrrr_arr = hrrr_positions[:, 0, 0]
                idx_hrrr_arr = hrrr_positions[:, 1, 0]

                idy_topo_arr = np.array([self.idy_topo_pre[idx] for idx in idx_obs_arr])
                idx_topo_arr = np.array([self.idx_topo_pre[idx] for idx in idx_obs_arr])

                # #append data
                # data['idx_obs'].extend(idx_obs_arr)
                # data['idt_madis'].extend([idt_madis] * n_obs)
                # data['timestamp'].extend([timestamp] * n_obs)
                # data['fp_hrrr'].extend([fp_hrrr] * n_obs)
                # data['idy_hrrr'].extend(idy_hrrr_arr)
                # data['idx_hrrr'].extend(idx_hrrr_arr)
                # data['idy_topo'].extend(idy_topo_arr)
                # data['idx_topo'].extend(idx_topo_arr)
                # data['latitude'].extend(self.lat_obs[idx_obs])
                # data['longitude']. extend(self.lon_obs[idx_obs])

                #look over hrrr files for each valid time.
                if not isinstance(fp_hrrr, list):
                    fp_hrrr = [fp_hrrr]
                for fp in fp_hrrr:
                    data['idx_obs'].extend(idx_obs_arr)
                    data['idt_madis'].extend([idt_madis] * n_obs)
                    data['timestamp'].extend([timestamp] * n_obs)
                    data['fp_hrrr'].extend([fp] * n_obs)
                    data['idy_hrrr'].extend(idy_hrrr_arr)
                    data['idx_hrrr'].extend(idx_hrrr_arr)
                    data['idy_topo'].extend(idy_topo_arr)
                    data['idx_topo'].extend(idx_topo_arr)
                    data['latitude'].extend(self.lat_obs[idx_obs])
                    data['longitude'].extend(self.lon_obs[idx_obs])
        
        # Create DataFrame from complete data
        self.obs_lookup = pd.DataFrame(data)
        
        # Convert types
        for col in ['idx_obs', 'idt_madis', 'idy_hrrr', 'idx_hrrr']:
            self.obs_lookup[col] = self.obs_lookup[col].astype(int)

    def shuffle_timesteps(self):
        """Shuffle obs lookup while keeping timesteps together for efficient IO"""
        random_order = pd.Series(
            np.random.permutation(self.obs_lookup['timestamp'].unique()),
            index=self.obs_lookup['timestamp'].unique()
        )
        
        self.obs_lookup['random_order'] = self.obs_lookup['timestamp'].map(random_order)
        
        self.obs_lookup = self.obs_lookup.sort_values('random_order')
        self.obs_lookup = self.obs_lookup.drop('random_order', axis=1)
        self.obs_lookup = self.obs_lookup.reset_index(drop=True)

    def get_cached_arrs(self, idt_madis, timestamp):
        """Get an single timestep of HRRR and Obs data cached as numpy array 

        Because we group all spatial obs within a given time index its more 
        efficient to cache all data for that time index as numpy array and 
        iterate through that before the next time index is requested
        """
        fp_hrrr = self.times_hrrr.get(timestamp, None)
        dset_hrrr = self.dsets_hrrr[fp_hrrr]

        # load first timestep into memory
        if self.cache_madis is None or self.cache_hrrr is None:
            self.cache_timestamp = dset_hrrr['valid_time'].values.copy()
            self.cache_hrrr = [dset_hrrr[var].values for var in self.vars_hrrr]
            self.cache_hrrr = np.dstack(self.cache_hrrr)
            # self.cache_madis = [self.dset_madis.isel(time=idt_madis)[var].values for var in self.vars_madis]
            self.cache_madis = [self.madis_arrays[var][idt_madis] for var in self.vars_madis]
            self.cache_madis = np.vstack(self.cache_madis).T
            self.cache_hrrr = self.cache_hrrr.astype(np.float32)
            self.cache_madis = self.cache_madis.astype(np.float32)

        # New timestep, load into memory
        if self.cache_timestamp != dset_hrrr['valid_time'].values:
            self.cache_timestamp = dset_hrrr['valid_time'].values.copy()
            self.cache_hrrr = [dset_hrrr[var].values for var in self.vars_hrrr]
            self.cache_hrrr = np.dstack(self.cache_hrrr)
            # self.cache_madis = [self.dset_madis.isel(time=idt_madis)[var].values for var in self.vars_madis]
            self.cache_madis = [self.madis_arrays[var][idt_madis] for var in self.vars_madis]
            self.cache_madis = np.vstack(self.cache_madis).T
            self.cache_hrrr = self.cache_hrrr.astype(np.float32)
            self.cache_madis = self.cache_madis.astype(np.float32)
    
        return self.cache_hrrr, self.cache_madis

    def __len__(self):
        """Get the number of observations per epoch"""
        return len(self.obs_lookup)

    def __getitem__(self, idx, check_coords=True):
        """This returns the single observation inputs and target numpy array.
        Must be float32
        """

        if idx == 0 and self.shuffle:
            self.shuffle_timesteps()

        idx_obs = self.obs_lookup.at[idx, 'idx_obs']
        timestamp = self.obs_lookup.at[idx, 'timestamp']
        idt_madis = self.obs_lookup.at[idx, 'idt_madis']
        idy_hrrr = self.obs_lookup.at[idx, 'idy_hrrr']
        idx_hrrr = self.obs_lookup.at[idx, 'idx_hrrr']
        
        idt_arr_hrrr, idt_arr_madis = self.get_cached_arrs(idt_madis, timestamp)

        if check_coords:
            assert np.abs(self.lat_obs[idx_obs] - self.lat_hrrr[idy_hrrr, idx_hrrr]) < self.dist_lim
            assert np.abs(self.lon_obs[idx_obs] - self.lon_hrrr[idy_hrrr, idx_hrrr]) < self.dist_lim
        
        hrrr_inputs = idt_arr_hrrr[idy_hrrr, idx_hrrr]
        topo_inputs = self.topo_per_obs[idx_obs]
        inputs = np.concatenate([hrrr_inputs, topo_inputs]).astype(np.float32)

        target = idt_arr_madis[idx_obs]
        if np.isnan(target).any():
            return self.__getitem__((idx + 1) % len(self))
            
        return inputs, target

In [None]:
source_madis_fps = '/project/cowy-nvhackathon/cowy-wildfire/data/observations/cowy_madis_metar_mesonet_2024.nc'
source_hrrr_fps = sorted(glob.glob('/project/cowy-nvhackathon/cowy-wildfire/data/nwp/ifs/*.nc'))
source_topo_fps  = '/project/cowy-nvhackathon/cowy-wildfire/data/terrain_data/terrain_990m/*_reprojected_wgs84_cowy_990m.nc'

In [None]:
obs_lookup = '/gscratch/kylabazlen/ml_out/test_2/obs_lookup.csv'
cowy = CoWyPointDataset(source_madis_fps,
                        source_hrrr_fps,
                        source_topo_fps,
                        obs_meta_cache=None)

cowy.obs_lookup.to_csv('/gscratch/kylabazlen/ml_out/test_2/obs_lookup.csv', index=False)

In [None]:
## RUN ONCE
#run onces to get x and y arrays saved
X_all = []
Y_all = []

#builds x and y in the order of obs lookup
# Result: X_all[i] and Y_all[i] correspond to obs_lookup.iloc[i]

for i in range(len(cowy)):
    x, y = cowy[i]
    X_all.append(x)
    Y_all.append(y)

# More explicit stacking
X_all = np.vstack(X_all).astype(np.float32)  # shape: (n_samples, n_features)
Y_all = np.array(Y_all, dtype=np.float32)    # shape: (n_samples,) if y is scalar

mean = np.nanmean(X_all, axis = 0)
sd = np.nanstd(X_all, axis = 0)

# OR
# Y_all = np.vstack(Y_all).astype(np.float32)  # shape: (n_samples, 1) for 2d
np.save("/gscratch/kylabazlen/ml_out/test_1/mean.npy", mean)
np.save("/gscratch/kylabazlen/ml_out/test_1/sd.npy", sd)

np.save("/gscratch/kylabazlen/ml_out/test_1/X_all.npy", X_all)
np.save("/gscratch/kylabazlen/ml_out/test_1/Y_all.npy", Y_all)


In [None]:
def balance_train_classes(train_ds, thresholds=[11.17]):
    """
    Balance training dataset by upsampling minority classes to match majority class.
    
    Bins observations into wind speed ranges and upsamples each bin to match
    the size of the largest bin (typically low wind speeds).
    
    Parameters:
    ----------
    train_ds : CoWyPointDataset
        The training dataset object containing the observation table (obs_lookup).
    thresholds : list[float]
        Thresholds defining wind speed bins. E.g., [5, 10] creates bins:
        [0-5), [5-10), [10+) 11.17 = 25mph = ws for red flag warning.
    
    Returns:
    -------
    train_ds : CoWyPointDataset
        The updated training dataset with balanced obs_lookup table.
    """
    print("Starting class balancing on training data...")
    
    obs_lookup = train_ds.obs_lookup.copy()
    print(f"Initial training obs_lookup size: {len(obs_lookup)}")
    
    # Get target variable (observed wind speed)
    obs_ws_all = train_ds.dset_madis["windspeed_10m"].values
    obs_ws = obs_ws_all[obs_lookup["idt_madis"].values, obs_lookup["idx_obs"].values]
    
    # Define bin edges
    thresholds = sorted(thresholds)
    edges = [0] + thresholds + [np.inf]
    
    # Separate data into bins
    bins = {}
    bin_sizes = {}
    
    for i, (lo, hi) in enumerate(zip(edges[:-1], edges[1:])):
        mask = (obs_ws >= lo) & (obs_ws < hi)
        bin_samples = obs_lookup[mask]
        label = f"{lo}-{hi}" if hi != np.inf else f">{lo}"
        
        bins[label] = bin_samples
        bin_sizes[label] = len(bin_samples)
        print(f"Bin {label}: {len(bin_samples)} samples")
    
    # Find maximum bin size (typically low wind speeds)
    max_size = max(bin_sizes.values())
    print(f"\nTarget size for all bins: {max_size}")
    
    # Upsample each bin to match max_size
    balanced_bins = []
    for label, bin_samples in bins.items():
        current_size = len(bin_samples)
        
        if current_size == 0:
            print(f"Warning: Bin {label} is empty, skipping")
            continue
        
        if current_size < max_size:
            # Randomly sample with replacement to reach max_size
            upsampled = bin_samples.sample(
                n=max_size,
                replace=True,
                random_state=42
            ).reset_index(drop=True)
            print(f"Bin {label}: upsampled from {current_size} to {len(upsampled)}")
        else:
            # Already at max (this is the majority class)
            upsampled = bin_samples.reset_index(drop=True)
            print(f"Bin {label}: kept at {len(upsampled)}")
        
        balanced_bins.append(upsampled)
    
    # Combine all balanced bins
    updated_obs_lookup = pd.concat(balanced_bins, ignore_index=True)
    
    # Optional: shuffle for better training dynamics
    updated_obs_lookup = updated_obs_lookup.sample(frac=1, random_state=42).reset_index(drop=True)
    
    print(f"\nUpdated training obs_lookup size: {len(updated_obs_lookup)}")
    print(f"Expansion factor: {len(updated_obs_lookup) / len(obs_lookup):.2f}x")
    
    # Update the training dataset's obs_lookup table
    train_ds.obs_lookup = updated_obs_lookup
    
    print("Class balancing completed.")
    return train_ds

In [None]:
# Recompute mean and std on balanced training set
X_balanced = []
Y_balanced = []
for i in range(len(train_ds)):
    x, y = train_ds[i]
    X_balanced.append(x)
    Y_balanced.append(y)

X_balanced = np.vstack(X_balanced).astype(np.float32)
Y_balanced = np.array(Y_balanced, dtype=np.float32)

mean_balanced = np.nanmean(X_balanced, axis=0)
std_balanced = np.nanstd(X_balanced, axis=0)

# Save or use these for normalization
np.save("/gscratch/kylabazlen/ml_out/test_1/mean_balanced.npy", mean_balanced)
np.save("/gscratch/kylabazlen/ml_out/test_1/std_balanced.npy", std_balanced)

np.save("/gscratch/kylabazlen/ml_out/test_1/X_all_balanced.npy", X_balanced)
np.save("/gscratch/kylabazlen/ml_out/test_1/Y_all_balanced.npy", Y_balanced)

In [None]:
#2° x 2° spatial blocking
train_ds, test_ds = spatial_blocking(cowy, block_size=2, n_folds=2)

# 80/20 train/val split by index
# train_idx and val_idx are integer arrays of row positions with train or val_ds.obslookup not cowy.obs_lookup that contrain 1 - len train_ds (len(val_ds))
train_idx, val_idx = train_test_split(
    np.arange(len(train_ds)),
    test_size=0.2,
    random_state=42
)

val_ds = copy.deepcopy(train_ds)

train_ds.obs_lookup = train_ds.obs_lookup.iloc[train_idx].reset_index(drop=True)
val_ds.obs_lookup   = val_ds.obs_lookup.iloc[val_idx].reset_index(drop=True)

#balance training class
train_ds = balance_train_classes(train_ds, thresholds=[11.1])

In [None]:
save_dir ='/gscratch/kylabazlen/ml_out/test_1/'

In [None]:
# save obs look ups for each split
train_ds.obs_lookup.to_csv(os.path.join(save_dir, "train_obs_lookup.csv"), index=False)
val_ds.obs_lookup.to_csv(os.path.join(save_dir, "val_obs_lookup.csv"), index=False)
test_ds.obs_lookup.to_csv(os.path.join(save_dir, "test_obs_lookup.csv"), index=False)

In [None]:
#save feature order
with open("/gscratch/kylabazlen/ml_out/test_1/vars_hrrr.txt", "w") as f:
    for var in cowy.vars_hrrr:
        f.write(f"{var}\n")

## Runnnig with presaved arrays ##

In [3]:
save_dir = '/gscratch/kylabazlen/ml_out/test_1/outputs/'

In [4]:
class FixedNorm(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        # register buffers so they move with the model (.to(device))
        self.register_buffer('mean', torch.as_tensor(mean))
        self.register_buffer('std', torch.as_tensor(std))

    def forward(self, x):
        return (x - self.mean) / self.std

In [5]:
class PointCorrectionModel(L.LightningModule):
    """Custom model that uses nn.Linear layers to take features as input
    and predict point observations for bias correction."""

    def __init__(self, n_inputs, n_outputs, mean, std, config, dense_weight_obj=None):
        """
        Parameters
        ----------
        n_inputs : int
            Number of input channels. This will be axis=1 in the input tensor
        n_outputs : int
            Number of output channels. This will be axis=1 in the output tensor
        mean : tensor
            Mean for normalization
        std : tensor
            Standard deviation for normalization
        config : dict
            Hyperparameters including number of channels in the latent space and optimizer keyword arguments.
        """
        super().__init__()
        self.n_inputs = n_inputs
        self.config = config
        self.dw = dense_weight_obj
        self.n_filters = config['n_filters']

        self.model = nn.Sequential(
            FixedNorm(mean, std),
            nn.Linear(in_features=n_inputs, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Dropout(p=self.config['drop_out_p']),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Dropout(p=self.config['drop_out_p']),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Dropout(p=self.config['drop_out_p']),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Dropout(p=self.config['drop_out_p']),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=self.n_filters), nn.LeakyReLU(),
            nn.Linear(in_features=self.n_filters, out_features=n_outputs),
        )

        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        
        # L1 regularization on ALL model parameters
        l1_norm = sum(p.abs().sum() for p in self.parameters())
        l1_penalty = self.config['l1_lambda'] * l1_norm
        total_loss = loss + l1_penalty

        self.log("train_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True)
        return total_loss

    def validation_step(self, batch, batch_idx):
        """Validate on batch with RMSE loss"""
        x, y = batch
        y_hat = self(x) #forward passing
        loss = self.loss(y_hat, y)
        self.log("validation_loss", loss, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        """Get AdamW optimizer"""
        optimizer = AdamW(self.parameters(),
                         lr = self.config['learning_rate'],
                         eps = self.config['eps'],
                         weight_decay = self.config['weight_decay'])
        return optimizer

    def forward(self, X):
        """Forward pass through the model"""
        X = self.model(X)
        return X
    
    def loss(self, y_pred, y_true):
        """MSE loss + L1 penalty on feature weights"""
        mse_loss = nn.functional.mse_loss(y_pred, y_true)
        return mse_loss

In [17]:
config = {
    'batch_size': 2500,
    'learning_rate': 0.0001,
    'eps' : 0.00001,
    'max_epochs': 50,
    'n_filters': 640,
    'weight_decay': 0.0001, #"L2" regularization
    'drop_out_p': 0.25,
    'l1_lambda': 0.000001,
}

with open(f'{save_dir}config.json', "w") as f:
    json.dump(config, f, indent=2)

In [18]:
X_all = np.load("/gscratch/kylabazlen/ml_out/test_1/X_all.npy")
Y_all = np.load("/gscratch/kylabazlen/ml_out/test_1/Y_all.npy")
X_all = torch.from_numpy(X_all)
Y_all = torch.from_numpy(Y_all)

mean = np.load("/gscratch/kylabazlen/ml_out/test_1/mean.npy")
sd = np.load("/gscratch/kylabazlen/ml_out/test_1/sd.npy")

In [19]:
#read in obs lookup files
train_ds = pd.read_csv('/gscratch/kylabazlen/ml_out/test_1/train_obs_lookup.csv')
test_ds_obs_lookup = pd.read_csv('/gscratch/kylabazlen/ml_out/test_1/test_obs_lookup.csv')
val_obs_lookup = pd.read_csv('/gscratch/kylabazlen/ml_out/test_1/val_obs_lookup.csv')

In [20]:
train_ds = TensorDataset(X_all[train_ds['obs_number'].values], Y_all[train_ds['obs_number'].values])
val_ds = TensorDataset(X_all[val_obs_lookup['obs_number'].values], Y_all[val_obs_lookup['obs_number'].values])
test_ds = TensorDataset(X_all[test_ds_obs_lookup['obs_number'].values], Y_all[test_ds_obs_lookup['obs_number'].values])

In [21]:
train_dataloader = DataLoader(train_ds,
                              batch_size=config['batch_size'],
                              shuffle=True, #shuffle true for training
                              num_workers=4) #parallel loading

val_dataloader   = DataLoader(val_ds,
                              batch_size=config['batch_size'],
                              shuffle=False, 
                              num_workers=0)

#shuffle false for testing to keep alignment with obs lookup
test_dataloader  = DataLoader(test_ds,
                              batch_size=config['batch_size'],
                              shuffle=False,
                              num_workers=0)

xsample, ysample = train_ds[0]

In [22]:
ckpt_path = None

model = PointCorrectionModel(
    n_inputs=len(xsample),
    n_outputs=len(ysample),
    mean=mean,
    std=sd,
    config=config,
    dense_weight_obj = None
)

In [23]:
# loggers
tb_logger = TensorBoardLogger(
    "/gscratch/kylabazlen/ml_out/test_1/lightning_logs", 
    name="PointCorrectionModel"
)
# version number tied to TensorBoard logger
logger_version = tb_logger.version

csv_logger = CSVLogger(
    "/gscratch/kylabazlen/ml_out/test_1/lightning_logs",
    name="PointCorrectionModel",
    version=logger_version,
)

early_stop_callback = EarlyStopping(monitor="validation_loss", min_delta=0.00, patience=500, verbose=False) #not real early stopping.. 
model_ckpt_callback = ModelCheckpoint(monitor="validation_loss",
                                      dirpath=f'/gscratch/kylabazlen/ml_out/test_1/lightning_logs/PointCorrectionModel/version_{tb_logger.version}/', 
                                      verbose=False)

# Lightning trainer
trainer = L.Trainer(
    accelerator="auto",
    profiler=None,
    max_epochs=config["max_epochs"],
    logger=[csv_logger, tb_logger],
    log_every_n_steps=10,
    callbacks=[early_stop_callback, model_ckpt_callback]
)

trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=ckpt_path)

/project/cowy-nvhackathon/software/cowy-wildfire-envs/cowy-wildfire-env/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /project/cowy-nvhackathon/software/cowy-wildfire-env ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/project/cowy-nvhackathon/software/cowy-wildfire-envs/cowy-wildfire-env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /cluster/medbow/gscratch/kylabazlen/ml_out/test_1/lightning_logs/PointCorrectionModel/version_1 exists and is not empty.

  | Name  | Type       | Params | Mode  | FLOPs
-----------------------------------------------------
0 | model | Sequential | 7.4 M  | train | 0    
-----------------------------------------------------
7.4 M     Trainable params
0       

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/project/cowy-nvhackathon/software/cowy-wildfire-envs/cowy-wildfire-env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 49: 100%|██████████| 1184/1184 [04:42<00:00,  4.20it/s, v_num=1, train_loss_step=1.360, train_loss_epoch=1.350]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 1184/1184 [04:42<00:00,  4.20it/s, v_num=1, train_loss_step=1.360, train_loss_epoch=1.350]


In [None]:
# %load_ext tensorboard
# %tensorboard --logdir='/gscratch/kylabazlen/ml_out/lightning_logs'

In [None]:
def forward_pass(model, dataloader):
    preds = []
    obs = []
    model.eval()

    with torch.no_grad():
        for x, y in dataloader:
            y_hat = model(x)
            preds.append(y_hat)
            obs.append(y)
    
    #Concatenate as tensors, then convert to numpy
    preds = torch.cat(preds, dim=0).numpy()
    obs = torch.cat(obs, dim=0).numpy()
    
    return preds, obs

In [None]:
# ## WITH OLD MODEL RUN
# full_ckpt = ckpt_path  # Use the manually set checkpoint path
# version_dir = os.path.dirname(full_ckpt)

#  WITH NEW Checkpoint
full_ckpt = model_ckpt_callback.best_model_path
version_dir = os.path.dirname(full_ckpt)

In [None]:
save_dir = os.path.join(version_dir, "best_ckpt_results")
os.makedirs(save_dir, exist_ok=True)

In [None]:
#load in OR compute and save predictions for a forward pass of the model for each split.
# save_dir = os.path.join(version_dir, "best_ckpt_results")
os.makedirs(save_dir, exist_ok=True)

#pick one file to check
check = os.path.join(save_dir, "pred_train.npy")

if os.path.exists(check):
    print("Found saved forward-pass results → loading...")

    pred_train = np.load(os.path.join(save_dir, "pred_train.npy"))
    obs_train  = np.load(os.path.join(save_dir, "obs_train.npy"))

    pred_val = np.load(os.path.join(save_dir, "pred_val.npy"))
    obs_val  = np.load(os.path.join(save_dir, "obs_val.npy"))

    pred_test = np.load(os.path.join(save_dir, "pred_test.npy"))
    obs_test  = np.load(os.path.join(save_dir, "obs_test.npy"))

else:
    print("No saved results → computing and saving...")

    pred_train, obs_train = forward_pass(model, train_dataloader)
    pred_val, obs_val     = forward_pass(model, val_dataloader)
    pred_test, obs_test   = forward_pass(model, test_dataloader)

    np.save(os.path.join(save_dir, "pred_train.npy"), pred_train)
    np.save(os.path.join(save_dir, "obs_train.npy"),  obs_train)

    np.save(os.path.join(save_dir, "pred_val.npy"), pred_val)
    np.save(os.path.join(save_dir, "obs_val.npy"),  obs_val)

    np.save(os.path.join(save_dir, "pred_test.npy"), pred_test)
    np.save(os.path.join(save_dir, "obs_test.npy"),  obs_test)

    print("Saved results to:", save_dir)

In [None]:
def extract_data(dataloader):
    inputs = []
    obs = []
    
    for x, y in dataloader:
        inputs.append(x)
        obs.append(y)
    
    inputs = torch.cat(inputs, dim=0).numpy()
    obs = torch.cat(obs, dim=0).numpy()
    
    return inputs, obs

x_train, obs_train_bl = extract_data(train_dataloader)
x_val, obs_val_bl = extract_data(val_dataloader)
x_test, obs_test_bl = extract_data(test_dataloader)

In [None]:
save_directory = "/gscratch/kylabazlen/ml_out/test_1"

np.save(os.path.join(save_directory, "x_train_bl.npy"), x_train)
np.save(os.path.join(save_directory, "x_val_bl.npy"), x_val)
np.save(os.path.join(save_directory, "x_test_bl.npy"), x_test)

np.save(os.path.join(save_directory, "obs_train_bl.npy"), obs_train_bl)
np.save(os.path.join(save_directory, "obs_val_bl.npy"), obs_val_bl)
np.save(os.path.join(save_directory, "obs_test_bl.npy"), obs_test_bl)

In [None]:
def evaluate_bins(pred, obs, thresholds=[5,10]):
    pred = pred.flatten()
    obs = obs.flatten()
    diff = pred - obs

    thresholds = sorted(thresholds)
    edges = [0] + thresholds + [np.inf]

    results = {}
    # all ws
    ss_res = np.sum(diff**2)
    ss_tot = np.sum((obs - obs.mean())**2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else np.nan

    # all ws
    results["all"] = {
        "Count": len(obs),
        "MBE": diff.mean(),
        "RMSE": np.sqrt(np.mean(diff**2)),
        "R2": r2,
    }

    # bins
    for lo, hi in zip(edges[:-1], edges[1:]):
        mask = (obs >= lo) & (obs < hi)
        d = diff[mask]
        obs_bin = obs[mask]

        if d.size == 0:
            mbe = np.nan
            rmse = np.nan
            r2_bin = np.nan
        else:
            mbe = d.mean()
            rmse = np.sqrt(np.mean(d**2))
            ss_res_bin = np.sum(d**2)
            ss_tot_bin = np.sum((obs_bin - obs_bin.mean())**2)
            r2_bin = 1 - (ss_res_bin / ss_tot_bin) if ss_tot_bin > 0 else np.nan

        label = f"{lo}-{hi}" if hi != np.inf else f">{lo}"
        results[label] = {
            "Count": int(mask.sum()),
            "MBE": mbe,
            "RMSE": rmse,
            "R2": r2_bin,
        }

    return results

def print_stats_table(results):
    """
    Pretty-print MBE/RMSE/R² bin stats from evaluate_bins().
    Expected format:
        results = {
            'all': {'MBE':..., 'RMSE':..., 'R2':..., 'Count':...},
            '0-5': {...},
            ...
        }
    """

    # Header
    print(f"{'Bin':<8} {'Count':>10} {'MBE':>10} {'RMSE':>10} {'R²':>10}")
    print("-" * 50)

    # Print all bins EXCEPT all
    for key in results:
        if key == "all":
            continue
        stats = results[key]
        print(f"{key:<8} {stats['Count']:>10} {stats['MBE']:>10.5f} {stats['RMSE']:>10.5f} {stats['R2']:>10.5f}")

    # Separator
    print("-" * 50)

    # Print ALL
    stats = results["all"]
    print(f"{'All':<8} {stats['Count']:>10} {stats['MBE']:>10.5f} {stats['RMSE']:>10.5f} {stats['R2']:>10.5f}")

In [None]:
thresholds = [5, 10]
train_stats = evaluate_bins(pred_train, obs_train, thresholds=thresholds)
print("Train Split")
print_stats_table(train_stats)
val_stats = evaluate_bins(pred_val, obs_val, thresholds=thresholds)
print("Validation Split")
print_stats_table(val_stats)
test_stats = evaluate_bins(pred_test, obs_test, thresholds=thresholds)
print("Test Split")
print_stats_table(test_stats)

In [None]:
with open("/gscratch/kylabazlen/ml_out/test_1/vars_hrrr.txt", "r") as f:
    vars_hrrr = [line.strip() for line in f]

ws_idx = vars_hrrr.index("ws_10")
print(ws_idx)

In [None]:
print("Baseline IFS Stats 72-f96")
thresholds = [5,10]
train_stats = evaluate_bins(x_train[:, ws_idx], obs_train_bl, thresholds=thresholds)
print("Train Split")
print_stats_table(train_stats)
val_stats = evaluate_bins(x_val[:, ws_idx], obs_val_bl, thresholds=thresholds)
print("Validation Split")
print_stats_table(val_stats)
test_stats = evaluate_bins(x_test[:, ws_idx], obs_test_bl, thresholds=thresholds)
print("Test Split")
print_stats_table(test_stats)

### Plots

In [None]:
# import matplotlib.pyplot as plt
# import cartopy.crs as ccrs
# import geopandas as gpd
# import cartopy.io.shapereader as shpreader
# import numpy as np

# dem = xr.open_dataset('/projects/cowy/datasets/terrain_data/terrain_990m/conus_elev_reprojected_wgs84_cowy_990m.nc')
# # Data
# train_lon = train_ds.obs_lookup['longitude']
# train_lat = train_ds.obs_lookup['latitude']
# test_lon  = test_ds.obs_lookup['longitude']
# test_lat  = test_ds.obs_lookup['latitude']

# # Load states
# states_shp = shpreader.natural_earth(
#     resolution="10m",
#     category="cultural",
#     name="admin_1_states_provinces"
# )
# gdf_states = gpd.read_file(states_shp)

# # Figure
# fig, ax = plt.subplots(figsize=(8, 6), subplot_kw={'projection': ccrs.PlateCarree()})

# # DEM
# im = ax.pcolormesh(
#     dem['longitude'],
#     dem['latitude'],
#     dem['HGT'],
#     cmap='cubehelix',
#     shading='auto',
#     transform=ccrs.PlateCarree()
# )
# plt.colorbar(im, ax=ax, label="Elevation (m)")

# # Scatter points
# ax.scatter(train_lon, train_lat, s=15, c='blue', alpha=.5, edgecolors='lightgray',linewidth=0.2, label='Train', transform=ccrs.PlateCarree())
# ax.scatter(test_lon, test_lat, s=15, c='red', alpha=.5, edgecolors='lightgray',linewidth=0.2, label='Test', transform=ccrs.PlateCarree())

# # Plot states
# gdf_states.plot(ax=ax, facecolor='none', edgecolor='black', linewidth=0.8, transform=ccrs.PlateCarree())

# # Gridlines
# gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
# gl.top_labels = False
# gl.right_labels = False

# # Map extent
# ax.set_extent([
#     test_lon.min(), test_lon.max(),
#     test_lat.min(),  test_lat.max()
# ], crs=ccrs.PlateCarree())

# # This sets 1 degree latitude = 1 degree longitude in screen space
# lon_range = dem['longitude'].max() - dem['longitude'].min()
# lat_range = dem['latitude'].max() - dem['latitude'].min()
# ax.set_aspect(lon_range / lat_range)

# # Title & legend
# ax.set_title("Training and Test Set Observations")
# ax.legend(loc="upper right", framealpha=1, facecolor="white", edgecolor="black")

# plt.tight_layout()
# plt.show()
