In [1]:
import sys
sys.path.insert(0, "../src")

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%reload_ext nb_black

<IPython.core.display.Javascript object>

In [3]:
import gc
import warnings
from pathlib import Path
from tqdm.notebook import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

import pytorch_lightning as pl

import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2

from transformers import AdamW, get_cosine_schedule_with_warmup

import optim
import loss
from utils import visualize, radar2precipitation, seed_everything

<IPython.core.display.Javascript object>

In [4]:
warnings.simplefilter("ignore")

<IPython.core.display.Javascript object>

# U-Net

## Config

In [5]:
args = dict(
    seed=42,
    dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),
    train_folds_csv=Path("../input/train_folds.csv"),
    train_data_path=Path("../input/train"),
    test_data_path=Path("../input/test"),
    model_dir=Path("../models"),
    output_dir=Path("../output"),
    rng=255.0,
    num_workers=4,
    gpus=1,
    lr=1e-3,
    max_epochs=30,
    batch_size=256,
    precision=16,
    optimizer="adamw",
    scheduler="cosine",
    accumulate_grad_batches=1,
    gradient_clip_val=5.0,
    warmup_epochs=1,
)

args["trn_tfms"] = A.Compose(
    [
        A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),
        ToTensorV2(always_apply=True, p=1),
    ]
)

args["val_tfms"] = A.Compose(
    [
        A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),
        ToTensorV2(always_apply=True, p=1),
    ]
)

<IPython.core.display.Javascript object>

## Dataset

In [6]:
class NowcastingDataset(Dataset):
    def __init__(self, paths, tfms=None, test=False):
        self.paths = paths
        if tfms is not None:
            self.tfms = tfms
        else:
            self.tfms = A.Compose(
                [
                    A.PadIfNeeded(
                        min_height=128, min_width=128, always_apply=True, p=1
                    ),
                    ToTensorV2(always_apply=True, p=1),
                ]
            )
        self.test = test

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        data = np.load(path)

        augmented = self.tfms(image=data)
        data = augmented["image"]

        x = data[:4, :, :]
        x = x / args["rng"]
        if self.test:
            return x
        else:
            y = data[4, :, :]
            y = y / args["rng"]
            y = y.unsqueeze(0)

            return x, y

<IPython.core.display.Javascript object>

