In [1]:
import os
import sys
import json
import torch
import xarray as xr

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from loaddata import load_cesm2_by_period
from UNet import UNet
from testing import test
from utils import DotDict
from analog import *

%load_ext autoreload
%autoreload 2

In [2]:
# Parameters
exp = 'test'
out_dir = f'../output/{exp}'
data_dir = '../data/cesm2'
test_data = 'test'

with open(f'{out_dir}/hyperparameters.json', 'r') as f:
    hp = json.load(f)
    hp = DotDict(hp)

periods = {
    'library': hp.periods['train'],
    'target': hp.periods[test_data],
}

if test_data == 'train':
    insample = True
else:
    insample = False

In [3]:
# load data for U-Net
(datasets, dataloaders, 
 t0_library, t0_mask, 
 t1_library) = load_cesm2_by_period(
    data_dir, **hp, shuffle=False)

# dimension
x, _, _ = datasets['train'][0]
n_channels = x.shape[0]

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
t0_library = t0_library.to(device)
t0_mask = t0_mask.to(device)
t1_library = t1_library.to(device)

Using cuda device


In [4]:
for epoch in range(hp.n_epochs):
    f = f'{out_dir}/model_epoch{epoch}.pt'
    if not os.path.exists(f):
        continue
    
    # Load model
    model = UNet(
        in_ch=n_channels, out_ch=n_channels, 
        init_ch=hp.init_ch, depth=hp.depth,
        attention=hp.attention, is_res=hp.is_res,
        ).to(device)
    model.load_state_dict(torch.load(f))
    model.eval()

    # Test acc, get analog indices
    loss, acc, mse, ma_idx, weights = test(
        model, device, dataloaders[test_data], 
        t0_library, t0_mask, hp.n_sub, t1_library,
        n_analog=hp.n_analog, insample=insample,
    )
    print(f'{exp} epoch {epoch:3d}: M_MSE = {acc:.3f}, MSE = {mse:.3f}')
            
    # mask land
    weights[:, t0_mask.cpu().detach().numpy()] = np.nan

    # to xarray
    ma_idx = xr.DataArray(ma_idx[:, :100], dims=['sample', 'analog'],
                          coords={'sample': datasets[test_data].t0_ds['sample']}
                         )
    weights = xr.Dataset(
        {datasets[test_data].vnames[i]: 
            (('sample', 'lat', 'lon'), weights[:, i])for i in range(len(hp.vnames))},
        coords={dim: datasets[test_data].t0_ds[dim] 
                for dim in ['sample', 'lat', 'lon']}
    )
    
    # save
    ma_idx.astype(float).unstack().astype(int).to_netcdf(f'{out_dir}/ma_index_{test_data}_epoch{epoch}.nc')
    weights.unstack().to_netcdf(f'{out_dir}/weight_{test_data}_epoch{epoch}.nc')

test epoch   9: M_MSE = 0.907, MSE = 0.578
test epoch  19: M_MSE = 0.847, MSE = 0.551
test epoch  29: M_MSE = 0.811, MSE = 0.537
test epoch  39: M_MSE = 0.803, MSE = 0.544
test epoch  49: M_MSE = 0.797, MSE = 0.538
test epoch  51: M_MSE = 0.793, MSE = 0.536
test epoch  59: M_MSE = 0.793, MSE = 0.538
