In [1]:
import sys
sys.path.insert(0, "../src")

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%reload_ext nb_black

<IPython.core.display.Javascript object>

In [3]:
import gc
from pathlib import Path
from tqdm.notebook import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import RandomSampler, SequentialSampler
import pytorch_lightning as pl

import transformers

from utils import visualize, radar2precipitation, seed_everything

<IPython.core.display.Javascript object>

# U-Net

## Config

In [4]:
args = dict(
    seed=42,
    dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),
    train_folds_csv=Path("../input/train_folds.csv"),
    train_data_path=Path("../input/train-128"),
    test_data_path=Path("../input/test-128"),
    rng=255.0,
    num_workers=4,
    gpus=1,
    lr=1e-4,
    max_epochs=50,
    batch_size=256,
    precision=16,
    optimizer="adamw",
    scheduler="cosine",
    accumulate_grad_batches=1,
    gradient_clip_val=5.0,
)

<IPython.core.display.Javascript object>

## Model

### Layers

#### Basic

In [5]:
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        assert in_ch == out_ch
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.net(x)

<IPython.core.display.Javascript object>

In [6]:
# x = torch.randn(3, 4, 128, 128)
# block = BasicBlock(4, 4)
# block(x).shape

<IPython.core.display.Javascript object>

#### Encoder

In [6]:
class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)
        self.net = nn.Sequential(
            nn.BatchNorm2d(in_ch),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(in_ch),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        )

    def forward(self, x):
        residual = x
        residual = self.id_conv(residual)
        x = self.net(x)
        return residual + x, x

<IPython.core.display.Javascript object>

In [7]:
# block = DownBlock(4, 64)
# down, across = block(x)
# down.shape, across.shape

<IPython.core.display.Javascript object>

In [8]:
class Encoder(nn.Module):
    def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):
        super().__init__()
        self.blocks = nn.ModuleList(
            [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
        )
        self.basic = BasicBlock(chs[-1], chs[-1])

    def forward(self, x):
        feats = []
        for block in self.blocks:
            x, feat = block(x)
            feats.append(feat)
        x = self.basic(x)
        feats.append(x)
        return feats

<IPython.core.display.Javascript object>

In [9]:
# x = torch.randn(3, 4, 128, 128)
# encoder = Encoder()
# feats = encoder(x)
# for feat in feats:
#     print(feat.shape)

<IPython.core.display.Javascript object>

#### Decoder

In [10]:
class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=False):
        super().__init__()
        self.id_conv = nn.ConvTranspose2d(
            in_ch + in_ch, out_ch, kernel_size=2, stride=2
        )
        layers = []
        if bilinear:
            layers.append(nn.Upsample(scale_factor=2, mode="nearest"))
        else:
            layers.append(
                nn.ConvTranspose2d(in_ch + in_ch, out_ch, kernel_size=2, stride=2)
            )
        layers.extend(
            [
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            ]
        )
        self.block = nn.Sequential(*layers)

    def forward(self, x, feat):
        x = torch.cat([x, feat], dim=1)
        residual = x
        residual = self.id_conv(residual)
        x = self.block(x)
        return x + residual

<IPython.core.display.Javascript object>

In [11]:
# x = torch.randn(3, 1024, 4, 4)
# feat = torch.randn(3, 1024, 4, 4)
# block = UpBlock(1024, 512)
# block(x, feat).shape

<IPython.core.display.Javascript object>