In [7]:
class NowcastingDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_df=None,
        val_df=None,
        batch_size=args["batch_size"],
        num_workers=args["num_workers"],
        test=False,
    ):
        super().__init__()
        self.train_df = train_df
        self.val_df = val_df
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.test = test

    def setup(self, stage="train"):
        if stage == "train":
            train_paths = [
                args["train_data_path"] / fn for fn in self.train_df.filename.values
            ]
            val_paths = [
                args["train_data_path"] / fn for fn in self.val_df.filename.values
            ]
            self.train_dataset = NowcastingDataset(train_paths, tfms=args["trn_tfms"])
            self.val_dataset = NowcastingDataset(val_paths, tfms=args["val_tfms"])
        else:
            test_paths = list(sorted(args["test_data_path"].glob("*.npy")))
            self.test_dataset = NowcastingDataset(test_paths, test=True)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=RandomSampler(self.train_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=2 * self.batch_size,
            sampler=SequentialSampler(self.val_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=2 * self.batch_size,
            sampler=SequentialSampler(self.test_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
        )

<IPython.core.display.Javascript object>

## Model

### Basic

In [8]:
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        assert in_ch == out_ch
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.net(x)

<IPython.core.display.Javascript object>

### Encoder

In [9]:
class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.id_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2)
        self.net = nn.Sequential(
            nn.BatchNorm2d(in_ch),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(in_ch),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        )

    def forward(self, x):
        residual = x
        residual = self.id_conv(residual)
        x = self.net(x)
        return residual + x, x


class Encoder(nn.Module):
    def __init__(self, chs=[4, 64, 128, 256, 512, 1024]):
        super().__init__()
        self.blocks = nn.ModuleList(
            [DownBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
        )
        self.basic = BasicBlock(chs[-1], chs[-1])

    def forward(self, x):
        feats = []
        for block in self.blocks:
            x, feat = block(x)
            feats.append(feat)
        x = self.basic(x)
        feats.append(x)
        return feats

<IPython.core.display.Javascript object>

### Decoder

In [10]:
class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=False):
        super().__init__()
        self.id_conv = nn.ConvTranspose2d(
            in_ch + in_ch, out_ch, kernel_size=2, stride=2
        )
        layers = []
        if bilinear:
            layers.append(nn.Upsample(scale_factor=2, mode="nearest"))
        else:
            layers.append(
                nn.ConvTranspose2d(in_ch + in_ch, out_ch, kernel_size=2, stride=2)
            )
        layers.extend(
            [
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            ]
        )
        self.block = nn.Sequential(*layers)

    def forward(self, x, feat):
        x = torch.cat([x, feat], dim=1)
        residual = x
        residual = self.id_conv(residual)
        x = self.block(x)
        return x + residual


class Decoder(nn.Module):
    def __init__(self, chs=[1024, 512, 256, 128, 64]):
        super().__init__()
        self.blocks = nn.ModuleList(
            [UpBlock(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
        )

    def forward(self, x, feats):
        for block, feat in zip(self.blocks, feats):
            x = block(x, feat)
        return x

<IPython.core.display.Javascript object>

### U-Net

In [16]:
class UNet(pl.LightningModule):
    def __init__(
        self,
        lr=args["lr"],
        enc_chs=[4, 64, 128, 256, 512, 1024],
        dec_chs=[1024, 512, 256, 128, 64],
        num_train_steps=None,
    ):
        super().__init__()
        self.lr = lr
        self.num_train_steps = num_train_steps
        self.criterion = nn.L1Loss()

        self.tail = BasicBlock(4, enc_chs[0])
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.head = nn.Sequential(
            nn.ConvTranspose2d(dec_chs[-1], 32, kernel_size=2, stride=2, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
            #             nn.ReLU(inplace=True),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.tail(x)
        feats = self.encoder(x)
        feats = feats[::-1]
        x = self.decoder(feats[0], feats[1:])
        x = self.head(x)
        return x

    def shared_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        return loss, y, y_hat

    def training_step(self, batch, batch_idx):
        loss, y, y_hat = self.shared_step(batch, batch_idx)
        self.log("train_loss", loss)
        for i, param_group in enumerate(self.optimizer.param_groups):
            self.log(f"lr/lr{i}", param_group["lr"])
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss, y, y_hat = self.shared_step(batch, batch_idx)
        return {"loss": loss, "y": y.detach(), "y_hat": y_hat.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)

        y = torch.cat([x["y"] for x in outputs])
        y_hat = torch.cat([x["y_hat"] for x in outputs])

        crop = T.CenterCrop(120)
        y = crop(y)
        y_hat = crop(y_hat)

        batch_size = len(y)
        y = y.detach().cpu().numpy()
        y *= args["rng"]
        y = y.reshape(batch_size, -1)
        y = y[:, args["dams"]]
        y_hat = y_hat.detach().cpu().numpy()
        y_hat *= args["rng"]
        y_hat = y_hat.reshape(batch_size, -1)
        y_hat = y_hat[:, args["dams"]]

        y_true = radar2precipitation(y)
        y_true = np.where(y_true >= 0.1, 1, 0)
        y_true = y_true.ravel()
        y_pred = radar2precipitation(y_hat)
        y_pred = np.where(y_pred >= 0.1, 1, 0)
        y_pred = y_pred.ravel()

        y = y.ravel()
        y_hat = y_hat.ravel()
        mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)
        self.log("mae", mae)

        tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()
        csi = tp / (tp + fn + fp)
        self.log("csi", csi)

        comp_metric = mae / (csi + 1e-12)
        self.log("comp_metric", comp_metric)

        print(
            f"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}"
        )

    def configure_optimizers(self):
        # optimizer
        if args["optimizer"] == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        elif args["optimizer"] == "adamw":
            self.optimizer = AdamW(self.parameters(), lr=self.lr)
        elif args["optimizer"] == "radam":
            self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)
        elif args["optimizer"] == "ranger":
            self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)
            self.optimizer = optim.Lookahead(self.optimizer)

        # scheduler
        if args["scheduler"] == "cosine":
            self.scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.num_train_steps * args["warmup_epochs"],
                num_training_steps=self.num_train_steps * args["max_epochs"],
            )
            return [self.optimizer], [{"scheduler": self.scheduler, "interval": "step"}]
        elif args["scheduler"] == "step":
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, step_size=10, gamma=0.5
            )
            return [self.optimizer], [
                {"scheduler": self.scheduler, "interval": "epoch"}
            ]
        elif args["scheduler"] == "plateau":
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode="min", factor=0.1, patience=3, verbose=True
            )
            return [self.optimizer], [
                {
                    "scheduler": self.scheduler,
                    "interval": "epoch",
                    "reduce_on_plateau": True,
                    "monitor": "comp_metric",
                }
            ]
        else:
            self.scheduler = None
            return [self.optimizer]

<IPython.core.display.Javascript object>

## Train

