In [1]:
import os
from pathlib import Path

import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import dask.array as da
import gcsfs
import torch

import deepsensor.torch  # noqa: F401 (enables torch backend)
from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds
from deepsensor.model import ConvNP
from deepsensor.train import Trainer, set_gpu_default_device

# Local package utilities
from deepsensor_greatlakes.utils import (
    standardize_dates,
    generate_random_coordinates,
    apply_mask_to_prediction,
)
from deepsensor_greatlakes.preprocessor import (
    SeasonalCycleProcessor,
    list_saved_seasonal_cycles,
)
from deepsensor_greatlakes.model import save_model, load_convnp_model

set_gpu_default_device()
print("Using device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


Using device: Tesla T4


In [2]:
# ===============================================================
# Project / GCS configuration
# ===============================================================
from pathlib import Path
import numpy as np
import torch

PROJECT_ID = "great-lakes-osd"
LOCATION   = "us-central1"

# Only store the bucket *name* here
BUCKET_NAME = "great-lakes-osd"

# Derive the gs:// form where needed
BUCKET = f"gs://{BUCKET_NAME}"

# Core datasets in GCS
BATHYMETRY_PATH        = f"{BUCKET}/context/gl_bathy_depth_3arcsec.nc"
LAKEMASK_PATH          = f"{BUCKET}/context/lakemask.nc"
ICE_CONCENTRATION_PATH = f"{BUCKET}/ice/ice_concentration_glsea3_838x1181_1995-2022.zarr"
GLSEA_PATH             = f"{BUCKET}/GLSEA_combined.zarr"

# ===============================================================
# Local folders for configs / models
# ===============================================================

# DataProcessor normalisation config
DATAPROCESSOR_DIR = Path("./deepsensor_config")
DATAPROCESSOR_DIR.mkdir(parents=True, exist_ok=True)

# Where to save the trained ConvNP model (best checkpoint)
MODEL_DIR_LOCAL = Path("./models/convnp_sst")
MODEL_DIR_LOCAL.mkdir(parents=True, exist_ok=True)

# (Optional) String for a GCS model folder if you later want to sync/upload
MODEL_DIR_GCS = f"{BUCKET}/models/convnp_sst"

# ===============================================================
# Experiment mode: 'debug', 'scaled', or 'full'
# ===============================================================

MODE = "scaled"   # options: "debug", "scaled", "full"
print(f"Running in MODE = {MODE!r}")

# ===============================================================
# Time ranges + hyperparameters by mode
# ===============================================================

if MODE == "debug":
    # Tiny, safe, for smoke tests
    DATA_RANGE            = ("2009-01-01", "2010-12-31")
    TRAIN_RANGE           = ("2009-01-01", "2009-12-31")
    VAL_RANGE             = ("2010-01-01", "2010-12-31")
    DATE_SUBSAMPLE_FACTOR = 30      # ~monthly

    N_RANDOM_POINTS  = 50           # fewer buoys per task
    N_EPOCHS         = 10
    LEARNING_RATE    = 5e-5
    BATCH_SIZE       = 1
    INTERNAL_DENSITY = 200          # keep ConvNP cheap
    SEED             = 42

elif MODE == "scaled":
    # First serious run on a T4, still conservative
    # Match ice record: 1995â€“2022
    DATA_RANGE            = ("1995-01-01", "2022-12-31")
    TRAIN_RANGE           = ("1995-01-01", "2016-12-31")
    VAL_RANGE             = ("2017-01-01", "2022-12-31")
    DATE_SUBSAMPLE_FACTOR = 7       # weekly sampling

    N_RANDOM_POINTS  = 150          # N context buoys per task
    N_EPOCHS         = 50
    LEARNING_RATE    = 5e-5
    BATCH_SIZE       = 1            # can try 2 later if GPU is comfy
    INTERNAL_DENSITY = 400          # bigger than debug, still T4-friendly
    SEED             = 42

elif MODE == "full":
    # Ambitious run; may need bigger GPU or at least monitoring
    DATA_RANGE            = ("1995-01-01", "2022-12-31")
    TRAIN_RANGE           = ("1995-01-01", "2018-12-31")
    VAL_RANGE             = ("2019-01-01", "2022-12-31")
    DATE_SUBSAMPLE_FACTOR = 3       # every 3rd day

    N_RANDOM_POINTS  = 200
    N_EPOCHS         = 75
    LEARNING_RATE    = 5e-5
    BATCH_SIZE       = 1            # bump to 2 only if VRAM is clearly free
    INTERNAL_DENSITY = 600          # increase cautiously; main GPU knob
    SEED             = 42

else:
    raise ValueError(f"Unknown MODE={MODE!r}")

# =====================================================================
# CONFIG: Should we LOAD an existing DataProcessor, or BUILD a new one?
# =====================================================================
USE_EXISTING_DATAPROCESSOR = True   # Set to False to rebuild from scratch

# ===============================================================
# Target variable and seeding
# ===============================================================

# Name of the GLSEA variable to predict 
TARGET_VAR_NAME = "sst"    # change to "sst_anom" etc. if needed

# Set seeds for reproducibility-ish behaviour
np.random.seed(SEED)
torch.manual_seed(SEED)

print("Config OK.")
print(f"TRAIN_RANGE={TRAIN_RANGE}, VAL_RANGE={VAL_RANGE}, "
      f"DATE_SUBSAMPLE_FACTOR={DATE_SUBSAMPLE_FACTOR}, "
      f"N_RANDOM_POINTS={N_RANDOM_POINTS}, INTERNAL_DENSITY={INTERNAL_DENSITY}")


Running in MODE = 'scaled'
Config OK.


In [3]:
# ===============================================================
# Load temporal datasets (ice concentration, GLSEA)
# ===============================================================

print("Opening Zarr datasets from GCS...")
ice_concentration_raw = xr.open_zarr(ICE_CONCENTRATION_PATH)
glsea_raw = xr.open_zarr(GLSEA_PATH)

# Replace sentinel land value -1 with NaN in ice concentration
ice_concentration_raw = ice_concentration_raw.where(ice_concentration_raw != -1, float("nan"))

# Drop the useless CRS variable if present
if "crs" in glsea_raw:
    glsea_raw = glsea_raw.drop_vars("crs")

# Convert times to date-only
ice_concentration_raw = standardize_dates(ice_concentration_raw)
glsea_raw = standardize_dates(glsea_raw)

# Restrict to chosen data range
t0, t1 = DATA_RANGE
ice_concentration_raw = ice_concentration_raw.sel(time=slice(t0, t1))
glsea_raw = glsea_raw.sel(time=slice(t0, t1))

print("ice_concentration:", ice_concentration_raw)
print("glsea:", glsea_raw)


Opening Zarr datasets from GCS...


NameError: name 'ICE_CONCENTRATION_PATH' is not defined

In [None]:
# ===============================================================
# Load static datasets (bathymetry + lake mask) from GCS via gcsfs
# ===============================================================

print("Opening bathymetry + lake mask from GCS using gcsfs...")

fs = gcsfs.GCSFileSystem()

# gcsfs paths must be "bucket/path", not "gs://bucket/path"
bathy_path_gcsfs = f"{BUCKET_NAME}/context/gl_bathy_depth_3arcsec.nc"
mask_path_gcsfs  = f"{BUCKET_NAME}/context/lakemask.nc"

# IMPORTANT: Do NOT close these file objects â€” xarray needs them
bathy_fobj = fs.open(bathy_path_gcsfs, "rb")
mask_fobj  = fs.open(mask_path_gcsfs,  "rb")

bathymetry_raw = xr.open_dataset(bathy_fobj)
lakemask_raw   = xr.open_dataset(mask_fobj)

# ===============================================================
# Clean up bathymetry + lake mask datasets
# ===============================================================

# --- Clean bathymetry_raw ---
# Drop the useless CRS variable if present
if "crs" in bathymetry_raw:
    bathymetry_raw = bathymetry_raw.drop_vars("crs")

# Rename Band1 â†’ depth (this will become much nicer downstream)
if "Band1" in bathymetry_raw:
    bathymetry_raw = bathymetry_raw.rename({"Band1": "depth"})

print("bathymetry_raw:", bathymetry_raw)
print("lakemask_raw:", lakemask_raw)


In [None]:
# ===============================================================
# Construct circular day-of-year features
# ===============================================================

from deepsensor.data import construct_circ_time_ds

# Build a date coordinate covering the full GLSEA dataset
dates = pd.date_range(
    glsea_raw.time.values.min(),
    glsea_raw.time.values.max(),
    freq="D"
)

# Construct circular time dataset (cosine + sine of day-of-year)
# This dataset has its own time index matching "dates"
doy_ds = construct_circ_time_ds(dates, freq="D")

# Standardize times so they match the YYYY-MM-DD format everywhere else
cos_D = standardize_dates(doy_ds["cos_D"])
sin_D = standardize_dates(doy_ds["sin_D"])

print("Circular time features constructed:")
print("cos_D:", cos_D)
print("sin_D:", sin_D)


In [None]:
# ===============================================================
# DataProcessor Setup: Register all variables with correct scaling
# ===============================================================
#
# DeepSensor's DataProcessor learns normalization the FIRST time it
# sees each dataset. We therefore either:
#   - LOAD an existing, saved DataProcessor (preferred), or
#   - BUILD a new one and save it for future runs.
#
# Scaling strategy:
#   - SST (glsea) ....................... mean_std   (Gaussian-ish field)
#   - Ice Concentration ................. min_max    (bounded 0â€“1)
#   - Bathymetry ........................ min_max    (static absolute field)
#   - Lake Mask ......................... min_max    (binary 0/1, preserved)
#
# After registration, we process the full datasets with consistent scaling.
# ===============================================================

print("=== DataProcessor setup ===")

# 0. Attempt to LOAD existing DataProcessor, or BUILD a new one
if USE_EXISTING_DATAPROCESSOR and Path(DATAPROCESSOR_DIR).exists():
    print(f"Loading existing DataProcessor from: {DATAPROCESSOR_DIR}")
    data_processor = DataProcessor(DATAPROCESSOR_DIR)
    USING_LOADED_PROCESSOR = True

else:
    print("Building a NEW DataProcessor...")
    data_processor = DataProcessor(x1_name="lat", x2_name="lon")
    USING_LOADED_PROCESSOR = False

    # 1. Register SST (glsea) with MEAN-STD scaling
    print("  Registering SST (glsea) with mean_std scaling...")
    _ = data_processor(glsea_raw, method="mean_std")

    # 2. Register Ice Concentration with MIN-MAX scaling
    print("  Registering Ice Concentration with min_max scaling...")
    _ = data_processor(ice_concentration_raw, method="min_max")

    # 3. Register Bathymetry with MIN-MAX scaling
    print("  Registering bathymetry with min_max scaling...")
    _ = data_processor(bathymetry_raw, method="min_max")

    # 4. Register Lake Mask with MIN-MAX scaling
    print("  Registering lake mask with min_max scaling...")
    _ = data_processor(lakemask_raw, method="min_max")

    print("\nRegistration complete.")
    print("Registered variables:")
    for name, cfg in data_processor.config.items():
        print(f"  - {name}: {cfg}")

    # Save the learned scaling once
    data_processor.save(DATAPROCESSOR_DIR)
    print(f"\nSaved new DataProcessor to: {DATAPROCESSOR_DIR}")

if USING_LOADED_PROCESSOR:
    print("Using previously-saved normalization parameters.")

# ---------------------------------------------------------------
# Process all datasets with learned scaling
# ---------------------------------------------------------------
print("\nApplying DataProcessor scaling to all datasets...")

glsea             = data_processor(glsea_raw)
ice_concentration = data_processor(ice_concentration_raw)
bathymetry        = data_processor(bathymetry_raw)
lakemask          = data_processor(lakemask_raw)

print("\nDataProcessor processing complete.")
print(f"  glsea processed vars:             {list(glsea.data_vars)}")
print(f"  ice_concentration processed vars: {list(ice_concentration.data_vars)}")
print(f"  bathymetry processed vars:        {list(bathymetry.data_vars)}")
print(f"  lakemask processed vars:          {list(lakemask.data_vars)}")


In [None]:
# ===============================================================
# TaskLoader setup and task generation
# ===============================================================
#
# We:
#   1. Build an auxiliary context dataset (static + time features)
#   2. Instantiate a TaskLoader with:
#        - context = [glsea, ice_concentration, aux_ds]
#        - target  = glsea (SST field)
#   3. Define a gen_tasks(...) helper to:
#        - sample N random lake points per date
#        - build tasks
#        - remove NaNs from targets
#   4. Generate train/val tasks
# ===============================================================

# 1. Auxiliary context dataset
#    Note: bathymetry var is now "depth", mask var is "mask".
aux_ds = xr.Dataset({
    "mask":  lakemask["mask"],      # processed lake mask (0/1)
    "cos_D": cos_D,                 # circular time features (already sin/cos)
    "sin_D": sin_D,
})

print("aux_ds:", aux_ds)

# 2. TaskLoader: contexts and target
task_loader = TaskLoader(
    context=[glsea, ice_concentration, bathymetry, aux_ds],
    target=glsea,
)

print("\nTaskLoader initialised:")
print(task_loader)


# 3. Task generation helper
def gen_tasks(dates, N=100, progress=True, lakemask_raw=None, data_processor=None):
    """
    Generate a list of DeepSensor tasks for given dates.

    Parameters
    ----------
    dates : iterable of np.datetime64 / pandas.Timestamp
        Dates at which to sample tasks.
    N : int
        Number of random lake points (context locations) per task.
    progress : bool
        If True, show tqdm progress bar.
    lakemask_raw : xr.Dataset
        Raw lake mask dataset used by `generate_random_coordinates`.
    data_processor : deepsensor.data.DataProcessor
        DataProcessor instance used to normalise coordinates.

    Returns
    -------
    tasks : list
        List of DeepSensor Task objects.
    """
    if lakemask_raw is None or data_processor is None:
        raise ValueError("You must pass both `lakemask_raw` and `data_processor`.")

    tasks = []
    for date in tqdm(dates, disable=not progress, desc="Generating tasks"):
        # Generate a fresh set of random lake points for each date
        random_points = generate_random_coordinates(lakemask_raw, N, data_processor)

        # Sample the task:
        #   - context_sampling: use our N random coordinates
        #   - target_sampling:  "all" â†’ full SST field on that date
        task = task_loader(date, context_sampling=random_points, target_sampling="all")

        # Remove NaNs from the target (e.g., land, missing data)
        task = task.remove_target_nans()

        tasks.append(task)

    return tasks


# 4. Build train/validation date lists
#    (Assumes TRAIN_RANGE, VAL_RANGE, DATE_SUBSAMPLE_FACTOR are already defined,
#     e.g.: TRAIN_RANGE = ("2009-01-01","2009-12-31"), etc.)

train_dates = pd.date_range(TRAIN_RANGE[0], TRAIN_RANGE[1])[::DATE_SUBSAMPLE_FACTOR]
val_dates   = pd.date_range(VAL_RANGE[0],   VAL_RANGE[1])[::DATE_SUBSAMPLE_FACTOR]

# Normalize to datetime64[D]
train_dates = pd.to_datetime(train_dates).normalize()
val_dates   = pd.to_datetime(val_dates).normalize()

print(f"\nNumber of training dates:   {len(train_dates)}")
print(f"Number of validation dates: {len(val_dates)}")


# 5. Generate training and validation tasks
train_tasks = gen_tasks(
    train_dates,
    N=100,  # number of random lake points per task
    lakemask_raw=lakemask_raw,
    data_processor=data_processor,
)

val_tasks = gen_tasks(
    val_dates,
    N=100,
    lakemask_raw=lakemask_raw,
    data_processor=data_processor,
)

print(f"\nGenerated {len(train_tasks)} training tasks and {len(val_tasks)} validation tasks.")


# 6. Quick sanity check plot of one task
fig = deepsensor.plot.task(train_tasks[2], task_loader)
plt.show()


In [None]:
# ---------------------------------------------------------------
# 1. Device + model + trainer
# ---------------------------------------------------------------

print("Initialising ConvNP model and Trainer...")
model = ConvNP(
    data_processor, 
    task_loader,
    internal_density=INTERNAL_DENSITY,
)

# Trainer only takes (model, lr)
trainer = Trainer(
    model,
    lr=LEARNING_RATE,
)

# For un-normalising SST later
target_var_ID = task_loader.target_var_IDs[0][0]  # first target, 1D

print("Model and Trainer ready.")
print(f"Training for {N_EPOCHS} epochs with lr={LEARNING_RATE:g}, batch_size={BATCH_SIZE}")

# ---------------------------------------------------------------
# 2. Helper: validation RMSE in physical SST units
# ---------------------------------------------------------------

def compute_val_rmse(model, val_tasks, task_loader, data_processor, target_var_ID):
    """
    Compute RMSE over all validation tasks in un-normalised SST units.
    """
    errors = []

    for task in val_tasks:
        with torch.no_grad():
            # Mean prediction in normalised space
            mean_norm = model.mean(task)

            # Map prediction and truth back to physical SST
            mean = data_processor.map_array(
                mean_norm,
                target_var_ID,
                unnorm=True,
            )
            true = data_processor.map_array(
                task["Y_t"][0],
                target_var_ID,
                unnorm=True,
            )

        diff = mean - true
        errors.append(diff.ravel() ** 2)

        torch.cuda.empty_cache()

    errors = np.concatenate(errors)
    rmse = float(np.sqrt(np.mean(errors)))
    return rmse

# ---------------------------------------------------------------
# 3. Training loop (now using Trainer.__call__ with batch_size)
# ---------------------------------------------------------------

losses        = []
val_rmses     = []
best_val_rmse = np.inf

print("\n=== Starting training loop ===")

for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training epochs"):
    # ðŸ”¹ IMPORTANT: pass batch_size here to avoid OOM
    batch_losses = trainer(
        train_tasks,
        batch_size=BATCH_SIZE,      # this is the DeepSensor-supported batching
        progress_bar=False,
        tqdm_notebook=False,
    )
    epoch_loss = float(np.mean(batch_losses))
    losses.append(epoch_loss)

    # Validation RMSE in physical SST units
    val_rmse = compute_val_rmse(
        model,
        val_tasks,
        task_loader,
        data_processor,
        target_var_ID,
    )
    val_rmses.append(val_rmse)

    print(f"Epoch {epoch:03d} | loss = {epoch_loss:.4f} | val RMSE = {val_rmse:.4f}")

    # Save best model so far
    if val_rmse < best_val_rmse:
        best_val_rmse = val_rmse
        save_model(model, MODEL_DIR_LOCAL)
        print(f"  â†³ New best model saved to {MODEL_DIR_LOCAL} (RMSE = {best_val_rmse:.4f})")

    torch.cuda.empty_cache()

print("\nTraining complete.")
print(f"Best validation RMSE: {best_val_rmse:.4f}")

# ---------------------------------------------------------------
# 4. Plot training diagnostics
# ---------------------------------------------------------------

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(losses)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Training loss")
axes[0].set_title("Training loss")

axes[1].plot(val_rmses)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Validation RMSE")
axes[1].set_title("Validation RMSE")

plt.tight_layout()
plt.show()


In [None]:
# ===============================================================
# Single prediction on one validation date
# ===============================================================

from deepsensor_greatlakes.model import load_convnp_model

# 1. Reload best model from disk (to confirm save/load works)
print(f"Loading best model from {MODEL_DIR_LOCAL} ...")
best_model = load_convnp_model(MODEL_DIR_LOCAL, data_processor, task_loader)
print("Model loaded.")

# 2. Choose a prediction date (first validation date, for example)
PRED_DATE = val_dates[0]
print(f"Making prediction for date: {PRED_DATE}")

# 3. Sample context points for this date
context_points = generate_random_coordinates(
    lakemask_raw,
    N_RANDOM_POINTS,
    data_processor,
)

# 4. Build a task for this date
test_task = task_loader(
    PRED_DATE,
    context_sampling=context_points,
    target_sampling="all",
)

# 5. Run prediction.
#    X_t argument defines where we want the output grid; using raw GLSEA grid here.
pred_ds = best_model.predict(test_task, X_t=glsea_raw)

print("Prediction dataset:")
print(pred_ds)


In [None]:
# ===============================================================
# Mask prediction to lake and plot mean + std
# ===============================================================

# Extract the SST field from the prediction dataset
sst_pred = pred_ds[TARGET_VAR_NAME]   # "sst"

# Apply lake mask (your helper handles matching coords)
sst_pred_masked = apply_mask_to_prediction(sst_pred, lakemask_raw)

plt.figure(figsize=(15, 6))

# Mean
plt.subplot(1, 2, 1)
sst_pred_masked["mean"].plot(
    cmap="viridis",
    cbar_kwargs={"label": "Predicted SST (Â°C)"},
)
plt.title(f"Predicted mean SST\n{np.array(PRED_DATE).astype('datetime64[D]')}")
plt.xlabel("Longitude")
plt.ylabel("Latitude")

# Std (uncertainty)
plt.subplot(1, 2, 2)
sst_pred_masked["std"].plot(
    cmap="plasma",
    cbar_kwargs={"label": "Predictive std (Â°C)"},
)
plt.title(f"Predictive uncertainty (std)\n{np.array(PRED_DATE).astype('datetime64[D]')}")
plt.xlabel("Longitude")
plt.ylabel("Latitude")

plt.tight_layout()
plt.show()
