In [None]:
%cd ~/perpetual_day/

In [None]:
import os
from pathlib import Path
from shutil import rmtree

import torch
import lightning as L
import pyearthtools.data as petdata
import pyearthtools.pipeline as petpipe
import pyearthtools.training as pettrain
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from pipeline import full_pipeline, filter_dates
from models import UNet

In [None]:
date_range = petpipe.iterators.DateRange('20200101T0000', '20220101T0000', interval='30 minutes')
bbox = [-35.27, -34, 150, 151.27]
cachedir = "/scratch/nf33/mr3857/cache2"

In [None]:
%%time
fullpipe = full_pipeline(date_range, bbox, cachedir)

In [None]:
%%time
good_dates = filter_dates(fullpipe, n_jobs=12)

In [None]:
len(good_dates)

In [None]:
train_dates, valid_dates = train_test_split(
    good_dates,
    random_state=42,
    shuffle=True,
    test_size=0.1
)
train_split = petpipe.iterators.Predefined(train_dates)
valid_split = petpipe.iterators.Predefined(valid_dates)

dm = pettrain.data.lightning.PipelineLightningDataModule(
    fullpipe,
    train_split=train_split,
    valid_split=valid_split,
    batch_size=32,
    num_workers=6,
    shuffle=True,
    multiprocessing_context="forkserver",
    persistent_workers=True,
)

In [None]:
features, targets = next(iter(fullpipe))
print(f"features shape and type: {features.shape}, {features.dtype}")
print(f"targets shape and type: {targets.shape}, {targets.dtype}")

In [None]:
%%time
model = UNet(
    chan_in=features.shape[0],
    chan_out=targets.shape[0],
    sample_size=features.shape[1],
    learning_rate=1e-4,
)
trainer = L.Trainer(max_epochs=1, precision="16-mixed")
trainer.fit(model, dm)

In [None]:
# get the target pipeline (without cache) to undo transforms on predictions
targetpipe = fullpipe.steps[0].sub_pipelines[1]
targetpipe = petpipe.Pipeline(*targetpipe.steps[:-2], targetpipe.steps[-1])
_ = targetpipe["20200301T0000"]

In [None]:
def plot_preds(target, pred):
    _, axes = plt.subplots(1, 3, figsize=(12, 3), layout="constrained")
    field = "channel_0003_scaled_radiance"
    target[field].plot(ax=axes[0])
    pred[field].plot(ax=axes[1])
    (pred - target)[field].plot(ax=axes[2])

In [None]:
dm.train()
for i in range(3):
    features, targets = dm[i]
    features = torch.from_numpy(features).unsqueeze(0).cuda()
    targets = targetpipe.undo(targets).isel(time=0)

    preds = model.cuda().forward(features)
    preds = preds.squeeze(0).cpu().detach().numpy()
    preds = targetpipe.undo(preds).isel(time=0)

    plot_preds(targets, preds)

In [None]:
dm.eval()
for i in range(3):
    features, targets = dm[i]
    features = torch.from_numpy(features).unsqueeze(0).cuda()
    targets = targetpipe.undo(targets).isel(time=0)
    
    preds = model.cuda().forward(features)
    preds = preds.squeeze(0).cpu().detach().numpy()
    preds = targetpipe.undo(preds).isel(time=0)

    plot_preds(targets, preds)