# Training the ConvNP Model 

In this notebook, we will preprocess Great Lakes data using an existing data processor, generate tasks for model training, and set up a training loop to train a **ConvNP** model using DeepSensor. We will:
1. Load and preprocess temporal and static datasets like **SST**, **Ice Concentration**, **Lake Mask**, and **Bathymetry**.
2. Load and use an existing **DataProcessor** to handle data normalization.
3. Generate tasks using **TaskLoader** and train the **ConvNP** model.
4. Monitor validation performance and track model training losses and RMSE (Root Mean Squared Error).

Let's begin by importing necessary packages and defining helper functions.

This notebook also implements a binary ice indicator as a context set

## Step 1: Import Packages and Define Helper Functions

We import the libraries required for:
- Data manipulation and visualization (`xarray`, `pandas`, `matplotlib`).
- Geospatial operations (`cartopy`).
- Efficient computation with Dask (`dask`).
- DeepSensor for data processing and model training (`deepsensor`).

Additionally, we import local helper functions such as `standardize_dates`, which standardizes the 'time' dimension in the dataset to a date-only format (`datetime64[D]`). We also define `generate_random_coordinates` and custom save and load functions, as the default functions in DeepSensor appear to be broken in this environment. 


In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import dask.array as da
import gcsfs
import os
import wandb
import sys 
sys.path.append(os.path.abspath(".."))
import deepsensor.torch
from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds
from deepsensor.data.sources import get_era5_reanalysis_data, get_earthenv_auxiliary_data, \
    get_gldas_land_mask
from deepsensor.model import ConvNP
from deepsensor.train import Trainer, set_gpu_default_device

# Local package utilities
from deepsensor_greatlakes.utils import standardize_dates, generate_random_coordinates, apply_mask_to_prediction
from deepsensor_greatlakes.model import save_model, load_convnp_model


def transform_ice(da):
    da = xr.DataArray(da)  # Ensure input is always an xarray.DataArray
    nan_mask = da.isnull()  # This correctly creates a mask in xarray
    transformed = xr.where(da > 0.2, 0, 1)  # Apply thresholding
    transformed = transformed.where(~nan_mask, np.nan)  # Preserve NaNs
    return transformed

In [2]:
set_gpu_default_device()

## Step 2: Data Inventory and Preprocessing

In this section, we load the required environmental datasets for model training:
- **Ice Concentration**: A dataset of ice cover over time in the Great Lakes.
- **GLSEA (Sea Surface Temperature)**: A dataset of sea surface temperature.
- **Bathymetry**: A dataset representing the underwater topography of the lakes.
- **Lake Mask**: A binary mask indicating water presence.

These datasets are loaded from storage and preprocessed by converting time into date-only format and handling missing data.


### User Inputs - Select Training and Validation Ranges

In [3]:
# Training/data config (adapted for Great Lakes)
#data_range = ("2009-01-01", "2022-12-31")
#train_range = ("2009-01-01", "2021-12-31")
#val_range = ("2022-01-01", "2022-12-31")
#date_subsample_factor = 10

# Just two years for demo purposes
# Previous Successes:
# data_range = ("2000-01-01", "2020-12-31")
# train_range = ("2000-01-01", "2015-12-31")
# val_range = ("2016-01-01", "2020-12-31")
# date_subsample_factor = 20

data_range = ("2000-01-01", "2002-12-31")
train_range = ("2000-01-01", "2001-12-31")
val_range = ("2002-01-01", "2002-12-31")
date_subsample_factor = 30

In [4]:
# Path to the Zarr stores (NOTE: This won't work on U-M HPC. Paths must be changed)
#bathymetry_path = 'gs://great-lakes-osd/context/interpolated_bathymetry.nc'
#mask_path = 'gs://great-lakes-osd/context/lakemask.nc'
#ice_concentration_path = 'gs://great-lakes-osd/ice_concentration.zarr'
#glsea_path = 'gs://great-lakes-osd/GLSEA_combined.zarr'
#glsea3_path = 'gs://great-lakes-osd/GLSEA3_combined.zarr'