In [12]:
class Decoder(nn.Module):
    def __init__(self, chs=[1024, 512, 256, 128, 64]):
        super().__init__()
        self.blocks = nn.ModuleList(
            [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
        )

    def forward(self, x, feats):
        for block, feat in zip(self.blocks, feats):
            x = block(x, feat)
        return x

<IPython.core.display.Javascript object>

In [13]:
# x = torch.randn(3, 4, 128, 128)
# encoder = Encoder()
# feats = encoder(x)
# for feat in feats:
#     print(feat.shape)

<IPython.core.display.Javascript object>

In [14]:
# decoder = Decoder()
# x = torch.randn(3, 1024, 4, 4)
# feats = list(reversed(feats))[1:]
# decoder(x, feats).shape

<IPython.core.display.Javascript object>

In [15]:
class UNet(pl.LightningModule):
    def __init__(
        self,
        lr=args["lr"],
        enc_chs=[4, 64, 128, 256, 512, 1024],
        dec_chs=[1024, 512, 256, 128, 64],
        num_train_steps=None,
    ):
        super().__init__()
        self.lr = lr
        self.num_train_steps = num_train_steps
        self.criterion = nn.SmoothL1Loss()

        self.tail = BasicBlock(4, enc_chs[0])
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.head = nn.Sequential(
            nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.tail(x)
        #         print("after tail:", x.shape)
        feats = self.encoder(x)
        feats = feats[::-1]
        x = self.decoder(feats[0], feats[1:])
        #         print("after decoder:", x.shape)
        x = self.head(x)

        return x

    def shared_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)

        return loss, y, y_hat

    def training_step(self, batch, batch_idx):
        loss, y, y_hat = self.shared_step(batch, batch_idx)
        self.log("train_loss", loss)
        for i, param_group in enumerate(self.optimizer.param_groups):
            self.log(f"lr/lr{i}", param_group["lr"])

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss, y, y_hat = self.shared_step(batch, batch_idx)

        return {"loss": loss, "y": y.detach(), "y_hat": y_hat.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)

        crop = T.CenterCrop(120)

        y = torch.cat([x["y"] for x in outputs])
        y = crop(y)
        y = y.detach().cpu().numpy()
        y = y.reshape(-1, 120 * 120)

        y_hat = torch.cat([x["y_hat"] for x in outputs])
        y_hat = crop(y_hat)
        y_hat = y_hat.detach().cpu().numpy()
        y_hat = y_hat.reshape(-1, 120 * 120)

        y = args["rng"] * y[:, args["dams"]]
        y = y.clip(0, 255)
        y_hat = args["rng"] * y_hat[:, args["dams"]]
        y_hat = y_hat.clip(0, 255)

        y_true = radar2precipitation(y)
        y_true = np.where(y_true >= 0.1, 1, 0)
        y_pred = radar2precipitation(y_hat)
        y_pred = np.where(y_pred >= 0.1, 1, 0)

        y *= y_true
        y_hat *= y_true
        mae = metrics.mean_absolute_error(y, y_hat)
        self.log("mae", mae)

        tn, fp, fn, tp = metrics.confusion_matrix(
            y_true.ravel(), y_pred.ravel()
        ).ravel()
        csi = tp / (tp + fn + fp)
        self.log("csi", csi)

        comp_metric = mae / (csi + 1e-12)
        self.log("comp_metric", comp_metric)

        print(
            f"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}"
        )

    def configure_optimizers(self):
        #         self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        self.optimizer = transformers.AdamW(self.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=self.num_train_steps
        )

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "step"}]

<IPython.core.display.Javascript object>

In [16]:
# m = UNet()
# x = torch.randn(3, 4, 128, 128)
# m(x).shape

<IPython.core.display.Javascript object>

## Dataset

In [17]:
class NowcastingDataset(torch.utils.data.Dataset):
    def __init__(self, paths, test=False):
        self.paths = paths
        self.test = test

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        data = np.load(path)
        x = data[:, :, :4]
        x = x / args["rng"]
        x = x.astype(np.float32)
        x = torch.tensor(x, dtype=torch.float)
        x = x.permute(2, 0, 1)
        if self.test:
            return x
        else:
            y = data[:, :, 4]
            y = y / args["rng"]
            y = y.astype(np.float32)
            y = torch.tensor(y, dtype=torch.float)
            y = y.unsqueeze(-1)
            y = y.permute(2, 0, 1)

            return x, y

<IPython.core.display.Javascript object>

