In [1]:
base_dir = '/glade/work/kjmayer/research/catalyst/S2S_ocn_lnd_atm/'

In [2]:
import xarray as xr
import numpy as np
import random

import torch
import torchinfo
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from model.train_utils import NeuralNetwork
import utils.utils
from utils.utils import get_config
from utils.utils import prepare_device
from trainer.trainer import Trainer
import model.metrics as module_metric
from data_prep.data_loader import GetData,MakeDataset,lead_shift, concat_input

In [3]:
# torch.cuda.is_available()

In [7]:
config = get_config("exp_1")

torch.manual_seed(config["seed"])
torch.cuda.manual_seed(config["seed"])
np.random.seed(config["seed"])
random.seed(config["seed"])
torch.backends.cudnn.deterministic = True

In [8]:
LEAD = 7 # will loop over this for training eventually
trainfinames = config["data_loader"]["anommems_finames"][0:6]
valfinames = config["data_loader"]["anommems_finames"][6:8]

xtrain, xtrainmean, xtrainstd, xtrainmin, xtrainmax = GetData(dir=config["data_loader"]["base_dir"],
                                                              var=config["data_loader"]["atm_var"],
                                                              finames=trainfinames,
                                                              train=True,
                                                              climo=False)[0]
# xtrain = xtrain.stack(l=('lat','lon'))
xtrain_shift = lead_shift(xtrain, lead=LEAD, forward=False)
ytrain = xtrain.sel(lat=slice(29,61))
ytrain = ytrain.stack(l=('lat','lon'))
ytrain_shift = lead_shift(ytrain, lead=LEAD, forward=True)

xval = GetData(dir=config["data_loader"]["base_dir"],
               var=config["data_loader"]["atm_var"],
               finames=valfinames,
               train=False,
               trainmean=xtrainmean,
               trainstd=xtrainstd,
               trainmin=xtrainmin,
               trainmax=xtrainmax,
               climo=False)[0]
# 
xval_shift = lead_shift(xval, lead=LEAD, forward=False)
yval = xval.sel(lat=slice(29,61))
yval = yval.stack(l=('lat','lon'))
yval_shift = lead_shift(yval, lead=LEAD, forward=True)


# climo (same for train, val, and test --> basically a DOY encoder)
xclimo, climomin, climomax = GetData(dir = config["data_loader"]["base_dir"],
                                     var = config["data_loader"]["atm_var"],
                                     finames = config["data_loader"]["climo_finame"],
                                     train = False, # MUST ALWAYS BE False FOR CLIMO
                                     climo = True)[0]
# climo appended to same length as training
xclimo = xclimo.rename({'dayofyear': 'time'})
xclimo_train = xr.concat([xclimo]*int(xtrain.shape[0]/365),dim='mem')
xclimo_train = xclimo_train.stack(s=('mem','time')).transpose('s', 'lat', 'lon').reset_index(['s'])
xclimo_train_shift = lead_shift(xclimo_train,lead=LEAD,forward=False)

# climo appended to same length as validation
xclimo_val = xr.concat([xclimo]*int(xval.shape[0]/365),dim='mem')
xclimo_val = xclimo_val.stack(s=('mem','time')).transpose('s', 'lat', 'lon').reset_index(['s'])
xclimo_val_shift = lead_shift(xclimo_val,lead=LEAD,forward=False)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  result = result._stack_once(dims, new_dim)
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  result = result._stack_once(dims, new_dim)
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  result = result._stack_once(dims, new_dim)
    >>> with 

In [9]:
Xtrain = concat_input(xtrain_shift,xclimo_train_shift,dim_name='features').values
del xtrain_shift,xclimo_train_shift
Xval = concat_input(xval_shift,xclimo_val_shift,dim_name='features').values
del xval_shift,xclimo_val_shift
Ytrain = ytrain_shift.values
del ytrain_shift
Yval = yval_shift.values
del yval_shift

In [10]:
## Prep training and validation for ANN
training_data = MakeDataset(Xtrain,Ytrain) #xtrain, ytrain need to be numpy, not xarray
val_data = MakeDataset(Xval,Yval)

#[batch_size,channels,lat,lon]
train_dataloader = DataLoader(training_data,batch_size = config["data_loader"]["batch_size"],shuffle=True)
val_dataloader  = DataLoader(val_data,batch_size = config["data_loader"]["batch_size"],shuffle=True)

  self.X = torch.tensor(torch.from_numpy(X), dtype = torch.float32)#.unsqueeze(1)
  self.y = torch.tensor(torch.from_numpy(X), dtype = torch.float32) # unsqueeze?


In [None]:
for input, output in val_dataloader:
    print(input)
    print(output)
    break

In [11]:
## create NN architecture
model = NeuralNetwork(config=config["arch_atm"])

## grab optimizer and loss 
optimizer = getattr(torch.optim, config["optimizer"]["type"])(
    model.parameters(), **config["optimizer"]["args"]
)
criterion = getattr(torch.nn, config["criterion"])()

metric_funcs = [getattr(module_metric, met) for met in config["metrics"]]

## Build the trainer
device = prepare_device(config["device"])
trainer = Trainer(
    model,
    criterion,
    metric_funcs,
    optimizer,
    max_epochs=config["trainer"]["max_epochs"],
    data=train_dataloader,
    validation_data=val_dataloader,
    device=device,
    config=config,
)

# Visualize the model
torchinfo.summary(
    model,
    input_size=(config["data_loader"]["batch_size"], 2448),
    verbose=1,
    col_names=("input_size", "output_size", "num_params"),
)

# Train the Model
# model.to(device)
# trainer.fit()

AttributeError: 'Flatten' object has no attribute 'size'

In [None]:
print(trainer.log.history.keys())

plt.figure(figsize=(20, 4))
for i, m in enumerate(("loss", *config["metrics"])):
    plt.subplot(1, 4, i + 1)
    plt.plot(trainer.log.history["epoch"], trainer.log.history[m], label=m)
    plt.plot(
        trainer.log.history["epoch"], trainer.log.history["val_" + m], label="val_" + m
    )
    plt.axvline(
        x=trainer.early_stopper.best_epoch, linestyle="--", color="k", linewidth=0.75
    )
    plt.title(m)
    plt.legend()
plt.tight_layout()
plt.show()

In [None]:
model.eval()
preds = model(x_test)