# Tutorial 4: Running the model on downscaled climate projections

### Outline:

* Imports, including library code from previous steps
* More in-depth exploration of the downscaled CMIP6 data
* Loading the trained model
* Setting up an inference pipeline for multiple projections
* Analysis

## Setup and configuration

At this point you should be familiar with the setup routine, importing packages and setting devices and datatypes.

In [None]:
import os
import time
import torch
import intake
import regionmask
import xarray as xr
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import warnings

from torch import nn
from tqdm.autonotebook import tqdm

from src.datapipes import (
    get_static_data, 
    select_region, 
    make_data_pipeline, 
    scale_means, 
    scale_stds
)
from src.utils import load_experiment
from src.models import create_lstm_model
from src.inference import (
    unmask,
    gen_inference_data_pipeline,
    run_model
)

warnings.filterwarnings('ignore')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16

## Getting the model imported and loaded up

Just as we did in the previous portion of the tutorial, we will use the experiment file to load in our hyperparameters and settings. From there we can use the `create_lstm_model` to construct the structure. To this model structure we append a final `ReLU` function, which simply gets rid of the negative values that we saw before. The `LeakyReLU` used during training did make it so that these values were never too large, but at this point we can just chop them off. Next we load the weights, and set the device and data type appropriately. Finally, we set the model to evaluation mode.

In [None]:
config_file = '../experiments/tutorial/tutorial.yml'
config = load_experiment(config_file)
model = create_lstm_model(**config['model_config'])
model = model.append(nn.ReLU())
model.load_state_dict(torch.load(config['weights_file'], map_location=DEVICE))
model.to(DEVICE).to(DTYPE)
model.eval()

## Data preparation

Unfortunately, we stil also have a bit of work to do to make the climate projections usable with our model. Now that we're running in forward mode on completely new data it's worth seeing what everything looks like. 

In [None]:
cat = intake.open_esm_datastore(
  'https://cpdataeuwest.blob.core.windows.net/cp-cmip/version1/catalogs/global-downscaled-cmip6.json'
)
# cat

In [None]:
pr_data = cat.search(timescale='day', variable_id='pr')
# pr_data

In [None]:
cat = intake.open_esm_datastore(
  'https://cpdataeuwest.blob.core.windows.net/cp-cmip/version1/catalogs/global-downscaled-cmip6.json'
)

cat_subset = cat.search(
    method="GARD-SV",
    source_id="CanESM5",
    experiment_id="ssp245",
    variable_id=['tasmin', 'tasmax', 'pr'],
    timescale='day',
)
dsets = cat_subset.to_dataset_dict()
met_ds = list(dsets.values())[0]#.chunk({'time': 1, 'lat': 48, 'lon': 48})

Next we can open up the static data, which contains the elevation, slopes, aspects, and mask variables. These are actually on a slightly different grid than the CMIP6 data - basically the static data just has longitudes that go from 0 to 360, and the CMIP6 data has longitudes from -180 to 180. Otherwise the grids are identical, which makes it easy to line them back up. To do that we just roll the longitude dimension by half of the total number of gridcells. We will assign the correct coordinates simply using the `assign_coords` method and set the new longitudes using CMIP6 data we just opened up. And then finally we can merge it all into a single `ds`.

In [None]:
static_data = get_static_data()
static_data = static_data.roll(lon=720)
static_data = static_data.assign_coords({'lon': met_ds['lon']})
ds = xr.merge([met_ds.squeeze(), static_data])

You may note that the given `ds` here is still the global data, so now we use the `select_region` method to clip out just the region(s) that we specified in the configuration. Next, is another preprocessing step for making the CMIP6 data match the ERA5 training data by scaling the precipitation from mm/day to mm/second. This is a simple division by the number of seconds in a day, 86400. Then, we copy out the mask just like we did in the prevous section of the tutorial and set the mask in the test data to be all ones so we can run the forward mode on the full domain all at once.

In [None]:
test_data = select_region(ds, config['data_config']['regions'])
# Need to convert from mm/day to mm/s to match ERA5
test_data['pr'] = test_data['pr'] / 86400
# Put in the dummy mask and record the true mask
true_mask = test_data['mask'].copy()
test_data = test_data.fillna(1.0)
test_data['mask'].values[:] = 1.0
test_data = test_data.fillna(1.0)

Next up, we just have to perform the same tricks we used during inference last time by recording the actual shape of the data and then telling the config that we're running the whole domain per batch by setting the `batch_dims`. Last, but not least, we set the `output_var` to be `pr` since we don't acutally have `swe` in the CMIP data and have to have something there. With that all set we can make the data pipeline and get to running the model!

To make our analysis tractable we'll then subset it down to three time periods:
* 2015-2035: We will call this "2020s"
* 2040-2060: We will call this "2050s"
* 2070-2090: We will call this "2080s"

In [None]:
swe_pred = run_model(model, test_data, config)
swe_2020s = swe_pred.sel(time=slice('2015', '2035'))
swe_2050s = swe_pred.sel(time=slice('2040', '2060'))
swe_2080s = swe_pred.sel(time=slice('2070', '2090'))