# Path to the files on U-M HPC
bathymetry_path = '/nfs/turbo/seas-dannes/SST-sensor-placement-input/bathymetry/interpolated_bathymetry.nc'
mask_path = '/nfs/turbo/seas-dannes/SST-sensor-placement-input/masks/lakemask.nc'
ice_concentration_path = '/nfs/turbo/seas-dannes/SST-sensor-placement-input/NSIDC/ice_concentration.zarr'
glsea_path = '/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA_combined.zarr'
glsea3_path = '/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_combined.zarr'

# Paths to saved configurations
# Change model_path for new location
deepsensor_folder = '../deepsensor_config/'
model_path = '../saved_models/example_cover/'

## Step 3: Loading Temporal Datasets (Ice Concentration and GLSEA)

In this section, we load the **Ice Concentration** and **GLSEA** datasets stored in Zarr format. These datasets contain critical temporal information on ice cover and sea surface temperature.

We perform the following preprocessing:
1. Replace invalid land values (denoted by `-1`) with `NaN`.
2. Standardize the time dimension to date-only precision.
3. Drop unnecessary variables like **CRS**.

Let’s load and preprocess the data now.


In [5]:
# Open the Zarr stores
ice_concentration_raw = xr.open_zarr(ice_concentration_path, chunks={'time': 366, 'lat': 200, 'lon': 200})
glsea_raw = xr.open_zarr(glsea_path, chunks={'time': 366, 'lat': 200, 'lon': 200})
glsea3_raw = xr.open_zarr(glsea3_path, chunks={'time': 366, 'lat': 200, 'lon': 200})

# Replace -1 (land value) with NaN
ice_concentration_raw = ice_concentration_raw.where(ice_concentration_raw != -1, float('nan'))

# Convert all times to date-only format, removing the time component
ice_concentration_raw = standardize_dates(ice_concentration_raw)
glsea_raw = standardize_dates(glsea_raw)
glsea3_raw = standardize_dates(glsea3_raw)

# Drop CRS - not needed
glsea_raw = glsea_raw.drop_vars('crs')
glsea3_raw = glsea3_raw.drop_vars('crs')

# Apply transform_ice function while keeping it as xarray
ice_mask = xr.apply_ufunc(
    transform_ice,
    glsea3_raw["sst"],  # Ensure function applies only to 'sst' DataArray
    dask="allowed",  # Ensures correct Dask processing
    output_dtypes=[glsea3_raw["sst"].dtype],
    keep_attrs=True,  # Preserve metadata
)
ice_mask = standardize_dates(ice_mask)
ice_mask = ice_mask.rename("binary_ice_indicator")
print(ice_mask)

<xarray.DataArray 'binary_ice_indicator' (time: 6226, lat: 838, lon: 1181)> Size: 49GB
dask.array<where, shape=(6226, 838, 1181), dtype=float64, chunksize=(366, 200, 200), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float64 7kB 38.87 38.89 38.9 38.92 ... 50.58 50.59 50.61
  * lon      (lon) float64 9kB -92.42 -92.41 -92.39 ... -75.91 -75.9 -75.88
  * time     (time) datetime64[s] 50kB 2006-12-11 2006-12-12 ... 2023-12-31
Attributes:
    grid_mapping:   crs
    long_name:      Temperature
    standard_name:  sea_water_temperature
    units:          Celsius


## Step 4: Loading Static Datasets (Bathymetry and Lake Mask)

Next, we load two static datasets:
- **Bathymetry**: The underwater features of the Great Lakes.
- **Lake Mask**: A binary mask indicating water bodies within the lakes.

These datasets are loaded from NetCDF files and undergo basic preprocessing. 


In [6]:
# Open the NetCDF files using xarray 
bathymetry_raw = xr.open_dataset(bathymetry_path)
lakemask_raw = xr.open_dataset(mask_path)

