# Tutorial 3: Evaluating the trained model on the test data

### Outline

* Imports, including library code from previous steps
* Loading the trained model using hyperparameters and weights file
* Setting up the datapipe for the test data
* Some functions for "undoing/inverting" the ETL pipeline (aka recovering spatiotemporal relations)
* Running the trained model in eval mode
* Some basic metrics and analysis

In [None]:
import torch
import xarray as xr
import matplotlib.pyplot as plt
import warnings

from tqdm.autonotebook import tqdm
from src.models import create_lstm_model
from src.utils import load_experiment
from src.datapipes import make_data_pipeline, merge_data, select_region

warnings.filterwarnings('ignore')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device('cpu')
DTYPE = torch.float32

In [None]:
config_file = '../experiments/tutorial/tutorial.yml'
config = load_experiment(config_file)
model = create_lstm_model(**config['model_config'])
model.load_state_dict(torch.load(config['weights_file']))
model.to(DEVICE)
model.eval()

In [None]:
ds = merge_data()
test_data = select_region(
    ds.sel(time=config['data_config']['test_period']),
    config['data_config']['regions']
)

In [None]:
true_mask = test_data['mask'].copy()
test_data = test_data.fillna(1.0)
test_data['mask'].values[:] = 1.0
test_data = test_data.fillna(1.0)

In [None]:
actual_shape = (
    len(test_data['lat']),
    len(test_data['lon']),
    config['data_config']['output_sequence_length']
)
        

In [None]:
config['data_config']['batch_dims'] = {
    'lat': len(test_data['lat']),
    'lon': len(test_data['lon'])
}

In [None]:
pipe = make_data_pipeline(
    ds=test_data, 
    min_samples=0, 
    preload=True,
    filter_mask=False,
    **config['data_config']
)

In [None]:
predictions = []
for i, (x, y) in tqdm(enumerate(pipe)):
    # if i == 5:
    #     break
    x = x.to(DEVICE)
    with torch.no_grad():
        yhat = model(x).cpu()
    yhat = yhat.reshape(actual_shape)
    predictions.append(yhat)

In [None]:
test_data

In [None]:
swe_pred = xr.DataArray(
    torch.concat(predictions, dim=2).squeeze().cpu(),
    dims=('lat', 'lon', 'time')
) 
swe_pred = swe_pred.assign_coords(test_data.drop('time').coords)

In [None]:
config['data_config']

In [None]:
start_time = config['data_config']['input_overlap']['time'] 

In [None]:
test_data.isel(time=slice(start_time, None))

In [None]:
swe_pred

In [None]:
loc = {'lat': 47, 'lon': 245}
(test_data['swe'].drop('time')
  .isel(time=slice(start_time, None))
  .sel(**loc, method='nearest')
  .plot())
(3 * swe_pred).sel(**loc, method='nearest').plot()

In [None]:
swe_pred.isel(time=slice(120, 400)).mean(dim='time').plot(vmax=0.1, vmin=0, cmap='turbo')

In [None]:
(test_data['swe']
  .isel(time=slice(start_time, None))
  .isel(time=slice(120, 400)).mean(dim='time').plot(vmax=0.1, vmin=0, cmap='turbo'))