In [12]:
seed_everything(args["seed"])
pl.seed_everything(args["seed"])

42

<IPython.core.display.Javascript object>

In [13]:
df = pd.read_csv(args["train_folds_csv"])

<IPython.core.display.Javascript object>

In [14]:
def train_fold(df, fold, lr_find=False):
    train_df = df[df.fold != fold]
    val_df = df[df.fold == fold]

    datamodule = NowcastingDataModule(train_df, val_df)
    datamodule.setup()

    num_train_steps = np.ceil(
        len(train_df) // args["batch_size"] / args["accumulate_grad_batches"]
    )
    model = UNet(num_train_steps=num_train_steps)

    trainer = pl.Trainer(
        gpus=args["gpus"],
        max_epochs=args["max_epochs"],
        precision=args["precision"],
        progress_bar_refresh_rate=50,
        benchmark=True,
    )

    if lr_find:
        lr_finder = trainer.tuner.lr_find(model, datamodule=datamodule)
        fig = lr_finder.plot(suggest=True)
        fig.show()
        return

    print(f"Training fold {fold}...")
    trainer.fit(model, datamodule)

    checkpoint = (
        args["model_dir"]
        / f"unet_sigmoid_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt"
    )
    trainer.save_checkpoint(checkpoint)
    print("Model saved at", checkpoint)

    del model, trainer, datamodule
    gc.collect()
    torch.cuda.empty_cache()

<IPython.core.display.Javascript object>

In [16]:
# AdamW bs256 lr 1e-3
for fold in range(5):
    train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold 0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | tail      | BasicBlock | 300   
2 | encoder   | Encoder    | 25 M  
3 | decoder   | Decoder    | 17 M  
4 | head      | Sequential | 8 K   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 110793664383561.66 | MAE: 110.79366438356165 | CSI: 0.0 | Loss: 0.05267500877380371


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 33.186943419342214 | MAE: 25.55001192118522 | CSI: 0.7698814439856134 | Loss: 0.017848094925284386


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 25.185173461181805 | MAE: 19.69605205396773 | CSI: 0.7820494897245911 | Loss: 0.01364449504762888


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 25.54220371042292 | MAE: 20.054110962616964 | CSI: 0.7851362862010222 | Loss: 0.013360547833144665


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 27.026754305478295 | MAE: 21.24946493606421 | CSI: 0.7862381363244176 | Loss: 0.013475954532623291


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 21.443058880888415 | MAE: 17.222089883581003 | CSI: 0.8031545302946081 | Loss: 0.012746231630444527


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 21.810521936074984 | MAE: 17.524096834325515 | CSI: 0.8034698521046644 | Loss: 0.012218066491186619


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 19.459475168673396 | MAE: 15.815077867269974 | CSI: 0.8127186231985448 | Loss: 0.011925801634788513


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 21.077778418376518 | MAE: 17.037970858700216 | CSI: 0.8083380762663631 | Loss: 0.011935080401599407


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 20.73141515057428 | MAE: 16.773482073236565 | CSI: 0.8090852434041964 | Loss: 0.012599549256265163


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 20.89835070821418 | MAE: 17.01616066608685 | CSI: 0.8142346208869814 | Loss: 0.011802570894360542


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 20.213408750832656 | MAE: 16.462153962294565 | CSI: 0.8144175069727526 | Loss: 0.011618967168033123


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 24.208863723571998 | MAE: 19.503725665599507 | CSI: 0.805643994211288 | Loss: 0.012568147853016853


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 20.601691043830602 | MAE: 16.781879929672506 | CSI: 0.81458749643163 | Loss: 0.01178012229502201


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 19.976155204575853 | MAE: 16.334156712039647 | CSI: 0.8176827094474153 | Loss: 0.011617383919656277


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 18.39588910524068 | MAE: 15.092483317838292 | CSI: 0.8204269568857262 | Loss: 0.01175840012729168


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 19.09062582033396 | MAE: 15.576906377459213 | CSI: 0.8159452981813056 | Loss: 0.011697824113070965


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 18.493273251072264 | MAE: 15.189089688174784 | CSI: 0.8213305174234424 | Loss: 0.011621751822531223


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 18.927566425968354 | MAE: 15.492399944517754 | CSI: 0.818509870515814 | Loss: 0.011545676738023758


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 19.334057504131742 | MAE: 15.774070578664828 | CSI: 0.8158696422245838 | Loss: 0.011600039899349213


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 19.14047650865582 | MAE: 15.601810012202344 | CSI: 0.8151212957069099 | Loss: 0.011548931710422039


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 20.027654187947494 | MAE: 16.23505613828575 | CSI: 0.8106319385140905 | Loss: 0.011748154647648335


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 19.016997137613103 | MAE: 15.494138983728332 | CSI: 0.814752133135886 | Loss: 0.011649723164737225


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 19.549548452653777 | MAE: 15.862644750554823 | CSI: 0.8114072194021432 | Loss: 0.01184056606143713


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 19.62327869559657 | MAE: 15.935058043144943 | CSI: 0.8120487045164944 | Loss: 0.011686836369335651


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 19.20136738616643 | MAE: 15.653670722243882 | CSI: 0.8152372905224616 | Loss: 0.011736424639821053


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 19.269590556557628 | MAE: 15.65051845909309 | CSI: 0.8121873899260303 | Loss: 0.011749816127121449


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 19.560944262355147 | MAE: 15.85134794678203 | CSI: 0.8103569916748977 | Loss: 0.011810777708888054


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 19.463234309122434 | MAE: 15.779038946990088 | CSI: 0.8107100133699247 | Loss: 0.011808572337031364


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 19.463016053011348 | MAE: 15.769864921350168 | CSI: 0.8102477477477478 | Loss: 0.01181867253035307


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 19.419930651638133 | MAE: 15.728437331303983 | CSI: 0.8099121265377855 | Loss: 0.011822505854070187

