In [None]:
%cd ~/perpetual_day/

In [None]:
import os
from pathlib import Path
from shutil import rmtree

import lightning as L
import pyearthtools.pipeline as petpipe
import matplotlib.pyplot as plt

from pipeline import filter_day_time, features_pipeline, target_pipeline
from datamodule import PetDataModule
from models import CNN

In [None]:
date_range = petpipe.iterators.DateRange('20200101T0000', '20200103T0000', interval='30 minutes')
bbox = [-35, -34, 150, 151]
cache_dir = Path("/scratch/nf33/mr3857/cache")

In [None]:
valid_range = filter_day_time(date_range, bbox)
featpipe = features_pipeline(bbox, cache_dir / "features")
targetpipe = target_pipeline(bbox, cache_dir / "targets")
fullpipe = petpipe.Pipeline((featpipe, targetpipe), iterator=valid_range)

Uncomment the following celle to remove the cache folder, forcing a preprocessing step.

In [None]:
# rmtree(cache_dir)

In [None]:
dm = PetDataModule(
    fullpipe,
    cache_dir=cache_dir,
    n_jobs=6,
    val_split=0.1,
    test_split=0.2,
    batch_size=16,
    num_workers=4,
    multiprocessing_context="forkserver",
    persistent_workers=True
)

In [None]:
%%time
model = CNN(chan_in=9, chan_out=1)
trainer = L.Trainer(max_epochs=10, precision="16-mixed")
trainer.fit(model, dm)

In [None]:
def plot_preds(sample, pred):
    sample = sample[0].cpu().detach().numpy()
    pred = pred[0].cpu().detach().numpy()
    _, axes = plt.subplots(1, 3, figsize=(12, 5))
    axes[0].imshow(sample)
    axes[1].imshow(pred)
    axes[2].imshow(pred - sample)

In [None]:
for i in range(3):
    features, targets = dm.train_ds[i]
    preds = model.cuda().forward(features.cuda())
    plot_preds(targets, preds)

In [None]:
for i in range(3):
    features, targets = dm.val_ds[i]
    preds = model.cuda().forward(features.cuda())
    plot_preds(targets, preds)

In [None]:
for i in range(3):
    features, targets = dm.test_ds[i]
    preds = model.cuda().forward(features.cuda())
    plot_preds(targets, preds)