In [18]:
class NowcastingDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_df=None,
        val_df=None,
        batch_size=args["batch_size"],
        num_workers=args["num_workers"],
    ):
        super().__init__()
        self.train_df = train_df
        self.val_df = val_df
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage="train"):
        if stage == "train":
            train_paths = [
                args["train_data_path"] / fn for fn in self.train_df.filename.values
            ]
            val_paths = [
                args["train_data_path"] / fn for fn in self.val_df.filename.values
            ]
            self.train_dataset = NowcastingDataset(train_paths)
            self.val_dataset = NowcastingDataset(val_paths)
        else:
            test_paths = list(args["test_data_path"].glob("*.npy"))
            self.test_dataset = NowcastingDataset(test_paths, test=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=RandomSampler(self.train_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
            drop_last=True,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=2 * self.batch_size,
            sampler=SequentialSampler(self.val_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=2 * self.batch_size,
            sampler=SequentialSampler(self.test_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
        )

<IPython.core.display.Javascript object>

## Train

In [None]:
seed_everything(args["seed"])
pl.seed_everything(args["seed"])

df = pd.read_csv(args["train_folds_csv"])

for fold in range(5):
    train_df = df[df.fold != fold]
    val_df = df[df.fold == fold]

    datamodule = NowcastingDataModule(
        train_df, val_df, batch_size=args["batch_size"], num_workers=args["num_workers"]
    )
    datamodule.setup()

    num_train_steps = (
        int(
            np.ceil(
                len(train_df) // args["batch_size"] / args["accumulate_grad_batches"]
            )
        )
        * args["max_epochs"]
    )

    model = UNet(num_train_steps=num_train_steps)

    trainer = pl.Trainer(
        gpus=args["gpus"],
        max_epochs=args["max_epochs"],
        precision=args["precision"],
        progress_bar_refresh_rate=50,
        #         accumulate_grad_batches=args["accumulate_grad_batches"],
        #         gradient_clip_val=args["gradient_clip_val"],
        auto_lr_find=True,
    )

    # learning rate finder
    #     lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)
    #     fig = lr_finder.plot(suggest=True)
    #     fig.show()

    trainer.fit(model, datamodule)
    trainer.save_checkpoint(f"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}.ckpt")

    del datamodule, model, trainer
    gc.collect()
    torch.cuda.empty_cache()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type         | Params
-------------------------------------------
0 | criterion | SmoothL1Loss | 0     
1 | tail      | BasicBlock   | 300   
2 | encoder   | Encoder      | 25 M  
3 | decoder   | Decoder      | 17 M  
4 | head      | Sequential   | 8 K   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 18053432464599.61 | MAE: 18.05343246459961 | CSI: 0.0 | Loss: 0.011555060744285583


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 0 | MAE/CSI: 4.579871016522517 | MAE: 3.356309652328491 | CSI: 0.7328393398450657 | Loss: 0.0015933009563013911


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 1 | MAE/CSI: 4.3457863606748175 | MAE: 3.272869348526001 | CSI: 0.7531132634908084 | Loss: 0.0015471124788746238


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 2 | MAE/CSI: 4.294591652152133 | MAE: 3.2539639472961426 | CSI: 0.7576887887958337 | Loss: 0.0012963797198608518


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 3 | MAE/CSI: 4.154932101723256 | MAE: 3.154283046722412 | CSI: 0.7591659669745312 | Loss: 0.0012544452911242843


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 4 | MAE/CSI: 3.7860146000298136 | MAE: 2.913952112197876 | CSI: 0.7696621434505677 | Loss: 0.00123176712077111


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 5 | MAE/CSI: 4.065141708011735 | MAE: 3.1184253692626953 | CSI: 0.7671135702631766 | Loss: 0.0012020657304674387


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 6 | MAE/CSI: 3.61714756904823 | MAE: 2.808803081512451 | CSI: 0.7765243269430411 | Loss: 0.0011883970582857728


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 7 | MAE/CSI: 3.7737966527452103 | MAE: 2.9232699871063232 | CSI: 0.7746230801747217 | Loss: 0.0011574724921956658


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 3.5785213933996576 | MAE: 2.792635440826416 | CSI: 0.7803880803880804 | Loss: 0.0011816049227491021


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 9 | MAE/CSI: 3.7593860248772217 | MAE: 2.9188804626464844 | CSI: 0.776424778761062 | Loss: 0.0011313384165987372


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 10 | MAE/CSI: 3.537799221788761 | MAE: 2.764934539794922 | CSI: 0.7815408298929396 | Loss: 0.0011271697003394365


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 4.176114563603725 | MAE: 3.2004687786102295 | CSI: 0.7663747557356879 | Loss: 0.0011378307826817036


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 4.033352344612596 | MAE: 3.1121833324432373 | CSI: 0.7716120652330783 | Loss: 0.001107610878534615


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 3.6702073925815 | MAE: 2.86871600151062 | CSI: 0.781622315759435 | Loss: 0.0010806667851284146


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 3.359812260798782 | MAE: 2.65893816947937 | CSI: 0.7913948646773075 | Loss: 0.001116615254431963


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 3.777346978087415 | MAE: 2.942171096801758 | CSI: 0.7788988181031997 | Loss: 0.0010497045004740357


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 3.730972666019791 | MAE: 2.9134128093719482 | CSI: 0.7808721934369602 | Loss: 0.0010338622378185391


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 3.216006483334444 | MAE: 2.5546412467956543 | CSI: 0.7943520201314134 | Loss: 0.0010403376072645187


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 3.6045317048942183 | MAE: 2.8333561420440674 | CSI: 0.786053882725832 | Loss: 0.0010198085801675916


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 3.2532478057212075 | MAE: 2.5799641609191895 | CSI: 0.7930426192492238 | Loss: 0.0010090331779792905


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 3.2199748341813246 | MAE: 2.5563158988952637 | CSI: 0.7938931297709924 | Loss: 0.0010016037849709392


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 3.856811757635469 | MAE: 3.0003974437713623 | CSI: 0.7779475982532751 | Loss: 0.0010271386709064245


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 3.1347424155752437 | MAE: 2.4959778785705566 | CSI: 0.7962306140899923 | Loss: 0.0010296452092006803


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 3.1336774350690413 | MAE: 2.504286289215088 | CSI: 0.7991525423728814 | Loss: 0.000995357520878315


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 3.155359118626466 | MAE: 2.518603563308716 | CSI: 0.7981987053194484 | Loss: 0.0009966425132006407


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 3.2987531262931515 | MAE: 2.6152637004852295 | CSI: 0.7928037050231564 | Loss: 0.0009848386980593204


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 3.3408334380625853 | MAE: 2.6483819484710693 | CSI: 0.7927309150747657 | Loss: 0.0009853563969954848


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 3.151360169059713 | MAE: 2.5201938152313232 | CSI: 0.7997162114224903 | Loss: 0.0009801144478842616


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 3.383971888869794 | MAE: 2.6815693378448486 | CSI: 0.7924325100516945 | Loss: 0.0009856465039774776


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 3.3458963516517897 | MAE: 2.650679349899292 | CSI: 0.7922180101566412 | Loss: 0.0009886184707283974


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 3.540172589897291 | MAE: 2.7885289192199707 | CSI: 0.7876816308826718 | Loss: 0.00101062364410609


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 3.1403178747400826 | MAE: 2.5151991844177246 | CSI: 0.8009377664109122 | Loss: 0.0009796030353754759


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 3.137336144450593 | MAE: 2.5105366706848145 | CSI: 0.8002128414331323 | Loss: 0.0009821136482059956


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 3.22167231289145 | MAE: 2.5630111694335938 | CSI: 0.7955530297648646 | Loss: 0.0009937712457031012


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 3.107360140259587 | MAE: 2.48958420753479 | CSI: 0.8011894647408666 | Loss: 0.000986725091934204


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 3.3972837724475897 | MAE: 2.691316843032837 | CSI: 0.7921966557095899 | Loss: 0.0010110668372362852


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 3.2956298785612326 | MAE: 2.6264376640319824 | CSI: 0.7969455796945579 | Loss: 0.000985836493782699


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 3.276042547174778 | MAE: 2.6103878021240234 | CSI: 0.7968113248016014 | Loss: 0.0009991289116442204


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 3.282018963884671 | MAE: 2.615131139755249 | CSI: 0.7968056152413694 | Loss: 0.0009931615786626935


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 3.1812205203792945 | MAE: 2.543889284133911 | CSI: 0.7996582656984195 | Loss: 0.000992935849353671


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 3.307319243115345 | MAE: 2.6357016563415527 | CSI: 0.7969299189441217 | Loss: 0.0009866819018498063


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 3.143233505972508 | MAE: 2.513828754425049 | CSI: 0.7997588310398638 | Loss: 0.0009999492904171348


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 3.215217112553536 | MAE: 2.5696022510528564 | CSI: 0.7992002285061411 | Loss: 0.0009870363865047693


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 3.21669262308394 | MAE: 2.570596933364868 | CSI: 0.7991428571428572 | Loss: 0.000990581582300365


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 3.2109438338309406 | MAE: 2.5652711391448975 | CSI: 0.7989149832250696 | Loss: 0.0009919735603034496


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 3.2196984066052785 | MAE: 2.5702412128448486 | CSI: 0.7982863263120314 | Loss: 0.000992590212263167


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 3.206927726907789 | MAE: 2.562476634979248 | CSI: 0.799044086174918 | Loss: 0.000992824207060039


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 3.2243089758598296 | MAE: 2.5741519927978516 | CSI: 0.7983577293823635 | Loss: 0.0009943352779373527


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 3.167185464203879 | MAE: 2.532801866531372 | CSI: 0.7997011526967411 | Loss: 0.0009948504157364368


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 3.2167388425628727 | MAE: 2.569028615951538 | CSI: 0.7986438258386866 | Loss: 0.0009935392299667



GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type         | Params
-------------------------------------------
0 | criterion | SmoothL1Loss | 0     
1 | tail      | BasicBlock   | 300   
2 | encoder   | Encoder      | 25 M  
3 | decoder   | Decoder      | 17 M  
4 | head      | Sequential   | 8 K   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 14487452507019.043 | MAE: 14.487452507019043 | CSI: 0.0 | Loss: 0.009541328065097332


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 5.392028101629733 | MAE: 3.834322214126587 | CSI: 0.7111094641666049 | Loss: 0.0021661436185240746


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 4.158585763428692 | MAE: 3.119844675064087 | CSI: 0.7502177068214804 | Loss: 0.0013823203044012189


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 4.3142486810508425 | MAE: 3.232961654663086 | CSI: 0.749368405409422 | Loss: 0.0013024784857407212


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 3.7116308049028257 | MAE: 2.8598763942718506 | CSI: 0.770517474553349 | Loss: 0.0013511937577277422


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 4 | MAE/CSI: 3.539873240049237 | MAE: 2.7341103553771973 | CSI: 0.7723752151462995 | Loss: 0.001252486719749868


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 4.01579733783149 | MAE: 3.0554652214050293 | CSI: 0.7608614091693554 | Loss: 0.0011789751006290317


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 3.7320073356218004 | MAE: 2.864356517791748 | CSI: 0.7675109559533536 | Loss: 0.0011457751970738173


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 7 | MAE/CSI: 3.401637135420141 | MAE: 2.644503355026245 | CSI: 0.7774207682196582 | Loss: 0.0011405773693695664


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 4.595668000327307 | MAE: 3.438323736190796 | CSI: 0.7481662591687042 | Loss: 0.0012028244091197848


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 9 | MAE/CSI: 3.4759753314680486 | MAE: 2.6975338459014893 | CSI: 0.7760509177027827 | Loss: 0.0010866763768717647


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 3.2317734726219225 | MAE: 2.534553289413452 | CSI: 0.7842608124863248 | Loss: 0.0011001526145264506


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 11 | MAE/CSI: 3.2738988791916976 | MAE: 2.5705955028533936 | CSI: 0.7851786501985002 | Loss: 0.0010762200690805912


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 3.2065835705106345 | MAE: 2.5200467109680176 | CSI: 0.785897718101108 | Loss: 0.001055945991538465


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 3.6612566362170322 | MAE: 2.834601640701294 | CSI: 0.7742155009451795 | Loss: 0.00104250549338758


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 3.0992773782456213 | MAE: 2.4434335231781006 | CSI: 0.7883881385789783 | Loss: 0.0010305154137313366


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 3.639291737971471 | MAE: 2.8226284980773926 | CSI: 0.7755983035443805 | Loss: 0.0010273642838001251


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 3.2597624066956175 | MAE: 2.5623531341552734 | CSI: 0.786055182699478 | Loss: 0.0010038955369964242


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 3.2570115885248803 | MAE: 2.5611789226531982 | CSI: 0.7863585538576221 | Loss: 0.0010010426631197333


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 3.406129191555971 | MAE: 2.6659343242645264 | CSI: 0.7826873774694616 | Loss: 0.000992716639302671


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 3.144901228592783 | MAE: 2.488124370574951 | CSI: 0.791161371921732 | Loss: 0.0009813575306907296


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 3.0428135418560394 | MAE: 2.419461727142334 | CSI: 0.795139660665333 | Loss: 0.0009844209998846054


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 3.4147496972918714 | MAE: 2.6745433807373047 | CSI: 0.7832326283987915 | Loss: 0.0009928239742293954


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 3.465787560107332 | MAE: 2.704352617263794 | CSI: 0.7802995914661824 | Loss: 0.0009944615885615349


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 3.4667651828205974 | MAE: 2.7036044597625732 | CSI: 0.7798637395912188 | Loss: 0.0010142156388610601


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 24 | MAE/CSI: 3.3635030555243413 | MAE: 2.640925884246826 | CSI: 0.7851712457659014 | Loss: 0.0009805120062083006


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 3.0347875782549485 | MAE: 2.413892984390259 | CSI: 0.7954075605434141 | Loss: 0.0009923680918291211


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 26 | MAE/CSI: 3.1692318757620845 | MAE: 2.50390887260437 | CSI: 0.7900680577368933 | Loss: 0.0009772846242412925


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 3.2020519076581904 | MAE: 2.5177001953125 | CSI: 0.7862771335117454 | Loss: 0.000999476877041161


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  z = np.power(10.0, dbz / 10.0)


Epoch 28 | MAE/CSI: 3.0848241435457164 | MAE: 2.443657398223877 | CSI: 0.7921545230815125 | Loss: 0.0009700111113488674


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 3.1461110961318504 | MAE: 2.488356351852417 | CSI: 0.7909308590242442 | Loss: 0.0009753701160661876


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 3.1462543540290406 | MAE: 2.4851224422454834 | CSI: 0.7898669855029143 | Loss: 0.0009774373611435294


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 3.0840243109285304 | MAE: 2.4406659603118896 | CSI: 0.7913899873162725 | Loss: 0.0009708466241136193


  z = np.power(10.0, dbz / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 3.122035965640138 | MAE: 2.4696598052978516 | CSI: 0.7910414333706607 | Loss: 0.0009853055234998465


  z = np.power(10.0, dbz / 10.0)


## Inference

In [None]:
datamodule = NowcastingDataModule()
datamodule.setup("test")

final_preds = np.zeros((len(datamodule.test_dataset), 120, 120))

for fold in range(5):
    model = UNet.load_from_checkpoint(f"unet_fold{fold}_bs{args['batch_size']}_epoch{args['max_epochs']}")
    model.cuda()
    model.eval()
    preds = []
    with torch.no_grad():
        for batch in tqdm(datamodule.test_dataloader()):
            batch = batch.cuda()
            imgs = model(batch)
            imgs = imgs.detach().cpu().numpy()
            imgs = imgs[:, 0, 4:124, 4:124]
            imgs = args["rng"] * imgs
            imgs = imgs.clip(0, 255)
            imgs = imgs.round()
            preds.append(imgs)

    preds = np.concatenate(preds)
    preds = preds.astype(np.uint8)
    final_preds += preds
    
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
final_preds = final_preds.reshape(-1, 14400)

In [None]:
test_paths = datamodule.test_dataset.paths
test_filenames = [path.name for path in test_paths]

In [None]:
subm = pd.DataFrame({"file_name"}: test_filenames)
for i in tqdm(range(14400)):
    subm[str(i)] = final_preds[:, i]

In [None]:
subm.to_csv(f"unet_bs{args['batch_size']}_epoch{args['max_epochs']}_lr{model.lr}.csv", index=False)
subm.head()