Model saved at ../models/unet_fold0_bs256_epochs30_lr0.001_adamw_cosine.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | tail      | BasicBlock | 300   
2 | encoder   | Encoder    | 25 M  
3 | decoder   | Decoder    | 17 M  
4 | head      | Sequential | 8 K   


Training fold 1...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 109838714384134.84 | MAE: 109.83871438413485 | CSI: 0.0 | Loss: 0.050759207457304


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 30.438610328033292 | MAE: 22.877340869081632 | CSI: 0.7515895312731037 | Loss: 0.01447451300919056


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 27.81257920736288 | MAE: 21.42626403857756 | CSI: 0.7703803332586117 | Loss: 0.013589969836175442


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 23.791732404542913 | MAE: 18.607595477371124 | CSI: 0.7821034282393957 | Loss: 0.013189456425607204


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 24.67189650396396 | MAE: 19.398163979907117 | CSI: 0.7862453531598513 | Loss: 0.01269851066172123


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 22.37970634801444 | MAE: 17.825067209957396 | CSI: 0.7964835164835165 | Loss: 0.01245537493377924


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 21.38870789855043 | MAE: 17.124827474286082 | CSI: 0.8006480595036454 | Loss: 0.012010117061436176


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 23.121047769023082 | MAE: 18.387481908225936 | CSI: 0.7952702702702703 | Loss: 0.01195940189063549


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 21.380010649024133 | MAE: 17.249816201340995 | CSI: 0.8068198133524767 | Loss: 0.012651579454541206


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 19.884120596258796 | MAE: 16.060755238895975 | CSI: 0.8077176539503551 | Loss: 0.011617396026849747


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 19.95838481240207 | MAE: 16.128404994784848 | CSI: 0.8081017149615612 | Loss: 0.011446814052760601


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 21.23726514737613 | MAE: 17.11664908782062 | CSI: 0.8059723777528929 | Loss: 0.011620122008025646


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 19.230807556137442 | MAE: 15.613997821907722 | CSI: 0.8119262686347948 | Loss: 0.01138119213283062


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 19.26334314043933 | MAE: 15.615153630625755 | CSI: 0.810614934114202 | Loss: 0.011543345637619495


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 19.264658003703257 | MAE: 15.678122278859526 | CSI: 0.8138282172373081 | Loss: 0.011346523649990559


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 19.5614310438213 | MAE: 15.858158574006369 | CSI: 0.8106849922411882 | Loss: 0.011353997513651848


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 20.325657045944222 | MAE: 16.421114476880287 | CSI: 0.8079007945347887 | Loss: 0.011452319100499153


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 19.160887430973464 | MAE: 15.560784978478654 | CSI: 0.8121119146760173 | Loss: 0.011521076783537865


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 18.913051188853814 | MAE: 15.386936254420686 | CSI: 0.8135618151062735 | Loss: 0.01132373046129942


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 20.261271536845946 | MAE: 16.398104336047624 | CSI: 0.8093324402768475 | Loss: 0.011397392489016056


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 18.88648488743959 | MAE: 15.356261966135346 | CSI: 0.813082056170712 | Loss: 0.011438331566751003


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 18.90888501053057 | MAE: 15.309875574315502 | CSI: 0.8096656976744186 | Loss: 0.011600780300796032


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 19.057435755200128 | MAE: 15.442373928848985 | CSI: 0.8103070175438597 | Loss: 0.011565647087991238


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 19.747915922078203 | MAE: 15.972321902462133 | CSI: 0.8088105076741441 | Loss: 0.01151563972234726


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 18.986097731680978 | MAE: 15.399541466936103 | CSI: 0.8110956598111688 | Loss: 0.01157230231910944


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 18.99107928711418 | MAE: 15.39025627194611 | CSI: 0.8103939770484614 | Loss: 0.011596854776144028


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 19.7236355264424 | MAE: 15.961558913109142 | CSI: 0.8092604880926049 | Loss: 0.011602360755205154


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 19.66751364177057 | MAE: 15.879075442501195 | CSI: 0.8073758448427858 | Loss: 0.01169213280081749


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 19.442847596632387 | MAE: 15.730006844393438 | CSI: 0.8090382217005355 | Loss: 0.011666052974760532


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 19.545673056666146 | MAE: 15.805726503701734 | CSI: 0.8086560364464692 | Loss: 0.011664430610835552


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 19.661343826279275 | MAE: 15.894221094064639 | CSI: 0.8083995292733157 | Loss: 0.01167358923703432

