In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import fsspec

from deepsensor.data import DataProcessor, TaskLoader
from deepsensor.model import ConvNP
from deepsensor.train import Trainer



In [2]:
# Define the paths to your datasets in GCS (from the JSON file)
glsea_path = 'gs://great-lakes-osd/zarr_experimental/glsea'
glsea3_path = 'gs://great-lakes-osd/zarr_experimental/glsea3'
bathymetry_path = 'gs://great-lakes-osd/context/interpolated_bathymetry.nc'
lakemask_path = 'gs://great-lakes-osd/context/lakemask.nc'

# Use fsspec to open the Zarr dataset from GCS
fs = fsspec.filesystem('gcs', project='great-lakes-osd')  

# Open the Zarr datasets for GLSEA and GLSEA3
glsea = xr.open_zarr(fsspec.get_mapper(glsea_path), consolidated=False)
glsea3 = xr.open_zarr(fsspec.get_mapper(glsea3_path), consolidated=False)

# Open the bathymetry and lake mask as NetCDF files from GCS
bathymetry = xr.open_dataset(fsspec.get_mapper(bathymetry_path), engine='netcdf4')
lakemask = xr.open_dataset(fsspec.get_mapper(lakemask_path), engine='netcdf4')

# Preprocess the SST data: compute anomalies (subtract climatology)
climatology = glsea.sst.groupby('time.dayofyear').mean('time')
sst_anomalies = glsea.sst.groupby('time.dayofyear') - climatology

# Handle missing values (e.g., setting NaNs to a small value)
sst_anomalies = sst_anomalies.where(np.isnan(sst_anomalies) == False, -0.009)

# Process data (e.g., normalize) using DataProcessor
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
sst_anomalies_ds = data_processor(sst_anomalies)
bathymetry_ds = data_processor(bathymetry)
lakemask_ds = data_processor(lakemask)


GroupNotFoundError: group not found at path ''

In [None]:
# Create a TaskLoader to load context (bathymetry, lakemask) and target (sst anomalies)
task_loader = TaskLoader(context=[sst_anomalies_ds, lakemask_ds, bathymetry_ds], target=sst_anomalies_ds)

# Generate training and validation tasks for specific time ranges
train_tasks = []
val_tasks = []

# Create tasks for training (2007-2014) and validation (2015-2016)
for date in pd.date_range('2007-01-01', '2014-12-31', freq='D'):
    task = task_loader(date)
    train_tasks.append(task)

for date in pd.date_range('2015-01-01', '2016-12-31', freq='D'):
    task = task_loader(date)
    val_tasks.append(task)

# Define the model (ConvNP)
model = ConvNP(data_processor, task_loader)

In [None]:
# Initialize the trainer
trainer = Trainer(model, lr=5e-5)

# Training loop
losses = []
val_rmses = []

for epoch in range(10):  # Example: 10 epochs for simplicity
    print(f"Epoch {epoch + 1}")
    
    # Train the model
    batch_losses = trainer(train_tasks)
    losses.append(np.mean(batch_losses))

    # Compute validation RMSE
    val_rmse = compute_val_rmse(model, val_tasks)
    val_rmses.append(val_rmse)

    print(f"Training loss: {losses[-1]}")
    print(f"Validation RMSE: {val_rmses[-1]}")

# Plot the training loss and validation RMSE
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(losses)
axes[1].plot(val_rmses)
axes[0].set_title("Training Loss")
axes[1].set_title("Validation RMSE")
axes[0].set_xlabel("Epoch")
axes[1].set_xlabel("Epoch")
plt.tight_layout()
plt.show()