### Analyzing the projections

In [None]:
def day_of_wateryear(ds):
    result = (ds.time.dt.dayofyear + 92 - 1 ) % (365 + ds.time.dt.is_leap_year) + 1
    result.name = 'dowy'
    return result

In [None]:
def plot_quantile_spread(da, quantiles, ax):
    da_wy = da.groupby(day_of_wateryear(da)).quantile(quantiles)
    dowy = da.dowy
    ax.fill_between(
        dowy, da_wy.sel(quantile=quantiles[0]), da_wy.sel(quantile=quantiles[-1]),  
        alpha=0.25, color='grey'
    )
    ax.fill_between(
        dowy, da_wy.sel(quantile=quantiles[1]), da_wy.sel(quantile=quantiles[-2]),  
        alpha=0.25, color='grey'
    )
    ax.plot(dowy, da_wy.sel(quantile=quantiles[2]), color='grey')


In [None]:
loc = {'lat': slice(38, 40,), 'lon': slice(252-360, 254-360)}
swe_loc_20 = swe_2020s.sel(**loc).mean(dim=['lat', 'lon'])
swe_loc_50 = swe_2050s.sel(**loc).mean(dim=['lat', 'lon'])
swe_loc_80 = swe_2080s.sel(**loc).mean(dim=['lat', 'lon'])

quantiles = [0.1, 0.25, 0.5, 0.75, 0.9]

fig, axes = plt.subplots(1, 3, figsize=(12,4), sharey=True)
plot_quantile_spread(swe_loc_20, quantiles, axes[0])
axes[0].set_title('2020s')
plot_quantile_spread(swe_loc_50, quantiles, axes[1])
axes[1].set_title('2050s')
plot_quantile_spread(swe_loc_80, quantiles, axes[2])
axes[2].set_title('2080s')

In [None]:
axes[0].set_ylabel('SWE [m]')
axes[1].set_xlabel('Day of Wateryear')
plt.suptitle('Southern Rockies', fontsize=16)
plt.tight_layout()

In [None]:
swe_q_20 = swe_loc_20.groupby(day_of_wateryear(swe_loc_20)).quantile(quantiles)
swe_q_50 = swe_loc_50.groupby(day_of_wateryear(swe_loc_50)).quantile(quantiles)
swe_q_80 = swe_loc_80.groupby(day_of_wateryear(swe_loc_80)).quantile(quantiles)