Model saved at ../models/unet_fold1_bs256_epochs30_lr0.001_adamw_cosine.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | tail      | BasicBlock | 300   
2 | encoder   | Encoder    | 25 M  
3 | decoder   | Decoder    | 17 M  
4 | head      | Sequential | 8 K   


Training fold 2...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 98624364269755.4 | MAE: 98.6243642697554 | CSI: 0.0 | Loss: 0.08764511346817017


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 28.118997965134586 | MAE: 21.552152280986306 | CSI: 0.7664623151821133 | Loss: 0.01427386049181223


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 27.128638812721636 | MAE: 21.151175594950804 | CSI: 0.7796622506915126 | Loss: 0.013792337849736214


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 24.618840091334018 | MAE: 19.40673530777922 | CSI: 0.7882879630298216 | Loss: 0.01312659028917551


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 23.996545896105665 | MAE: 19.048209403667897 | CSI: 0.793789634813033 | Loss: 0.012730807065963745


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 22.49951004667157 | MAE: 18.03575311320954 | CSI: 0.8016064827978391 | Loss: 0.012839207425713539


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 21.361131217621754 | MAE: 17.257766442326197 | CSI: 0.8079050807977871 | Loss: 0.012023642659187317


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 20.236411459874972 | MAE: 16.415489828220917 | CSI: 0.811185810327564 | Loss: 0.012115568853914738


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 24.470472404637587 | MAE: 19.564411922558676 | CSI: 0.7995110024449877 | Loss: 0.012515711598098278


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 19.02272187040231 | MAE: 15.550396087375951 | CSI: 0.817464303652149 | Loss: 0.011686594225466251


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 20.080601901201696 | MAE: 16.387970643444902 | CSI: 0.816109533173112 | Loss: 0.011634819209575653


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 19.03831631311525 | MAE: 15.581505716571792 | CSI: 0.8184287654585746 | Loss: 0.011825410649180412


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 19.32006970412333 | MAE: 15.80681914816354 | CSI: 0.8181553891997965 | Loss: 0.01143832691013813


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 18.683346398615083 | MAE: 15.279302381427371 | CSI: 0.8178033022254128 | Loss: 0.011568550951778889


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 19.403732510254446 | MAE: 15.866799598526859 | CSI: 0.8177189409368636 | Loss: 0.011433429084718227


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 18.861351107256343 | MAE: 15.467400527873899 | CSI: 0.8200579290369298 | Loss: 0.011417574249207973


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 20.403804237083232 | MAE: 16.587705867603038 | CSI: 0.8129712319742333 | Loss: 0.011530996300280094


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 18.508092655544406 | MAE: 15.191096051399349 | CSI: 0.8207812838472092 | Loss: 0.011424791999161243


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 18.44665647834475 | MAE: 15.127479769290078 | CSI: 0.8200662156326471 | Loss: 0.011457541026175022


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 18.96570753645215 | MAE: 15.53061988275734 | CSI: 0.8188790137614679 | Loss: 0.011679432354867458


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 18.822043787124592 | MAE: 15.420485982148985 | CSI: 0.8192779783393502 | Loss: 0.011392543092370033


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 18.606130005732833 | MAE: 15.197414212351436 | CSI: 0.8167960885821111 | Loss: 0.01142452098429203


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 18.9835783172211 | MAE: 15.499320162112285 | CSI: 0.8164593578247035 | Loss: 0.011478066444396973


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 19.555789183906093 | MAE: 15.923755880672903 | CSI: 0.8142732431252728 | Loss: 0.01156239677220583


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 18.89104865410489 | MAE: 15.40757357276467 | CSI: 0.8156018151696319 | Loss: 0.011537961661815643


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 18.59823612059139 | MAE: 15.182996494158246 | CSI: 0.816367552045944 | Loss: 0.011565683409571648


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 19.234773276351596 | MAE: 15.645243910983037 | CSI: 0.8133833285261033 | Loss: 0.011604744009673595


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 18.789052965351058 | MAE: 15.296363151240856 | CSI: 0.8141103854159191 | Loss: 0.01161937415599823


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 18.882758069128442 | MAE: 15.359764472001991 | CSI: 0.8134280180761781 | Loss: 0.011645403690636158


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 18.913308360419204 | MAE: 15.388432096423116 | CSI: 0.8136298421807747 | Loss: 0.011661795899271965


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 18.950830060861367 | MAE: 15.392199249279548 | CSI: 0.8122176812217681 | Loss: 0.011679957620799541

