# Tutorial 3: Evaluating the trained model on the test data

### Outline

* Imports, including library code from previous steps
* Loading the trained model using hyperparameters and weights file
* Setting up the datapipe for the test data
* Some functions for "undoing/inverting" the ETL pipeline (aka recovering spatiotemporal relations)
* Running the trained model in eval mode
* Some basic metrics and analysis

## Setup and configuration

As you might expect, we start with some standard imports. Here we will be also importing our ability to create LSTM based models and load experiments which fit our narrow focus via the `src.models` and `src.utils` modules which you can find included with the tutorial. Additionally we are using the `src.datapipes` module which comes from earlier as well. Finally, we set the device and data type as usual.

In [None]:
import torch
import xarray as xr
import matplotlib.pyplot as plt
import warnings

from tqdm.autonotebook import tqdm
from src.models import create_lstm_model
from src.utils import load_experiment
from src.datapipes import make_data_pipeline, merge_data, select_region

warnings.filterwarnings('ignore')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16

## Loading the setup from our saved experiment

With the defaults out of the way, we can prove how nice it is to have even some minimal MLOps infrastructure set up via the `save_experiment` and `load_experiment` functions by simply loading up the previously saved experiment. We then instantiate an equivalent model structure and load the trained model into it. We are then ready to start thinking about how we can apply this model to new data!

In [None]:
config_file = '../experiments/tutorial/tutorial.yml'
config = load_experiment(config_file)
model = create_lstm_model(**config['model_config'])
model.load_state_dict(torch.load(config['weights_file'], map_location=DEVICE))
model.to(DTYPE).to(DEVICE)
model.eval()

## Data plumbing for model inference

We expect you're getting excited at this point to see how the trained model performs, but the unfortunate reality of using deep-learning models means that there is often a rift between training and application data workflows. In our case it's not to onerous, but it does require a little bit of work to be most efficient. As always, we just want to open the data up as a first step, and we can reuse previous data processing functions to get there.

In [None]:
ds = merge_data()
test_data = select_region(
    ds.sel(time=config['data_config']['test_period']),
    config['data_config']['regions']
)

Beyond that, we don't wan't to filter out missing data because that would mean we have ragged arrays that need complex logic to reconstruct. Instead we can just fill missing data. On one hand, this might seem like a hack simply to reduce the number of lines of code, but on the other is actually a nice optimization because of how fast these types of trained models can run compared to actually solving differential equations. That is, it's easier to run a forward pass on some irrelevant data than it is to activiely filter out the missing regions.

In [None]:
true_mask = test_data['mask'].copy()
test_data = test_data.fillna(1.0)
test_data['mask'].values[:] = 1.0
test_data = test_data.fillna(1.0)

Now, with the "filled" data, we want to make sure that we record the actual dimensions of the data as a tuple of `(lat, lon, time)` so that we can reconstruct things later.

In [None]:
actual_shape = (
    len(test_data['lat']),
    len(test_data['lon']),
    config['data_config']['output_sequence_length']
)

Now we can record the dimensions that we'll iterate over in place over the onces that we used at train time. What we are doing here is simply setting the `batch_dims` keyword that will go into the `make_data_pipeline` function that we developed in the `datapipes` module. Effectively all this says is to run the full domain on every forward pass of the network.

In [None]:
config['data_config']['batch_dims'] = {
    'lat': len(test_data['lat']),
    'lon': len(test_data['lon'])
}

And finally, we make the data pipeline with our same old `make_data_pipeline` function. We set a few extra parameters like `min_samples=0` so we don't filter anywhere out, `preload=True` so we load the data automatically to save computational cost, and `filter_mask=False` to include data outside of the masked region in order to make spatiotemporally complete predictions. The rest of the configuration comes from what we recorded in the `config`.

In [None]:
pipe = make_data_pipeline(
    ds=test_data, 
    min_samples=0, 
    preload=True,
    filter_mask=False,
    **config['data_config']
)

## Running the model in forward/inference mode

Given this, we are ready to actually run our trained model on the testing dataset. For this we'll loop over the new `pipe` object. For every element in the data pipe we can simply transfer it onto the `DEVICE` and run it through the model. In the process we have specified `with torch.no_grad()` which ensures that we do not run the backwards pass on the model, saving computation. We also make sure to transfer the predictions back onto the CPU from whatever `DEVICE` they were run on and finally reshape everything back into the correct shape, which basically unflattens things so that the spatial relations are recovered.

In [None]:
predictions = []
for i, (x, y) in tqdm(enumerate(pipe)):
    x = x.to(DTYPE).to(DEVICE)
    with torch.no_grad():
        yhat = model(x).cpu().float()
    yhat = yhat.reshape(actual_shape)
    predictions.append(yhat)

## Putting the data back together
With the above cell we've done most of the work, but it's often convenient to repackage the resulting predictions from a stack of numpy arrays into an xarray Dataset so that we can easily locate the data in time/space coordinates that are easily interpreted by humans. To do so we can simply make use of the "coordinates" from the truth/reference dataset from ERA5.

In [None]:
start_time = config['data_config']['input_overlap']['time'] 
swe_true = test_data['swe'].isel(time=slice(start_time, -7))

swe_pred = xr.DataArray(
    torch.concat(predictions, dim=2).squeeze().cpu(),
    dims=('lat', 'lon', 'time'), coords = swe_true.coords
) 
swe_pred

## Getting to the analysis and quantifying model performance

At this point we've assembled all of the predictions and reference data into similar data formats and all that's left for us to do is some analysis. Model analysis very problem and domain specific, but we will cover some basic analytics on our model here. First, if you have applied your model to the `WNA`, or Western North America region, you will see we've picked out some individual regions to look at individual timeseries on the spatial averages. These include major portions of prominent mountain ranges such as the Southern Rocky Mountains, Northern Cascade Mountains, and Central Sierra Nevada Mountains.

In [None]:
loc = {'lat': slice(40, 38), 'lon': slice(252, 254), }
swe_true.sel(**loc).mean(dim=['lat', 'lon']).plot()
(3 * swe_pred).sel(**loc).mean(dim=['lat', 'lon']).plot()
plt.title('Southern Rockies')

In [None]:
loc = {'lat': slice(49, 47), 'lon': slice(238, 240), }# 'method': 'nearest'}
swe_true.sel(**loc).mean(dim=['lat', 'lon']).plot()
(3 * swe_pred).sel(**loc).mean(dim=['lat', 'lon']).plot()
plt.title('Northern Cascades')

In [None]:
loc = {'lat': slice(38.5, 37.5), 'lon': slice(239.75, 240.25), }# 'method': 'nearest'}
swe_true.sel(**loc).mean(dim=['lat', 'lon']).plot()
(3 * swe_pred).sel(**loc).mean(dim=['lat', 'lon']).plot()
plt.title('Central Sierra Nevada')

In [None]:
swe_pred.isel(time=slice(120, 400)).max(dim='time').plot(vmax=0.1, vmin=0, cmap='turbo')

In [None]:
swe_true.isel(time=slice(120, 400)).max(dim='time').plot(vmax=0.1, vmin=0, cmap='turbo')