In [1]:
import os
import sys
import json
import torch
import xarray as xr

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from loaddata import load_cesm2_by_period, load_real
from UNet import UNet
from testing import test
from utils import DotDict
from analog import *

%load_ext autoreload
%autoreload 2

In [2]:
# Parameters
exp = 'test'
out_dir = f'../output/{exp}'
library_dir = '../data/cesm2'

data_dir = '../data/real'
test_data = 'real'
t1_dist_f = 'target_distance.nc'
test_period = (1987, 2020)

with open(f'{out_dir}/hyperparameters.json', 'r') as f:
    hp = json.load(f)
    hp = DotDict(hp)

# Data for UNet

In [3]:
# load library
periods = {'train': hp.periods['train']}
_, _, t0_library, t0_mask_library, t1_library = load_cesm2_by_period(
    library_dir, hp.vnames, hp.lat_slice, 
    hp.target_vname, hp.target_grid, hp.target_lat_slice, hp.target_lon_slice,
    hp.t1_dist_f, hp.lead, hp.month, periods, hp.batch_size, 
    shuffle=False)

# Load real data
periods['test'] = test_period
dataset, dataloader, t0_mask_data = load_real(
    data_dir, hp.vnames, hp.lat_slice, 
    hp.target_vname, hp.target_grid, hp.target_lat_slice, hp.target_lon_slice,
    t1_dist_f, hp.lead, hp.month, periods, hp.batch_size)

# Combine masks (only use grids where both data exist)
t0_mask = t0_mask_library | t0_mask_data

# dimension
x, _, _ = dataset[0]
n_channels = x.shape[0]

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
t0_library = t0_library.to(device)
t0_mask = t0_mask.to(device)
t1_library = t1_library.to(device)

Using cuda device


# Evaluate

In [4]:
for epoch in range(hp.n_epochs):
    f = f'{out_dir}/model_epoch{epoch}.pt'
    if not os.path.exists(f):
        continue
    
    # Load model
    model = UNet(
        in_ch=n_channels, out_ch=n_channels, 
        init_ch=hp.init_ch, depth=hp.depth,
        attention=hp.attention, is_res=hp.is_res,
        ).to(device)
    model.load_state_dict(torch.load(f))
    model.eval()

    # Test acc, get analog indices
    loss, acc, mse, ma_idx, weights = test(
        model, device, dataloader, 
        t0_library, t0_mask, hp.n_sub, t1_library,
        n_analog=hp.n_analog, insample=False,
    )
    print(f'{exp} epoch {epoch:3d}: M_MSE = {acc:.3f}, MSE = {mse:.3f}')

    # mask land
    weights[:, t0_mask.cpu().detach().numpy()] = np.nan

    # to xarray
    ma_idx = xr.DataArray(
        ma_idx[:, :100], dims=['sample', 'analog'],
        coords={'sample': dataset.t0_ds.sample}
    ).rename({'sample': 'year'})
    weights = xr.Dataset(
        {dataset.vnames[i]: 
            (('year', 'lat', 'lon'), weights[:, i])for i in range(len(hp.vnames))},
        coords=dataset.t0_ds.coords
    ).rename({'sample': 'year'})
    
    # save
    ma_idx.to_netcdf(f'{out_dir}/ma_index_{test_data}_epoch{epoch}.nc')
    weights.to_netcdf(f'{out_dir}/weight_{test_data}_epoch{epoch}.nc')

test epoch   9: M_MSE = 0.892, MSE = 0.374
test epoch  19: M_MSE = 0.861, MSE = 0.411
test epoch  29: M_MSE = 0.845, MSE = 0.421
test epoch  39: M_MSE = 0.819, MSE = 0.418
test epoch  49: M_MSE = 0.825, MSE = 0.417
test epoch  51: M_MSE = 0.786, MSE = 0.392
test epoch  59: M_MSE = 0.813, MSE = 0.414
