In [None]:
%cd ~/perpetual_day/

In [None]:
import os
from pathlib import Path
from shutil import rmtree

import torch
import lightning as L
import pyearthtools.pipeline as petpipe
import pyearthtools.training as pettrain
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from pipeline import full_pipeline
from models import CNN

In [None]:
date_range = petpipe.iterators.DateRange('20200101T0000', '20200101T0200', interval='30 minutes')
bbox = [-35, -34, 150, 151]
cachedir = "/scratch/nf33/mr3857/cache"

In [None]:
fullpipe = full_pipeline(date_range, bbox, cachedir, clean_cache=False)

In [None]:
train_dates, valid_dates = train_test_split(
    fullpipe.iterator.samples,
    random_state=42,
    shuffle=True,
    test_size=0.2
)
train_split = petpipe.iterators.Predefined(train_dates)
valid_split = petpipe.iterators.Predefined(valid_dates)

dm = pettrain.data.lightning.PipelineLightningDataModule(
    fullpipe,
    train_split=train_split,
    valid_split=valid_split,
    batch_size=4,
    num_workers=8,
    shuffle=True,
    multiprocessing_context="forkserver",
    persistent_workers=True,
)

In [None]:
%%time
model = CNN(chan_in=9, chan_out=1)
trainer = L.Trainer(max_epochs=1, precision="16-mixed")
trainer.fit(model, dm)

In [None]:
def plot_preds(target, pred):
    target = target[0]
    pred = pred[0]
    _, axes = plt.subplots(1, 3, figsize=(12, 5))
    axes[0].imshow(target)
    axes[1].imshow(pred)
    axes[2].imshow(pred - target)

In [None]:
dm.train()
for i in range(3):
    features, targets = dm[i]
    features = torch.from_numpy(features).cuda()
    preds = model.cuda().forward(features)
    preds = preds.cpu().detach().numpy()
    plot_preds(targets, preds)

In [None]:
dm.eval()
for i in range(3):
    features, targets = dm[i]
    features = torch.from_numpy(features).cuda()
    preds = model.cuda().forward(features)
    preds = preds.cpu().detach().numpy()
    plot_preds(targets, preds)