In [9]:
import os
import sys
import json
import torch
import wandb 
from torchinfo import summary

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from ml_only.loaddata import load_cesm2_by_period
from ml_only.UNet_dynamic import UNet
from ml_only.trainer import Trainer, EarlyStopper
from utils import DotDict

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
# Experiment name
exp = 'ML_test'
out_dir = f'../output/{exp}'
data_dir = '../data/cesm2'

# Hyperparameters
hp = {
    'vnames': ['sst', 'ssh', 'taux'],
    'lat_slice': (-50, 50),
    'target_vname': 'sst',
    'target_grid': '2x2',
    'target_lat_slice': (-10, 10),
    'target_lon_slice': (120, 290),
    'lead': 12,
    'month': 1,
    'periods': {
        'train': (1865, 1958),
        'val': (1959, 1985),
        'test': (1986, 1998),
    },
    'batch_size': 16,
    'learning_rate': 1e-5,
    'model': 'UNet',
    'attention': False,
    'is_res': False,
    'depth': 4,
    'init_ch': 256,
    'n_epochs': 10,
}
hp = DotDict(hp)

In [33]:
%%time
# load data
datasets, dataloaders, t1_wgt = load_cesm2_by_period(data_dir, **hp)

CPU times: user 255 ms, sys: 865 ms, total: 1.12 s
Wall time: 1.1 s


In [34]:
# dimension
x, y = datasets['train'][0]
x_shape = tuple(x.shape)
y_shape = tuple(y.shape)
n_channels = x.shape[0]
print(f'Input shape = {x_shape}')
print(f'Output shape = {y_shape}')

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
t1_wgt = t1_wgt.to(device)

Input shape = (3, 21, 72)
Output shape = (11, 86)
Using cuda device


# Define a neural network

In [35]:
model = UNet(
    in_ch=n_channels, 
    out_ch=1, 
    init_ch=hp.init_ch, 
    depth=hp.depth,
    in_shape=x_shape,
    out_shape=y_shape,
    attention=hp.attention, 
    is_res=hp.is_res,
    ).to(device)

model_stats = summary(model, input_size=(hp.batch_size, *x_shape))
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model_stats)
print(f'Total params: {total_params:,}')
hp['total_params'] = total_params

# Compile model (torch 2.x)
# model = torch.compile(model)

Layer (type:depth-idx)                             Output Shape              Param #
UNet                                               [16, 11, 86]              --
├─Encoder: 1-1                                     [16, 256, 21, 86]         --
│    └─ModuleList: 2-7                             --                        (recursive)
│    │    └─Down: 3-1                              [16, 256, 21, 86]         598,272
│    └─MaxPool2d: 2-2                              [16, 256, 10, 43]         --
│    └─ModuleList: 2-7                             --                        (recursive)
│    │    └─Down: 3-2                              [16, 512, 10, 43]         3,542,016
│    └─MaxPool2d: 2-4                              [16, 512, 5, 21]          --
│    └─ModuleList: 2-7                             --                        (recursive)
│    │    └─Down: 3-3                              [16, 1024, 5, 21]         14,161,920
│    └─MaxPool2d: 2-6                              [16, 1024, 2, 10]

# Train the model

In [36]:
%%time
# Initialize wandb
wandb.init(
    project='ML-ENSO', config=hp,
    name=exp,
)

# Save hyperparameters
os.makedirs(out_dir, exist_ok=True)
# torch.save(hp, f'{out_dir}/hyperparameters.pt')
with open(f'{out_dir}/hyperparameters.json', 'w') as f:
    json.dump(hp, f)

# Train
optimizer = torch.optim.Adam(model.parameters(), lr=hp.learning_rate) 
early_stopper = EarlyStopper(patience=10, min_delta=0.01, out_dir=out_dir)
trainer = Trainer(
    hp.n_epochs, model, device, optimizer, dataloaders,
    t1_wgt, early_stopper,
)
history = trainer()

# Save history
history.astype('float').to_csv(f'{out_dir}/history.csv')

# Close wandb run
wandb.finish()

0,1
train_loss,
val_loss,


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666834522038698, max=1.0)…

Epoch   0, train:   0.654, val:   0.624
Epoch   1, train:   0.507, val:   0.566
Epoch   2, train:   0.426, val:    0.57
Epoch   3, train:   0.324, val:    0.67
Epoch   4, train:   0.233, val:   0.579
Epoch   5, train:   0.189, val:   0.682
Epoch   6, train:   0.157, val:    0.56
Epoch   7, train:   0.137, val:   0.648
Epoch   8, train:   0.128, val:   0.634
Epoch   9, train:   0.113, val:    0.59


0,1
train_loss,█▆▅▄▃▂▂▁▁▁
val_loss,▅▁▂▇▂█▁▆▅▃

0,1
best_epoch,6.0
best_mse,0.56008
train_loss,0.11349
val_loss,0.59042


CPU times: user 6min 4s, sys: 663 ms, total: 6min 5s
Wall time: 6min 24s