Model saved at ../models/unet_fold2_bs256_epochs30_lr0.001_adamw_cosine.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | tail      | BasicBlock | 300   
2 | encoder   | Encoder    | 25 M  
3 | decoder   | Decoder    | 17 M  
4 | head      | Sequential | 8 K   


Training fold 3...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 110557260156424.86 | MAE: 110.55726015642486 | CSI: 0.0 | Loss: 0.05172676593065262


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 53.56887701828267 | MAE: 35.32899085851282 | CSI: 0.6595059076262084 | Loss: 0.01728362962603569


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 30.639784592662963 | MAE: 23.430682562051953 | CSI: 0.7647143370463528 | Loss: 0.014653380028903484


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 24.957249398984533 | MAE: 19.493277948112066 | CSI: 0.7810667608618863 | Loss: 0.013412871398031712


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 22.953791365850908 | MAE: 18.058455785756223 | CSI: 0.7867308497279196 | Loss: 0.013156878761947155


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 23.78838192093069 | MAE: 18.79893090212847 | CSI: 0.7902568137921168 | Loss: 0.012770530767738819


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 24.134156968826158 | MAE: 19.308000326101173 | CSI: 0.8000279583420703 | Loss: 0.013565506786108017


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 21.513528786478364 | MAE: 17.24819258159191 | CSI: 0.8017370256994376 | Loss: 0.012290913611650467


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 21.362476504765034 | MAE: 17.063856162282722 | CSI: 0.7987770593196514 | Loss: 0.01256471686065197


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 24.453781933161846 | MAE: 19.512796349196865 | CSI: 0.7979459538203802 | Loss: 0.012635443359613419


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 20.900274204023304 | MAE: 16.887015755273936 | CSI: 0.8079805839103433 | Loss: 0.0120097566395998


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 20.906133797377368 | MAE: 16.826452392597048 | CSI: 0.8048572039028559 | Loss: 0.012089049443602562


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 19.764185822147727 | MAE: 16.106189117183263 | CSI: 0.8149179157744468 | Loss: 0.011762270703911781


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 19.216470309349777 | MAE: 15.633324322908592 | CSI: 0.8135377658446973 | Loss: 0.011789782904088497


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 21.358521733793662 | MAE: 17.035517334854887 | CSI: 0.7975981459372147 | Loss: 0.012352569960057735


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 18.955754078822256 | MAE: 15.441341389797334 | CSI: 0.8145991621103458 | Loss: 0.011553915217518806


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 18.41964258573334 | MAE: 15.062246454066061 | CSI: 0.8177274007321881 | Loss: 0.011606993153691292


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 18.945060050135833 | MAE: 15.475827797763376 | CSI: 0.8168793214056597 | Loss: 0.011503880843520164


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 19.05648451538784 | MAE: 15.521529688907957 | CSI: 0.814501209620037 | Loss: 0.011547097936272621


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 19.16718569318381 | MAE: 15.615780001194492 | CSI: 0.8147142857142857 | Loss: 0.011518046259880066


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 18.73081169605486 | MAE: 15.293072094853354 | CSI: 0.8164660636706788 | Loss: 0.01148569118231535


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 19.141530793836786 | MAE: 15.653774430905004 | CSI: 0.817791147400086 | Loss: 0.011520149186253548


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 18.813476282761126 | MAE: 15.291689353583903 | CSI: 0.8128050937389458 | Loss: 0.011663438752293587


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 18.42512713407604 | MAE: 14.991494951336136 | CSI: 0.813644043931287 | Loss: 0.011668611317873001


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 18.5111411504437 | MAE: 15.066295467950644 | CSI: 0.8139041966935142 | Loss: 0.011603855527937412


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 18.97355917639315 | MAE: 15.4504381284336 | CSI: 0.814314171883893 | Loss: 0.011588823981583118


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 18.991307442213472 | MAE: 15.43785156363545 | CSI: 0.8128904031800114 | Loss: 0.011663117446005344


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 18.63918385797683 | MAE: 15.163947710759158 | CSI: 0.8135521290140048 | Loss: 0.01164956297725439


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 18.906251519195056 | MAE: 15.356309645176488 | CSI: 0.8122344944774851 | Loss: 0.01169898733496666


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 19.010790761517644 | MAE: 15.43363629861466 | CSI: 0.8118355776045358 | Loss: 0.011704476550221443


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 18.91723245279753 | MAE: 15.368827583863563 | CSI: 0.8124247361337394 | Loss: 0.011702695861458778