# Name the bathymetry variable (only needed if reading from GCP)
#bathymetry_raw = bathymetry_raw.rename({'__xarray_dataarray_variable__': 'bathymetry'})

## Step 5: Initialize the Data Processor

The **DataProcessor** from DeepSensor is used to preprocess and normalize the datasets, getting them ready for model training. It applies scaling and transformation techniques to the datasets, such as **min-max scaling**.

We initialize the **DataProcessor** and apply it to the datasets. Below we load the `data_processor` that we fit in the last notebook. 


In [7]:
data_processor = DataProcessor(deepsensor_folder)
print(data_processor)

DataProcessor with normalisation params:
{'bathymetry': {'method': 'min_max',
                'params': {'max': 316.62872313037894,
                           'min': 9.999999999999998}},
 'coords': {'time': {'name': 'time'},
            'x1': {'map': (38.8749871947229, 55.4132976408956), 'name': 'lat'},
            'x2': {'map': (-92.4199507342304, -75.8816402880577),
                   'name': 'lon'}},
 'mask': {'method': 'min_max', 'params': {'max': 1.0, 'min': 0.0}},
 'sst': {'method': 'mean_std',
         'params': {'mean': 7.873531818389893, 'std': 6.944828510284424}}}


In [8]:
glsea = data_processor(glsea_raw)
# process the bathymetry and lake
aux_ds, lakemask_ds = data_processor([bathymetry_raw, lakemask_raw], method="min_max")

In [9]:
# # Select a subset of the ice concentration data to compute normalization parameters
# _ = data_processor(ice_concentration_raw.sel(time=slice("2009-01-01", "2009-12-31")))

# # Now apply the normalization parameters to the full ice concentration dataset
# ice_concentration = data_processor(ice_concentration_raw, method="min_max")

# _ = data_processor(ice_concentration_raw.sel(time=slice("2009-01-01", "2009-12-31")))

# ice_concentration_raw = ice_concentration_raw.chunk({'time': 30, 'lat': 200, 'lon': 200})
# ice_concentration = data_processor(ice_concentration_raw, method="min_max")\
ice_mask = ice_mask.chunk({"time": 1, "lat": 838, "lon": 1181})
print(ice_mask)
_ = data_processor(ice_mask.sel(time=slice("2008-01-01", "2008-04-01")))
print(ice_mask)
ice_ds = data_processor(ice_mask)


<xarray.DataArray 'binary_ice_indicator' (time: 6226, lat: 838, lon: 1181)> Size: 49GB
dask.array<rechunk-merge, shape=(6226, 838, 1181), dtype=float64, chunksize=(1, 838, 1181), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float64 7kB 38.87 38.89 38.9 38.92 ... 50.58 50.59 50.61
  * lon      (lon) float64 9kB -92.42 -92.41 -92.39 ... -75.91 -75.9 -75.88
  * time     (time) datetime64[s] 50kB 2006-12-11 2006-12-12 ... 2023-12-31
Attributes:
    grid_mapping:   crs
    long_name:      Temperature
    standard_name:  sea_water_temperature
    units:          Celsius
<xarray.DataArray 'binary_ice_indicator' (time: 6226, lat: 838, lon: 1181)> Size: 49GB
dask.array<rechunk-merge, shape=(6226, 838, 1181), dtype=float64, chunksize=(1, 838, 1181), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float64 7kB 38.87 38.89 38.9 38.92 ... 50.58 50.59 50.61
  * lon      (lon) float64 9kB -92.42 -92.41 -92.39 ... -75.91 -75.9 -75.88
  * time     (time) datetime64[s] 50kB 2006-12-1

In [10]:
data_processor.config

{'coords': {'time': {'name': 'time'},
  'x1': {'name': 'lat', 'map': (38.8749871947229, 55.4132976408956)},
  'x2': {'name': 'lon', 'map': (-92.4199507342304, -75.8816402880577)}},
 'sst': {'method': 'mean_std',
  'params': {'mean': 7.873531818389893, 'std': 6.944828510284424}},
 'bathymetry': {'method': 'min_max',
  'params': {'min': 9.999999999999998, 'max': 316.62872313037894}},
 'mask': {'method': 'min_max', 'params': {'min': 0.0, 'max': 1.0}},
 'binary_ice_indicator': {'method': 'mean_std',
  'params': {'mean': 0.6185053417359287, 'std': 0.4857535218400889}}}