fig, axes = plt.subplots(1, 3, figsize=(12,4), sharey=True)
dowy = swe_q_20.dowy
axes[0].fill_between(dowy, swe_q_20.sel(quantile=quantiles[0]), swe_q_20.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[0].fill_between(dowy, swe_q_20.sel(quantile=quantiles[1]), swe_q_20.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[0].plot(dowy, swe_q_20.sel(quantile=quantiles[2]), color='grey')
axes[0].set_title('2020s')

axes[1].fill_between(dowy, swe_q_50.sel(quantile=quantiles[0]), swe_q_50.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[1].fill_between(dowy, swe_q_50.sel(quantile=quantiles[1]), swe_q_50.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[1].plot(dowy, swe_q_50.sel(quantile=quantiles[2]), color='grey')
axes[1].set_title('2050s')

axes[2].fill_between(dowy, swe_q_80.sel(quantile=quantiles[0]), swe_q_80.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[2].fill_between(dowy, swe_q_80.sel(quantile=quantiles[1]), swe_q_80.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[2].plot(dowy, swe_q_80.sel(quantile=quantiles[2]), color='grey')
axes[2].set_title('2080s')

In [None]:
loc = {'lat': slice(47, 49), 'lon': slice(238-360, 240-360), }# 'method': 'nearest'}
swe_loc_20 = swe_2020s.sel(**loc).mean(dim=['lat', 'lon'])
swe_loc_50 = swe_2050s.sel(**loc).mean(dim=['lat', 'lon'])
swe_loc_80 = swe_2080s.sel(**loc).mean(dim=['lat', 'lon'])

quantiles = [0.1, 0.25, 0.5, 0.75, 0.9]
swe_q_20 = swe_loc_20.groupby(day_of_wateryear(swe_loc_20)).quantile(quantiles)
swe_q_50 = swe_loc_50.groupby(day_of_wateryear(swe_loc_50)).quantile(quantiles)
swe_q_80 = swe_loc_80.groupby(day_of_wateryear(swe_loc_80)).quantile(quantiles)

fig, axes = plt.subplots(1, 3, figsize=(12,4), sharey=True)
dowy = swe_q_20.dowy
axes[0].fill_between(dowy, swe_q_20.sel(quantile=quantiles[0]), swe_q_20.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[0].fill_between(dowy, swe_q_20.sel(quantile=quantiles[1]), swe_q_20.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[0].plot(dowy, swe_q_20.sel(quantile=quantiles[2]), color='grey')
axes[0].set_title('2020s')

axes[1].fill_between(dowy, swe_q_50.sel(quantile=quantiles[0]), swe_q_50.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[1].fill_between(dowy, swe_q_50.sel(quantile=quantiles[1]), swe_q_50.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[1].plot(dowy, swe_q_50.sel(quantile=quantiles[2]), color='grey')
axes[1].set_title('2050s')

axes[2].fill_between(dowy, swe_q_80.sel(quantile=quantiles[0]), swe_q_80.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[2].fill_between(dowy, swe_q_80.sel(quantile=quantiles[1]), swe_q_80.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[2].plot(dowy, swe_q_80.sel(quantile=quantiles[2]), color='grey')
axes[2].set_title('2080s')

axes[0].set_ylabel('SWE [m]')
axes[1].set_xlabel('Day of Wateryear')
plt.suptitle('Northern Cascades', fontsize=16)
plt.tight_layout()

In [None]:
loc = {'lat': slice( 37.5, 38.5,), 'lon': slice(239.75-360, 240.25-360, )}
swe_loc_20 = swe_2020s.sel(**loc).mean(dim=['lat', 'lon'])
swe_loc_50 = swe_2050s.sel(**loc).mean(dim=['lat', 'lon'])
swe_loc_80 = swe_2080s.sel(**loc).mean(dim=['lat', 'lon'])

quantiles = [0.1, 0.25, 0.5, 0.75, 0.9]
swe_q_20 = swe_loc_20.groupby(day_of_wateryear(swe_loc_20)).quantile(quantiles)
swe_q_50 = swe_loc_50.groupby(day_of_wateryear(swe_loc_50)).quantile(quantiles)
swe_q_80 = swe_loc_80.groupby(day_of_wateryear(swe_loc_80)).quantile(quantiles)

fig, axes = plt.subplots(1, 3, figsize=(12,4), sharey=True)
dowy = swe_q_20.dowy
axes[0].fill_between(dowy, swe_q_20.sel(quantile=quantiles[0]), swe_q_20.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[0].fill_between(dowy, swe_q_20.sel(quantile=quantiles[1]), swe_q_20.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[0].plot(dowy, swe_q_20.sel(quantile=quantiles[2]), color='grey')
axes[0].set_title('2020s')

axes[1].fill_between(dowy, swe_q_50.sel(quantile=quantiles[0]), swe_q_50.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[1].fill_between(dowy, swe_q_50.sel(quantile=quantiles[1]), swe_q_50.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[1].plot(dowy, swe_q_50.sel(quantile=quantiles[2]), color='grey')
axes[1].set_title('2050s')

axes[2].fill_between(dowy, swe_q_80.sel(quantile=quantiles[0]), swe_q_80.sel(quantile=quantiles[-1]),  alpha=0.25, color='grey')
axes[2].fill_between(dowy, swe_q_80.sel(quantile=quantiles[1]), swe_q_80.sel(quantile=quantiles[-2]),  alpha=0.25, color='grey')
axes[2].plot(dowy, swe_q_80.sel(quantile=quantiles[2]), color='grey')
axes[2].set_title('2080s')

axes[0].set_ylabel('SWE [m]')
axes[1].set_xlabel('Day of Wateryear')
plt.suptitle('Central Sierra Nevada', fontsize=16)
plt.tight_layout()

In [None]:
pt = pd.pivot_table(
    df_20, 
    index=df.index.dayofyear, 
    columns=df.index.year,
    values='Tuolumne Meadows Pillow SWE [mm]', 
    aggfunc='mean'
)


In [None]:
swe_seas_mean_2020s = swe_2020s.groupby(swe_2020s['time'].dt.season).mean()
swe_seas_mean_2050s = swe_2050s.groupby(swe_2050s['time'].dt.season).mean()
swe_seas_mean_2080s = swe_2080s.groupby(swe_2080s['time'].dt.season).mean()

In [None]:
from matplotlib.colors import LogNorm, SymLogNorm, PowerNorm

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24, 6))

swe_seas_mean_2020s.sel(season='DJF').plot(norm=PowerNorm(gamma=0.5, vmin=1e-9, vmax=0.15), cmap='turbo', ax=axes[0], )
swe_seas_mean_2050s.sel(season='DJF').plot(norm=PowerNorm(gamma=0.5, vmin=1e-9, vmax=0.15), cmap='turbo', ax=axes[1], )
swe_seas_mean_2080s.sel(season='DJF').plot(norm=PowerNorm(gamma=0.5, vmin=1e-9, vmax=0.15), cmap='turbo', ax=axes[2], )
plt.suptitle('Winter mean SWE [m]', fontsize=16)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24, 6))

swe_seas_mean_2020s.sel(season='MAM').plot(norm=PowerNorm(gamma=0.5, vmin=1e-9, vmax=0.2), cmap='turbo', ax=axes[0])
swe_seas_mean_2050s.sel(season='MAM').plot(norm=PowerNorm(gamma=0.5, vmin=1e-9, vmax=0.2), cmap='turbo', ax=axes[1])
swe_seas_mean_2080s.sel(season='MAM').plot(norm=PowerNorm(gamma=0.5, vmin=1e-9, vmax=0.2), cmap='turbo', ax=axes[2])
plt.suptitle('Spring mean SWE [m]', fontsize=16)