Model saved at ../models/unet_fold3_bs256_epochs30_lr0.001_adamw_cosine.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | tail      | BasicBlock | 300   
2 | encoder   | Encoder    | 25 M  
3 | decoder   | Decoder    | 17 M  
4 | head      | Sequential | 8 K   


Training fold 4...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 110804066395808.14 | MAE: 110.80406639580814 | CSI: 0.0 | Loss: 0.046886492520570755


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 29.46492178703207 | MAE: 22.387287602854418 | CSI: 0.7597945707997066 | Loss: 0.014288208447396755


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 24.204093119843044 | MAE: 19.084029852953837 | CSI: 0.7884629165173773 | Loss: 0.013465666212141514


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 23.3908121436228 | MAE: 18.574004691365896 | CSI: 0.7940726716667881 | Loss: 0.0127839595079422


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 21.65168614373402 | MAE: 17.37842964056247 | CSI: 0.8026363177987467 | Loss: 0.012596615590155125


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 21.677014943572697 | MAE: 17.44700715090798 | CSI: 0.8048620714753622 | Loss: 0.012286808341741562


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 21.918280019408275 | MAE: 17.60216383659147 | CSI: 0.803081438004402 | Loss: 0.012095422483980656


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 20.329710330708902 | MAE: 16.422185408844385 | CSI: 0.8077923955472025 | Loss: 0.012070290744304657


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 21.331304906024766 | MAE: 17.233019202263723 | CSI: 0.8078745898651112 | Loss: 0.011999445036053658


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 20.6731214146294 | MAE: 16.683493317061913 | CSI: 0.8070137538705264 | Loss: 0.012145860120654106


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 19.466089391431066 | MAE: 15.881236795690588 | CSI: 0.8158411520837879 | Loss: 0.011607903987169266


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 19.81252870088382 | MAE: 16.179648879536806 | CSI: 0.8166372462488968 | Loss: 0.011471263132989407


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 19.550455453315504 | MAE: 15.932699609748498 | CSI: 0.8149528612146459 | Loss: 0.011466726660728455


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 18.289922608388945 | MAE: 15.044766262064371 | CSI: 0.8225713462092822 | Loss: 0.01155020110309124


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 19.916873297085072 | MAE: 16.239037439054368 | CSI: 0.8153407011637268 | Loss: 0.011535759083926678


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 19.381075628919593 | MAE: 15.88246965193991 | CSI: 0.8194833948339484 | Loss: 0.011257984675467014


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 20.13362559236226 | MAE: 16.44119687934336 | CSI: 0.8166038850727528 | Loss: 0.01147684920579195


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 17.862983125481282 | MAE: 14.748776933633795 | CSI: 0.8256614715476622 | Loss: 0.01122660469263792


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 17.659409365863613 | MAE: 14.61047819083576 | CSI: 0.8273480662983426 | Loss: 0.011134890839457512


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 17.875598933953043 | MAE: 14.754400889379596 | CSI: 0.8253933724893047 | Loss: 0.011305585503578186


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 18.201254790410943 | MAE: 15.010341899683212 | CSI: 0.8246872027511524 | Loss: 0.011105424724519253


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 18.577237816975604 | MAE: 15.28320825097125 | CSI: 0.8226846424384525 | Loss: 0.011204416863620281


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 18.38351714846262 | MAE: 15.161192424176441 | CSI: 0.8247166361974406 | Loss: 0.011202105320990086


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 18.325339605405098 | MAE: 15.064050353614043 | CSI: 0.8220338983050848 | Loss: 0.011181995272636414


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 17.875720986784252 | MAE: 14.731736891221365 | CSI: 0.8241198719813791 | Loss: 0.011183995753526688


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 18.44947563426877 | MAE: 15.130209495701461 | CSI: 0.8200888629907495 | Loss: 0.011271136812865734


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 18.53381123761257 | MAE: 15.19639837323524 | CSI: 0.8199284096719994 | Loss: 0.011240239255130291


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 18.0624360317926 | MAE: 14.856623087439964 | CSI: 0.8225149177703391 | Loss: 0.01123881060630083


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 18.33757410027865 | MAE: 15.043335066826108 | CSI: 0.8203557888597258 | Loss: 0.0112459110096097


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 18.386441478880272 | MAE: 15.091226520452597 | CSI: 0.8207801676995989 | Loss: 0.011267893016338348


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 18.40432508550994 | MAE: 15.094416588039563 | CSI: 0.8201559425781535 | Loss: 0.011272534728050232

