In [None]:
%cd ~/perpetual_day/

In [None]:
import os
from pathlib import Path
from shutil import rmtree

import torch
import lightning as L
import pyearthtools.pipeline as petpipe
import pyearthtools.training as pettrain
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from pipeline import full_pipeline
from models import UNet

In [None]:
date_range = petpipe.iterators.DateRange('20200101T0000', '20200103T0000', interval='30 minutes')
bbox = [-35.27, -34, 150, 151.27]
cachedir = Path("/scratch/nf33/yl4436/cache")

In [None]:
fullpipe = full_pipeline(date_range, bbox, cachedir, clean_cache=True)

In [None]:
train_dates, valid_dates = train_test_split(
    fullpipe.iterator.samples,
    random_state=42,
    shuffle=True,
    test_size=0.2
)
train_split = petpipe.iterators.Predefined(train_dates)
valid_split = petpipe.iterators.Predefined(valid_dates)

dm = pettrain.data.lightning.PipelineLightningDataModule(
    fullpipe,
    train_split=train_split,
    valid_split=valid_split,
    batch_size=4,
    num_workers=8,
    shuffle=True,
    multiprocessing_context="forkserver",
    persistent_workers=True,
)

In [None]:
dm.train()
features, targets = dm[0]
print(features.shape, targets.shape)

In [None]:
%%time
features, targets = next(iter(fullpipe))
print(f"features shape and type: {features.shape}, {features.dtype}")
print(f"targets shape and type: {targets.shape}, {targets.dtype}")
model = UNet(chan_in=9, chan_out=1, sample_size=features.shape[1])
trainer = L.Trainer(max_epochs=10, precision="16-mixed")
trainer.fit(model, dm)

In [None]:
def plot_preds(sample, pred):
    sample = sample[0].cpu().detach().numpy()
    pred = pred[0].cpu().detach().numpy()
    _, axes = plt.subplots(1, 3, figsize=(12, 5))
    axes[0].imshow(sample)
    axes[1].imshow(pred)
    axes[2].imshow(pred - sample)

In [None]:
dm.train()
for i in range(3):
    features, targets = dm[i]
    features = torch.from_numpy(features.unsqueeze(0)).cuda()
    preds = model.cuda().forward(features)
    preds = preds.squeeze(0).cpu().detach().numpy()
    plot_preds(targets, preds)


In [None]:
dm.eval()
for i in range(3):
    features, targets = dm[i]
    features = torch.from_numpy(features.unsqueeze(0)).cuda()
    preds = model.cuda().forward(features)
    preds = preds.squeeze(0).cpu().detach().numpy()
    plot_preds(targets, preds)
