In [1]:
base_dir = '/glade/work/kjmayer/research/catalyst/S2S_ocn_lnd_atm/'

In [3]:
import xarray
import numpy
import torch
import torchinfo
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from model.train_utils import NeuralNetwork
import utils.utils
from trainer.trainer import Trainer
import model.metrics as module_metric
from data_prep.data_loader import MakeDataset

In [None]:
torch.cuda.is_available()

In [None]:
config = utils.get_config("exp_foo")

torch.manual_seed(config["seed"])
torch.cuda.manual_seed(config["seed"])
np.random.seed(config["seed"])
random.seed(config["seed"])
torch.backends.cudnn.deterministic = True

In [4]:
## NEED TO LOAD DATA



In [None]:
## Prep training and validation for ANN
# !!!NOTE: currently assuming that xtrain, ytrain, xval and yval are numpy!!!
training_data = MakeDataset(xtrain,ytrain)
val_data = TensorDataset(xval,yval)

train_dataloader = DataLoader(training_data,batch_size = config["data_loader"]["batch_size"],shuffle=True)
val_dataloader  = DataLoader(val_data,batch_size = config["data_loader"]["batch_size"],shuffle=True)

In [None]:
## create NN architecture
model = NeuralNetwork(config=config["arch"])

## grab optimizer and loss 
optimizer = getattr(torch.optim, config["optimizer"]["type"])(
    model.parameters(), **config["optimizer"]["args"]
)
criterion = getattr(torch.nn, config["criterion"])()

metric_funcs = [getattr(module_metric, met) for met in config["metrics"]]

## Build the trainer
device = utils.prepare_device(config["device"])
trainer = Trainer(
    model,
    criterion,
    metric_funcs,
    optimizer,
    max_epochs=config["trainer"]["max_epochs"],
    data=train_dataloader,
    validation_data=val_dataloader,
    device=device,
    config=config,
)

# Visualize the model
torchinfo.summary(
    model,
    input_size=(config["data_loader"]["batch_size"], 1),
    verbose=1,
    col_names=("input_size", "output_size", "num_params"),
)

# Train the Model
model.to(device)
trainer.fit()

In [None]:
print(trainer.log.history.keys())

plt.figure(figsize=(20, 4))
for i, m in enumerate(("loss", *config["metrics"])):
    plt.subplot(1, 4, i + 1)
    plt.plot(trainer.log.history["epoch"], trainer.log.history[m], label=m)
    plt.plot(
        trainer.log.history["epoch"], trainer.log.history["val_" + m], label="val_" + m
    )
    plt.axvline(
        x=trainer.early_stopper.best_epoch, linestyle="--", color="k", linewidth=0.75
    )
    plt.title(m)
    plt.legend()
plt.tight_layout()
plt.show()

In [None]:
model.eval()
preds = model(x_test)