Model saved at ../models/unet_fold4_bs256_epochs30_lr0.001_adamw_cosine.ckpt


<IPython.core.display.Javascript object>

In [None]:
# AdamW bs256 lr 1e-3 sigmoid
for fold in range(5):
    train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold 0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | tail      | BasicBlock | 300   
2 | encoder   | Encoder    | 25 M  
3 | decoder   | Decoder    | 17 M  
4 | head      | Sequential | 8 K   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 167.45782510687246 | MAE: 27.286654537671232 | CSI: 0.16294642857142858 | Loss: 0.4579598009586334


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 28.362450125306825 | MAE: 22.10657577075336 | CSI: 0.7794311024984428 | Loss: 0.020207127556204796


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 23.82231910706382 | MAE: 18.83105605581214 | CSI: 0.7904795486600846 | Loss: 0.014244887046515942


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 23.14310842963703 | MAE: 18.453300962246768 | CSI: 0.7973561986423723 | Loss: 0.013162474147975445


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 22.18791287315471 | MAE: 17.807568912908806 | CSI: 0.8025797205302759 | Loss: 0.012457935139536858


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 23.763760209997628 | MAE: 19.19425320030403 | CSI: 0.8077111126632674 | Loss: 0.013415777124464512


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 21.892119289753165 | MAE: 17.683018462607592 | CSI: 0.8077344284736482 | Loss: 0.013465486466884613


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 22.04263144863797 | MAE: 17.724101042939697 | CSI: 0.8040828103585083 | Loss: 0.01200629211962223


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 21.346540554610538 | MAE: 17.230123555188904 | CSI: 0.8071623367303 | Loss: 0.011888713575899601


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 19.32465185410075 | MAE: 15.690600147796456 | CSI: 0.8119473647566664 | Loss: 0.01206459105014801


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 20.710088985666317 | MAE: 16.81065647000772 | CSI: 0.8117133867276888 | Loss: 0.011814854107797146


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 24.743819129015346 | MAE: 19.77880616991625 | CSI: 0.7993433053630062 | Loss: 0.012558660469949245


## Inference

In [None]:
def inference(checkpoints):
    datamodule = NowcastingDataModule()
    datamodule.setup("test")
    
    test_paths = datamodule.test_dataset.paths
    test_filenames = [path.name for path in test_paths]
    final_preds = np.zeros((len(datamodule.test_dataset), 14400))
    
    for checkpoint in checkpoints:
        print(f"Inference from {checkpoint}")
        model = UNet.load_from_checkpoint(str(checkpoint))
        model.cuda()
        model.eval()
        preds = []
        with torch.no_grad():
            for batch in tqdm(datamodule.test_dataloader()):
                batch = batch.cuda()
                imgs = model(batch)
                imgs = imgs.detach().cpu().numpy()
                imgs = imgs[:, 0, 4:124, 4:124]
                imgs = args["rng"] * imgs
                imgs = imgs.clip(0, 255)
                imgs = imgs.round()
                preds.append(imgs)
                
        preds = np.concatenate(preds)
        preds = preds.astype(np.uint8)
        preds = preds.reshape(-1, 14400)
        final_preds += preds / len(checkpoint)
        
        del model
        gc.collect()
        torch.cuda.empty_cache()
        
    final_preds = final_preds.round()
    final_preds = final_preds.astype(np.uint8)
    
    subm = pd.DataFrame()
    subm["file_name"] = test_filename
    for i in tqdm(range(14400)):
        subm[str(i)] = final_preds[:, i]
        
    return subm

In [None]:
checkpoints = [args["model_dir"] / f"unet_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.ckpt" for fold in range(5)]
output_path = args["output_dir"] / f"unet_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}_{args['scheduler']}.csv"
subm.to_csv(output_path, index=False)
subm.head()