In [1]:
import os
import sys
import json
import torch
import wandb 
from torchinfo import summary

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from loaddata import load_cesm2_by_period
from UNet import UNet
from trainer import Trainer, EarlyStopper
from utils import DotDict

%load_ext autoreload
%autoreload 2

In [7]:
# Experiment name
exp = 'test'
out_dir = f'../output/{exp}'
data_dir = '../data/cesm2'

# Hyperparameters
hp = {
    'vnames': ['sst', 'ssh', 'taux'],
    'lat_slice': (-50, 50),
    'target_vname': 'sst',
    'target_grid': '2x2',
    'target_lat_slice': (-10, 10),
    'target_lon_slice': (120, 290),
    't1_dist_f': 'target_distance_scaled_0_-3_-6_-9.nc',
    # 't1_dist_f': 'target_distance_0_-3_-6_-9.nc',
    # 't1_dist_f': 'target_distance.nc',
    'lead': 12,
    'month': 1,
    'periods': {
        'train': (1865, 1958),
        'val': (1959, 1985),
        'test': (1986, 1998),
    },
    'batch_size': 16,
    'learning_rate': 1e-5,
    'model': 'UNet',
    'attention': False,
    'is_res': False,
    'depth': 4,
    'init_ch': 256,
    'n_sub': 188,
    'n_analog': 30,
    'n_epochs': 60,
}
hp = DotDict(hp)

In [8]:
%%time
# load data
(datasets, dataloaders, 
 t0_library, t0_mask, 
 t1_library) = load_cesm2_by_period(data_dir, **hp)

print(t0_library.shape)
print(t1_library.shape)

torch.Size([9400, 3, 21, 72])
torch.Size([9400, 11, 86])
CPU times: user 2.37 s, sys: 13.7 s, total: 16.1 s
Wall time: 1min 15s


In [9]:
# dimension
x, _, _ = datasets['train'][0]
x_shape = tuple(x.shape)
n_channels = x.shape[0]
print(x_shape)

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
t0_library = t0_library.to(device)
t0_mask = t0_mask.to(device)
t1_library = t1_library.to(device)

(3, 21, 72)
Using cuda device


# Define a neural network

In [10]:
model = UNet(
    in_ch=n_channels, 
    out_ch=n_channels, 
    init_ch=hp.init_ch, 
    depth=hp.depth,
    attention=hp.attention, 
    is_res=hp.is_res,
    ).to(device)

model_stats = summary(model, input_size=(hp.batch_size, *x_shape))
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model_stats)
print(f'Total params: {total_params:,}')
hp['total_params'] = total_params

Layer (type:depth-idx)                             Output Shape              Param #
UNet                                               [16, 3, 21, 72]           --
├─Encoder: 1-1                                     [16, 256, 21, 72]         --
│    └─ModuleList: 2-7                             --                        (recursive)
│    │    └─Down: 3-1                              [16, 256, 21, 72]         598,272
│    └─MaxPool2d: 2-2                              [16, 256, 10, 36]         --
│    └─ModuleList: 2-7                             --                        (recursive)
│    │    └─Down: 3-2                              [16, 512, 10, 36]         3,542,016
│    └─MaxPool2d: 2-4                              [16, 512, 5, 18]          --
│    └─ModuleList: 2-7                             --                        (recursive)
│    │    └─Down: 3-3                              [16, 1024, 5, 18]         14,161,920
│    └─MaxPool2d: 2-6                              [16, 1024, 2, 9] 

# Train the model

In [11]:
%%time
# Initialize wandb
wandb.init(
    project='MA-UNet', config=hp,
    name=exp,
)

# Save hyperparameters
os.makedirs(out_dir, exist_ok=True)
# torch.save(hp, f'{out_dir}/hyperparameters.pt')
with open(f'{out_dir}/hyperparameters.json', 'w') as f:
    json.dump(hp, f)

# Train
optimizer = torch.optim.Adam(model.parameters(), lr=hp.learning_rate) 
early_stopper = EarlyStopper(patience=10, min_delta=0.01, out_dir=out_dir)
trainer = Trainer(
    hp.n_epochs, model, device, optimizer, dataloaders,
    t0_library, t0_mask, hp.n_sub, t1_library, hp.n_analog,
    early_stopper,
)
history = trainer()

# Save history
history.astype('float').to_csv(f'{out_dir}/history.csv')

# Close wandb run
wandb.finish()

Epoch   0, train: 0.00808,   0.576, val:  0.0077,   0.601
Epoch   1, train: 0.00732,   0.559, val:  0.0075,   0.599
Epoch   2, train: 0.00694,   0.549, val: 0.00719,   0.593
Epoch   3, train: 0.00652,   0.539, val: 0.00693,   0.587
Epoch   4, train: 0.00616,   0.527, val: 0.00672,   0.584
Epoch   5, train: 0.00581,   0.514, val: 0.00706,   0.573
Epoch   6, train: 0.00547,   0.495, val:  0.0064,   0.578
Epoch   7, train: 0.00511,   0.470, val: 0.00631,   0.567
Epoch   8, train: 0.00479,   0.443, val: 0.00628,   0.563
Epoch   9, train: 0.00441,   0.408, val: 0.00608,   0.559
Epoch  10, train: 0.00411,   0.378, val: 0.00606,   0.556
Epoch  11, train: 0.00377,   0.344, val: 0.00595,   0.546
Epoch  12, train: 0.00348,   0.316, val: 0.00596,   0.546
Epoch  13, train: 0.00321,   0.291, val: 0.00594,   0.539
Epoch  14, train:   0.003,   0.272, val: 0.00617,   0.535
Epoch  15, train: 0.00282,   0.257, val: 0.00597,   0.528
Epoch  16, train: 0.00265,   0.242, val: 0.00589,   0.528
Epoch  17, tra

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_loss,█▇▆▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_mse,██▇▇▇▆▅▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▇▅▅▄▃▂▂▂▂▂▂▁▁▂▁▁▄▁▂▂▁▂▁▂▂▁▂▂▂▂▂▂▂▂▂▃▃▃▃
val_mse,██▇▇▆▅▅▄▄▃▂▂▂▃▂▂▂▂▂▁▁▂▁▂▁▂▂▁▁▁▂▁▁▂▁▁▁▁▁▁

0,1
best_epoch,51.0
best_mse,0.51367
train_loss,0.00102
train_mse,0.1368
val_loss,0.00616
val_mse,0.5176


CPU times: user 1h 1min 36s, sys: 10.1 s, total: 1h 1min 47s
Wall time: 1h 2min 38s


# Test with uniform weights

In [6]:
from testing import test_uniform

In [32]:
model.eval()
for key, dataloader in dataloaders.items():
    print(key)

    if key == 'train':
        insample = True
    else:
        insample = False

    loss, mean_mse, mse = test_uniform(
        model, device, dataloader, 
        t0_library, t0_mask, hp.n_sub, t1_library,
        n_analog=hp.n_analog, insample=insample,     
    )

    print(f'{loss:7.3g}, {mean_mse:7.3f}, {mse:7.3f}')   

train
0.00584,   0.794,   0.369
val
 0.0059,   0.834,   0.416
test
0.00601,   0.866,   0.449
