In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
# Make sure we're in the right directory
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
import torch
import xarray as xr
import numpy as np
from typing import *
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from aibedo.models.MLP import AIBEDO_MLP
from aibedo.models.base_model import BaseModel


# Paths and filenames
## ***Please edit here to your own paths and desired filenames***

In [None]:
# Define the path where the data and model checkpoint is stored
DATA_DIR = "../../data"
# Input data filename (isosph is an order 6 icosahedron, isosph5 of order 5, etc.)
filename_input = "compress.isosph.CESM2.historical.r1i1p1f1.Input.Exp8_fixed.nc"
# Output data filename (inferred from the input filename), do not edit!
# E.g.: "compress.isosph.CESM2.historical.r1i1p1f1.Output.PrecipCon.nc"
filename_output = filename_input.replace("Input.Exp8_fixed", "Output.PrecipCon")
# Define the timestep to use as input data (as absolute index, -10 means 10 timesteps before the last timestep)
prediction_timestep = -10
# Define the ML model checkpoint path to be reloaded
CKPT = 'epoch023_seed15.ckpt'

### Some constants that we will use later on *(do not edit)*

In [None]:
# _pre means that the variable has been pre-processed (i.e. deseasonalized, detrended, etc.)
VARS_INPUT = [ 'crelSurf_pre', 'crel_pre', 'cresSurf_pre', 'cres_pre', 'netTOAcs_pre', 'lsMask', 'netSurfcs_pre']
VARS_OUTPUT = ['tas', 'ps', 'pr']
output_var_clean_name = {
    'tas': 'Air Temperature',
    'ps': "Surface Pressure",
    'pr': "Precipitation",
}

# (Re-)Loading
### Load the pre-processed data

In [None]:
ds_input = xr.open_dataset(f"{DATA_DIR}/{filename_input}")  # Input data
ds_output = xr.open_dataset(f"{DATA_DIR}/{filename_output}") # Ground truth data
ds_input.crel_pre.values.shape  # (time, pixel-in-icosahedron)

### Load the model

In [None]:
# Get the appropriate device (GPU or CPU) to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the trained model checkpoint (its weights, hyperparameters, etc.)
saved_model = torch.load(f"{DATA_DIR}/{CKPT}", map_location=device)
saved_model['hyper_parameters']['datamodule_config']['data_dir'] = DATA_DIR   # Update the data directory
# Get the appropriate architecture to use based on the hyperparameters
model = AIBEDO_MLP(**saved_model['hyper_parameters'], use_auxiliary_vars=False)
saved_model['hyper_parameters']

### Reload the checkpoint (model weights)

In [None]:
model.load_state_dict(saved_model['state_dict'])

### Reload the checkpoint (model weights)

In [None]:
model.load_state_dict(saved_model['state_dict'])

# Pre-process the data to be used for the ML model

In [None]:
model.load_state_dict(saved_model['state_dict'])

# Pre-process the data to be used for the ML model

In [None]:
def concat_variables_into_channel_dim(data: xr.Dataset, variables: List[str]) -> np.ndarray:
    """Concatenate xarray variables into numpy channel dimension (last)."""
    assert len(data[variables[0]].shape) == 1, "Each input data variable must have only one dimension"
    data_ml = np.concatenate(
        [data[var].values.reshape((-1, 1)) for var in variables],
        axis=-1  # last axis
    )
    return np.expand_dims(data_ml, axis=0).astype(np.float32)  # Add a batch dimension

def get_month_of_output_data(output_xarray: xr.Dataset) -> np.ndarray:
    """ Get month of the snapshot (0-11)  """
    n_gridcells = len(output_xarray['ncells'])
    # .item() is required here as only one timestep is used, the subtraction with -1 because we want 0-indexed months
    month_of_snapshot = np.array(output_xarray.coords['time'].item().month, dtype=np.float32) - 1
    # now repeat the month for each grid cell/pixel
    dataset_month = np.repeat(month_of_snapshot, n_gridcells)
    return dataset_month.reshape([1, n_gridcells, 1])  # Add a batch dimension and dummy channel/feature dimension

In [None]:
def get_pytorch_model_data(input_xarray: xr.Dataset, output_xarray: xr.Dataset, timestep_i: int) -> torch.Tensor:
    """Get the tensor input data for the ML model at the specified timestep."""
    snapshot_input_raw = input_xarray.isel(time=timestep_i)
    snapshot_output_raw = output_xarray.isel(time=timestep_i)
    # Concatenate all variables into the channel/feature dimension (last) of the input tensor
    data_input = concat_variables_into_channel_dim(snapshot_input_raw, VARS_INPUT)
    # Get the month of the snapshot (0-11), which is needed to denormalize the model predictions into their original scale
    data_month = get_month_of_output_data(snapshot_output_raw)
    # For convenience, we concatenate the month information to the input data, but it is *not* used by the model!
    data_input = np.concatenate([data_input, data_month], axis=-1)
    # Convert to torch tensor and move to CPU/GPU
    data_input = torch.from_numpy(data_input).float().to(device)
    return data_input

