In [1]:
%load_ext autoreload
%autoreload 2

# Reload model from wandb cloud and use it to predict on arbitrary data

In [2]:
import os
# Make sure we're in the right directory
if os.path.basename(os.getcwd()) in ["notebooks", "examples"]:
    os.chdir("..")

# Data requirements:
Please have the following files in the folder to which DATA_DIR points to (besides the inputs/outputs files):
- ymonmean.1980_2010.compress.isosph5.CMIP6.historical.ensmean.Output.PrecipCon.nc'
- ymonstd.1980_2010.compress.isosph5.CMIP6.historical.ensmean.Output.PrecipCon.nc'

In [3]:
DATA_DIR = "C:/Users/salva/PycharmProjects/Data/aibedo"  # the data used for prediction must be here, as well as the cmip6 mean/std statistics
# Input data filename (isosph is an order 6 icosahedron, isosph5 of order 5, etc.)
filename_input = "isosph5.denorm_nonorm.CESM2.historical.r1i1p1f1.Input.Exp8.nc"
# Output data filename is inferred from the input filename, do not edit!
# E.g.: "compress.isosph.CESM2.historical.r1i1p1f1.Output.nc"
filename_output = filename_input.replace("Input.Exp8.nc", "Output.nc")

In [4]:
import xarray as xr
import numpy as np
from typing import *
import wandb
import torch
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from aibedo.models import BaseModel
from aibedo.utilities.wandb_api import reload_checkpoint_from_wandb, get_run_ids_for_hyperparams

## First, reload the model

In [5]:
# Get the appropriate device (GPU or CPU) to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
overrides = [f'datamodule.data_dir={DATA_DIR}', f"++model.use_auxiliary_vars=False"]

### Use the following cell to get the wandb run ID corresponding to a (hyper-)parameter combination

If there are multiple results, the search was not narrow enough (or it is the same run for different random seeds)

In [6]:
# options for datamodule/esm_for_training (the ESM(s) used to train the model) are:
#   - 5 ESMs: ["MRI-ESM2-0", "CESM2", "GFDL-ESM4", "MPI-ESM1-2-LR", "CESM2-WACCM"]
#   - 3 ESMs: ["MRI-ESM2-0", "CESM2", "GFDL-ESM4"]
#   - 1 ESM: CESM2

example_hyperparams = {
    'datamodule/time_lag': 3,  # one of 0,1,2,3,4
    'datamodule/esm_for_training':  ["MRI-ESM2-0", "CESM2", "GFDL-ESM4"],  #  or "CESM2" (as a string, *not* list)
    'model/name': 'MLP',  # or FNO
    'datamodule/output_vars': ['tas_nonorm', 'ps_nonorm', 'pr_nonorm'],  # or the same but with '_pre'
    'model/window': 1  # keep this way
}
run_ids = get_run_ids_for_hyperparams(example_hyperparams)
run_ids

['1kxicry8']

### Reload model(s):

In [None]:
run_id_mlp_lag3_nonorm1 = "1kxicry8"
reloaded_mlp = reload_checkpoint_from_wandb(run_id_mlp_lag3_nonorm1, override_key_value=overrides, try_local_recovery=False)['model']

# Pre-process the data to be used for the ML model

In [None]:
def concat_variables_into_channel_dim(data: xr.Dataset, variables: List[str]) -> np.ndarray:
    """Concatenate xarray variables into numpy channel dimension (last)."""
    assert len(data[variables[0]].shape) == 2, "Each input data variable must have two dimensions"
    data_ml = np.concatenate(
        [np.expand_dims(data[var].values, axis=-1) for var in variables],
        axis=-1  # last axis
    )
    return data_ml.astype(np.float32)

def get_month_of_output_data(output_xarray: xr.Dataset) -> np.ndarray:
    """ Get month of the snapshot (0-11)  """
    n_gridcells = len(output_xarray['ncells'])
    # .item() is required here as only one timestep is used, the subtraction with -1 because we want 0-indexed months
    month_of_snapshot = np.array(output_xarray.coords['time'].item().month, dtype=np.float32) - 1
    # now repeat the month for each grid cell/pixel
    dataset_month = np.repeat(month_of_snapshot, n_gridcells)
    return dataset_month.reshape([1, n_gridcells, 1])  # Add a batch dimension and dummy channel/feature dimension

In [None]:
def get_pytorch_model_data(input_xarray: xr.Dataset, output_xarray: xr.Dataset, input_vars: List[str]) -> torch.Tensor:
    """Get the tensor input data for the ML model."""
    # Concatenate all variables into the channel/feature dimension (last) of the input tensor
    data_input = concat_variables_into_channel_dim(input_xarray, input_vars)
    # Get the month of the snapshot (0-11), which is needed to denormalize the model predictions into their original scale
    data_month = get_month_of_output_data(output_xarray)
    # For convenience, we concatenate the month information to the input data, but it is *not* used by the model!
    data_input = np.concatenate([data_input, data_month], axis=-1)
    # Convert to torch tensor and move to CPU/GPU
    data_input = torch.from_numpy(data_input).float().to(device)
    return data_input

In [None]:
def predict_with_aibedo_model(aibedo_model: BaseModel, input_tensor: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Predict with the AiBEDO model.
    Returns:
        A dictionary of output-variable -> prediction-tensor key->value pairs for each variable {var}.
        Keys with name {var} (e.g. 'pr') are in denormalized scale. Keys with name {var}_pre or {var}_nonorm are raw predictions of the ML model.
        To only get the raw predictions, please use aibedo_model.raw_predict(input_tensor)
    """
    model.eval()
    with torch.no_grad():  # No need to track the gradients during inference
        prediction = aibedo_model.predict(input_tensor, return_normalized_outputs=True)  # if true, also return {var}_nonorm (or {var}_pre)
    return prediction

# Prediction code
#### Select below which model to use for prediction:

In [None]:
model = reloaded_mlp      # Select which model to use for prediction

### Load the actual data and process it

In [None]:
ds_input = xr.open_dataset(f"{DATA_DIR}/{filename_input}")  # Input data
ds_output = xr.open_dataset(f"{DATA_DIR}/{filename_output}") # Ground truth data
input_ml = get_pytorch_model_data(ds_input, ds_output, input_vars=model.main_input_vars)

### Get AiBEDO predictions

In [None]:
predictions_ml = predict_with_aibedo_model(model, input_ml)
predictions_ml.keys()