In [11]:
dates = pd.date_range(glsea_raw.time.values.min(), glsea_raw.time.values.max(), freq="D") #converts existing data into data in requested time slice
dates = pd.to_datetime(dates).normalize()  # This will set all times to 00:00:00

In [12]:
doy_ds = construct_circ_time_ds(dates, freq="D")
aux_ds["cos_D"] = standardize_dates(doy_ds["cos_D"])
aux_ds["sin_D"] = standardize_dates(doy_ds["sin_D"])
aux_ds

## Step 7: Task Generation for Model Training

In this section, we use **TaskLoader** to generate tasks. A task consists of context data (input features like sea surface temperature, bathymetry, etc.) and target data (what we want the model to predict, such as ice concentration).

We generate tasks for training by sampling from the datasets. Each task represents a training example that the model will learn from.


In [15]:
task_loader = TaskLoader(context=[glsea, aux_ds, ice_ds, lakemask_ds], target=glsea)
task_loader

TaskLoader(4 context sets, 1 target sets)
Context variable IDs: (('sst',), ('bathymetry', 'cos_D', 'sin_D'), ('binary_ice_indicator',), ('mask',))
Target variable IDs: (('sst',),)

Context data dimensions: (1, 3, 1, 1)
Target data dimensions: (1,)

In [16]:
from tqdm import tqdm

# Define how Tasks are generated
def gen_tasks(dates, progress=True):
    tasks = []
    for date in tqdm(dates, disable=not progress):
        # Create task with context and target sampling
        # Here is re-randomization initialization
        random_lake_points = generate_random_coordinates(lakemask_raw, 500, data_processor)
        task = task_loader(date, context_sampling=random_lake_points, target_sampling="all")
        
        # Remove NaNs from the target data (Y_t) in the task 
        # Target data cannot have NaNs
        task = task.remove_target_nans()
        
        # Append the processed task to the list
        tasks.append(task)
        
    return tasks

In [17]:
# Generate training and validation tasks
train_dates = pd.date_range(train_range[0], train_range[1])[::date_subsample_factor]
val_dates = pd.date_range(val_range[0], val_range[1])[::date_subsample_factor]

# Standardize the dates so they are datetime64[D] (date only, no time)
train_dates = pd.to_datetime(train_dates).normalize()  # This will set the time to 00:00:00
val_dates = pd.to_datetime(val_dates).normalize()      # This will set the time to 00:00:00

# Generate the tasks
train_tasks = gen_tasks(train_dates)
val_tasks = gen_tasks(val_dates)

  0%|          | 0/25 [00:00<?, ?it/s]


KeyError: "not all values found in index 'time'. Try setting the `method` keyword argument (example: method='nearest')."

In [None]:
# Visualizes task at index 10 --> sanity check
train_tasks[10]

In [None]:
fig = deepsensor.plot.task(val_tasks[2], task_loader)
plt.show()
# To Save in batch job
# plt.savefig("your_name.png")
#plt.close()

## Step 8: Model Setup and Training

We now set up the **ConvNP** model, a neural process-based model from **DeepSensor**. We use the **DataProcessor** and **TaskLoader** as inputs to the model, which allows the model to handle context and target data properly during training.

The model is then trained for a set number of epochs, and we monitor its performance by tracking the training loss and validation RMSE (Root Mean Squared Error).

At the end of the training loop, we save the best-performing model.


In [None]:
# Set up model
model = ConvNP(data_processor, task_loader)

In [None]:
# Define the Trainer and training loop
trainer = Trainer(model, lr=5e-6)