In [None]:
snapshot_input_ml = get_pytorch_model_data(ds_input, ds_output, prediction_timestep)
snapshot_target = ds_output.isel(time=prediction_timestep)
snapshot_input_ml.shape  # (batch-dimension, icosahedron-grid-dimension, feature-dimension)

# Prediction with the AiBEDO model
###### ***Note:*** Please always use the ```model.predict(input_tensor)``` method instead of ```model(input_tensor)```!!!

In [None]:
def predict_with_aibedo_model(aibedo_model: BaseModel, input_tensor: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Predict with the AiBEDO model.
    Returns:
        A dictionary of output-variable -> prediction-tensor key->value pairs.
        Each prediction-tensor has been denormalized to original scale (e.g. temperature in kelvin)
    """
    model.eval()
    with torch.no_grad():  # No need to track the gradients during inference
        prediction = aibedo_model.predict(input_tensor)
    return prediction

In [None]:
snapshot_prediction = predict_with_aibedo_model(model, snapshot_input_ml)
snapshot_prediction.keys()

# Post-processing and plotting

In [None]:
def get_predictions_xarray(targets_ds: xr.Dataset, predictions_ds: xr.Dataset) -> xr.Dataset:
    """ Add the torch tensor predictions to the xarray targets dataset as well as errors (bias, MAE). """
    return_ds = targets_ds.copy()
    for var, pred in predictions_ds.items():
        pred_key = f"{var}_pred"
        return_ds[pred_key] = ('ncells', pred.squeeze().cpu().numpy())
        # compute the error
        diff_err = return_ds[pred_key] - return_ds[var]
        return_ds[f'{var}_mae'] = np.abs(diff_err)
        return_ds[f'{var}_bias'] = diff_err
    return return_ds

In [None]:
snapshot_postprocessed = get_predictions_xarray(snapshot_target, snapshot_prediction)

##### A plotting script

In [None]:
def single_snapshot_plotting(postprocessed_xarray: xr.Dataset,
                       **kwargs
                       ):
    proj = ccrs.PlateCarree()
    plot_kwargs = dict(
        ds=postprocessed_xarray,
        x='lon',
        y='lat',
        transform=proj, subplot_kws={'projection': proj},
        cbar_kwargs={'shrink': 0.8,  # make cbar smaller/larger
                     'pad': 0.01,  # padding between right-most subplot and cbar
                     'fraction': 0.05}, **kwargs
    )
    nrows, ncols = 4, 3
    fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True,
                            subplot_kw={'projection': proj},
                            gridspec_kw={'wspace': 0.07, 'hspace': 0,
                                         'top': 1., 'bottom': 0., 'left': 0., 'right': 1.},
                            figsize=(ncols * 12, nrows * 6)  # <- adjust figsize but keep ratio ncols/nrows
                            )

    for j, var in enumerate(VARS_OUTPUT):
        p_target = xr.plot.scatter(hue=var, ax=axs[0, j], **plot_kwargs)
        p_preds = xr.plot.scatter(hue=f'{var}_pred', ax=axs[1, j], vmin=p_target.colorbar.vmin, vmax=p_target.colorbar.vmax, **plot_kwargs)
        p_bias = xr.plot.scatter(hue=f'{var}_bias', ax=axs[2, j], **plot_kwargs)
        p_mae = xr.plot.scatter(hue=f'{var}_mae', ax=axs[3, j], **plot_kwargs)

        # Set title
        axs[0, j].set_title(output_var_clean_name[var], size=30)

        # Edit colorbar
        for p, label in zip([p_target, p_preds, p_bias, p_mae], ['Targets', "AiBEDO", 'Bias', "MAE"]):
            p.colorbar.set_label(label, size=25)
            p.colorbar.ax.tick_params(labelsize=18)

    for ax in list(axs.flat):
        ax.coastlines(linewidth=0.5)


## Plotting the results
**Legend:**

    Each column is a different (denormalized) output variable.
    
    Rows:
    - First row: Targets
    - Second row: AiBEDO model predictions
    - Third row: Bias error (AiBEDO - Targets)
    - Fourth row: MAE error (|AiBEDO - Targets|)
   
    Note the 1e-5 in the precipitation errors!
 

In [None]:
single_snapshot_plotting(snapshot_postprocessed, robust=True, s=2)