In [None]:
# Monitor validation performance
def compute_val_rmse(model, val_tasks):
    errors = []
    target_var_ID = task_loader.target_var_IDs[0][0]  # assuming 1st target set and 1D
    for task in val_tasks:
        mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True)
        true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True)
        errors.extend(np.abs(mean - true))
    return np.sqrt(np.mean(np.concatenate(errors) ** 2))

In [None]:
from tqdm import tqdm

losses = []
val_rmses = []
val_rmse_best = np.inf

# weights and biases definition if using
run = wandb.init(
    project="deepsensor-greatlakes",
    name="wandb_tester",
    config={
        "contexts": "sst, bathymetry, binary_ice_indicator, ice_concentration",
        "sampling": "dynamic-500",
        "years": "2000-2020",
        "epochs": 50
    })

for epoch in tqdm(range(50), desc="Training Epochs"):
    train_tasks = gen_tasks(train_dates)
    batch_losses = trainer(train_tasks)
    loss_mean = np.mean(batch_losses)
    losses.append(loss_mean)

    try:
        run.log({"loss": loss_mean, "epoch": epoch})
    except Exception as e:
        print("error logging loss:", e)

    val_rmse = compute_val_rmse(model, val_tasks)
    val_rmses.append(val_rmse)
    run.log({"val_rmse": val_rmse, "epoch": epoch})

    if val_rmse < val_rmse_best:
        val_rmse_best = val_rmse
        save_model(model, model_path)

        try:
            artifact = wandb.Artifact("trained_model", type="model")
            artifact.add_dir(model_path)
            run.log_artifact(artifact)
        except Exception as e:
            print("error logging model artifact:", e)

# Plot training losses and validation RMSE in Weights and Biases
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(losses)
axes[0].set_xlabel('Epoch')
axes[0].set_title('Training Loss')

axes[1].plot(val_rmses)
axes[1].set_xlabel('Epoch')
axes[1].set_title('Validation RMSE')

plt.tight_layout()
try:
    wandb.log({"loss_vs_rmse": wandb.Image(fig)})
except Exception as e:
    print("error logging plot:", e)

run.finish()


In [None]:
# To load it later:
# Assuming you have data_processor and task_loader instantiated in your notebook
loaded_model = load_convnp_model(model_path, data_processor, task_loader)
print("Model loaded successfully with custom function!")

## Step 9: Prediction

Now that we have a trained model, we can use it to make a prediction. Notice that we get both a mean and standard deviation from this prediciton. 

In [None]:
#Choose date in validation range
date = "2018-02-14"
test_task = task_loader(date, context_sampling=random_lake_points, target_sampling="all")
prediction_ds = loaded_model.predict(test_task, X_t=glsea_raw)
prediction_ds

In [None]:
prediction_ds_masked = apply_mask_to_prediction(prediction_ds['sst'], lakemask_raw)
prediction_ds_masked

Note that the prediction produces both a mean prediction and a standard deviation, which is a characteristic of a Gaussian Process approach. 

In [None]:
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
prediction_ds_masked['mean'].plot(cmap='viridis', cbar_kwargs={'label': 'Predicted Mean SST'})
plt.title(f'Masked Predicted Mean SST for Single Day')
plt.xlabel('Longitude')
plt.ylabel('Latitude')

plt.subplot(1, 2, 2) 
prediction_ds_masked['std'].plot(cmap='plasma', cbar_kwargs={'label': 'Predicted Std SST'})
plt.title(f'Masked Predicted Std SST for Single Day')
plt.xlabel('Longitude')
plt.ylabel('Latitude')

plt.tight_layout()
plt.show()
# For Batch Job saving:
# plt.savefig("your_name.png")
#plt.close()

The above plot looks really bizarre because it has only been trained on two years of data! DeepSensor's models are data hungry...

# Conclusion

In this notebook, we:
1. Loaded and preprocessed several Great Lakes datasets for training a **ConvNP** model.
2. Generated tasks using **TaskLoader** and visualized data to perform sanity checks.
3. Trained the **ConvNP** model and monitored its performance.

Next, we will explore the active learning component of **DeepSensor**.
