In [1]:
import sys
sys.path.insert(0, "../src")

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%reload_ext nb_black

<IPython.core.display.Javascript object>

In [3]:
import gc
from pathlib import Path
from tqdm.notebook import tqdm
import warnings

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_lightning as pl
from transformers import (
    AdamW,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)

import optim
import loss
from utils import visualize, radar2precipitation, seed_everything

<IPython.core.display.Javascript object>

In [4]:
warnings.simplefilter("ignore")

<IPython.core.display.Javascript object>

In [5]:
args = dict(
    seed=42,
    dams=(6071, 6304, 7026, 7629, 7767, 8944, 11107),
    train_folds_csv=Path("../input/train_folds.csv"),
    train_data_path=Path("../input/train"),
    test_data_path=Path("../input/test"),
    model_dir=Path("../models"),
    output_dir=Path("../output"),
    rng=255.0,
    num_workers=4,
    gpus=1,
    lr=5e-4,
    max_epochs=50,
    batch_size=128,
    precision=16,
    optimizer="adamw",
    scheduler="cosine",
    warmup_epochs=1,
    accumulate_grad_batches=1,
    gradient_clip_val=5.0,
)

<IPython.core.display.Javascript object>

# 🔥 Baseline ⚡️

## Sketch

In [6]:
df = pd.read_csv(args["train_folds_csv"])
df.head()

Unnamed: 0,filename,fold
0,train_60668.npy,0
1,train_60911.npy,0
2,train_14956.npy,0
3,train_24086.npy,0
4,train_09805.npy,0


<IPython.core.display.Javascript object>

In [7]:
fn = df.loc[0, "filename"]
fn

'train_60668.npy'

<IPython.core.display.Javascript object>

In [8]:
path = args["train_data_path"] / fn
path

PosixPath('../input/train/train_60668.npy')

<IPython.core.display.Javascript object>

In [9]:
data = np.load(path)
data.shape

(120, 120, 5)

<IPython.core.display.Javascript object>

In [10]:
tfms = A.Compose(
    [
        A.PadIfNeeded(min_height=128, min_width=128, always_apply=True, p=1),
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        ToTensorV2(always_apply=True, p=1),
    ]
)

<IPython.core.display.Javascript object>

In [11]:
augmented = tfms(image=data)

<IPython.core.display.Javascript object>

In [12]:
image = augmented["image"]
image.shape

torch.Size([5, 128, 128])

<IPython.core.display.Javascript object>

In [13]:
image[0].min(), image[0].max()

(tensor(0, dtype=torch.uint8), tensor(220, dtype=torch.uint8))

<IPython.core.display.Javascript object>

## Dataset

In [14]:
class NowcastingDataset(Dataset):
    def __init__(self, paths, tfms=None, test=False):
        self.paths = paths
        if tfms is not None:
            self.tfms = tfms
        else:
            self.tfms = A.Compose(
                [
                    A.PadIfNeeded(
                        min_height=128, min_width=128, always_apply=True, p=1
                    ),
                    ToTensorV2(always_apply=True, p=1),
                ]
            )
        self.test = test

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        data = np.load(path)

        augmented = self.tfms(image=data)
        data = augmented["image"]

        x = data[:4, :, :]
        x = x / args["rng"]
        if self.test:
            return x
        else:
            y = data[4, :, :]
            y = y / args["rng"]
            y = y.unsqueeze(0)

            return x, y

<IPython.core.display.Javascript object>

In [15]:
class NowcastingDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_df=None,
        val_df=None,
        batch_size=args["batch_size"],
        num_workers=args["num_workers"],
        test=False,
    ):
        super().__init__()
        self.train_df = train_df
        self.val_df = val_df
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.test = test

    def setup(self, stage="train"):
        if stage == "train":
            train_paths = [
                args["train_data_path"] / fn for fn in self.train_df.filename.values
            ]
            val_paths = [
                args["train_data_path"] / fn for fn in self.val_df.filename.values
            ]
            self.train_dataset = NowcastingDataset(train_paths)
            self.val_dataset = NowcastingDataset(val_paths)
        else:
            test_paths = list(sorted(args["test_data_path"].glob("*.npy")))
            self.test_dataset = NowcastingDataset(test_paths, test=True)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=RandomSampler(self.train_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=2 * self.batch_size,
            sampler=SequentialSampler(self.val_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=2 * self.batch_size,
            sampler=SequentialSampler(self.test_dataset),
            pin_memory=True,
            num_workers=self.num_workers,
        )

<IPython.core.display.Javascript object>

In [16]:
# df = pd.read_csv(args["train_folds_csv"])

# fold = 0
# train_df = df[df["fold"] != fold]
# val_df = df[df["fold"] == fold]

# datamodule = NowcastingDataModule(train_df, val_df)
# datamodule.setup()

# for batch in datamodule.train_dataloader():
#     xs, ys = batch
#     idx = np.random.randint(len(xs))
#     x, y = xs[idx], ys[idx]
#     print(x.shape, y.shape)
#     x = x.detach().cpu().numpy().transpose(1, 2, 0)
#     y = y.unsqueeze(-1)
#     y = y.detach().cpu().numpy()
#     visualize(x, y)
#     break

<IPython.core.display.Javascript object>

## Model

### Layers

In [17]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class Encoder(nn.Module):
    def __init__(self, chs=[4, 64, 128]):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.blocks = nn.ModuleList(
            [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
        )
        self.conv = nn.Conv2d(128, 512, kernel_size=3, padding=1)

    def forward(self, x):
        ftrs = []
        for block in self.blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        x = self.conv(x)
        ftrs.append(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=[512, 128, 64]):
        super().__init__()
        self.tr_convs = nn.ModuleList(
            [
                nn.ConvTranspose2d(chs[i], chs[i + 1], kernel_size=2, stride=2)
                for i in range(len(chs) - 1)
            ]
        )
        self.blocks = nn.ModuleList(
            [Block(2 * chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]
        )

    def forward(self, x, ftrs):
        for i, ftr in enumerate(ftrs):
            x = self.tr_convs[i](x)
            x = torch.cat([ftr, x], dim=1)
            x = self.blocks[i](x)
        return x

<IPython.core.display.Javascript object>

### Model

In [18]:
class Baseline(pl.LightningModule):
    def __init__(
        self,
        lr=args["lr"],
        enc_chs=[4, 64, 128],
        dec_chs=[512, 128, 64],
        num_train_steps=None,
    ):
        super().__init__()
        self.lr = lr
        self.num_train_steps = num_train_steps
        self.criterion = nn.L1Loss()
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.out = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        ftrs = self.encoder(x)
        ftrs = ftrs[::-1]
        x = self.decoder(ftrs[0], ftrs[1:])
        out = self.out(x)
        return out

    def shared_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        return loss, y, y_hat

    def training_step(self, batch, batch_idx):
        loss, y, y_hat = self.shared_step(batch, batch_idx)
        self.log("train_loss", loss)

        for i, param_group in enumerate(self.optimizer.param_groups):
            self.log(f"lr/lr{i}", param_group["lr"])

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss, y, y_hat = self.shared_step(batch, batch_idx)

        return {"loss": loss, "y": y.detach(), "y_hat": y_hat.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)

        y = torch.cat([x["y"] for x in outputs])
        y_hat = torch.cat([x["y_hat"] for x in outputs])

        crop = T.CenterCrop(120)
        y = crop(y)
        y_hat = crop(y_hat)

        batch_size = len(y)
        y = y.reshape(batch_size, -1)
        y = y.detach().cpu().numpy()
        y *= args["rng"]
        y_hat = y_hat.reshape(batch_size, -1)
        y_hat = y_hat.detach().cpu().numpy()
        y_hat *= args["rng"]

        y = y[:, args["dams"]]
        y_hat = y_hat[:, args["dams"]]

        y_true = radar2precipitation(y)
        y_true = np.where(y_true >= 0.1, 1.0, 0.0)
        y_pred = radar2precipitation(y_hat)
        y_pred = np.where(y_pred >= 0.1, 1.0, 0.0)

        y_true = y_true.ravel()
        y_pred = y_pred.ravel()

        tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()
        csi = tp / (tp + tn + tp)
        self.log("csi", csi)

        mae = metrics.mean_absolute_error(y, y_hat, sample_weight=y_true)
        self.log("mae", mae)

        comp_metric = mae / (csi + 1e-12)
        self.log("comp_metric", comp_metric)

        print(
            f"Epoch {self.current_epoch} | MAE/CSI: {comp_metric} | MAE: {mae} | CSI: {csi} | Loss: {avg_loss}"
        )

    def configure_optimizers(self):
        if args["optimizer"] == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        elif args["optimizer"] == "adamw":
            self.optimizer = AdamW(self.parameters(), lr=self.lr)
        elif args["optimizer"] == "radam":
            self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)
        elif args["optimizer"] == "ranger":
            self.optimizer = optim.RAdam(self.parameters(), lr=self.lr)
            self.optimizer = optim.Lookahead(self.optimizer)

        if args["scheduler"] == "cosine":
            self.scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.num_train_steps * args["warmup_epochs"],
                num_training_steps=self.num_train_steps * args["max_epochs"],
            )
            return [self.optimizer], [{"scheduler": self.scheduler, "interval": "step"}]
        elif args["scheduler"] == "step":
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, step_size=10, gamma=0.5
            )
            return [self.optimizer], [
                {"scheduler": self.scheduler, "interval": "epoch"}
            ]
        elif args["scheduler"] == "plateau":
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode="min", factor=0.1, patience=3, verbose=True
            )
            return [self.optimizer], [
                {
                    "scheduler": self.scheduler,
                    "interval": "epoch",
                    "reduce_on_plateau": True,
                    "monitor": "comp_metric",
                }
            ]
        else:
            self.scheduler = None
            return [self.optimizer]

<IPython.core.display.Javascript object>

## Train

In [19]:
seed_everything(args["seed"])
pl.seed_everything(args["seed"])

42

<IPython.core.display.Javascript object>

In [20]:
df = pd.read_csv(args["train_folds_csv"])

<IPython.core.display.Javascript object>

In [14]:
def train_fold(df, fold):
    train_df = df[df.fold != fold]
    val_df = df[df.fold == fold]

    datamodule = NowcastingDataModule(train_df, val_df)
    datamodule.setup()

    num_train_steps = np.ceil(
        len(train_df) // args["batch_size"] / args["accumulate_grad_batches"]
    )
    model = Baseline(num_train_steps=num_train_steps)

    trainer = pl.Trainer(
        gpus=args["gpus"],
        max_epochs=args["max_epochs"],
        precision=args["precision"],
        progress_bar_refresh_rate=50,
        benchmark=True,
    )

    print(f"Training fold {fold}...")
    trainer.fit(model, datamodule)
    
    checkpoint = (
        args["model_dir"]
        / f"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{model.lr}_{args['optimizer']}.ckpt"
    )
    trainer.save_checkpoint(checkpoint)
    print("Model saved at", checkpoint)
    
    del model, trainer, datamodule
    gc.collect()
    torch.cuda.empty_cache()    

<IPython.core.display.Javascript object>

In [14]:
# AdamW
fold = 0
train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 0.8953259817369302 | MAE: 0.10786948176583493 | CSI: 0.12048067850736674 | Loss: 0.014457848854362965


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.0508558724080708 | MAE: 0.12422264875239923 | CSI: 0.11821092883716593 | Loss: 0.012260455638170242


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 0.897048855974516 | MAE: 0.10771593090211132 | CSI: 0.12007810966348786 | Loss: 0.012970300391316414


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 0.9234593276009212 | MAE: 0.11047984644913628 | CSI: 0.1196369381369846 | Loss: 0.011573512107133865


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.0946424615395505 | MAE: 0.12859884836852206 | CSI: 0.1174802301991471 | Loss: 0.011782950721681118


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 0.9862818881438672 | MAE: 0.11715930902111324 | CSI: 0.11878886799859507 | Loss: 0.011379786767065525


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.2269688941892183 | MAE: 0.1421880998080614 | CSI: 0.11588565975895616 | Loss: 0.011886500753462315


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.0524613342678637 | MAE: 0.12406909788867562 | CSI: 0.1178847087754828 | Loss: 0.011312616057693958




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.8743689630123128 | MAE: 0.10502879078694817 | CSI: 0.12011953217579474 | Loss: 0.011175304651260376


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 1.0495967204363328 | MAE: 0.1237619961612284 | CSI: 0.11791385562707277 | Loss: 0.011026876978576183


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9960218536200095 | MAE: 0.11808061420345489 | CSI: 0.11855223233636758 | Loss: 0.011095545254647732


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8782308244192888 | MAE: 0.1054126679462572 | CSI: 0.1200284310392781 | Loss: 0.010937790386378765


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.8742338140149913 | MAE: 0.10502879078694817 | CSI: 0.12013810161805627 | Loss: 0.011085813865065575


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.7914715377970224 | MAE: 0.09589251439539348 | CSI: 0.12115724926951726 | Loss: 0.011124382726848125


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.8122845658503062 | MAE: 0.0981957773512476 | CSI: 0.12088839487876168 | Loss: 0.010922062210738659


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.832724012562721 | MAE: 0.10042226487523992 | CSI: 0.12059489501852615 | Loss: 0.010781347751617432


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.9140697496979593 | MAE: 0.1092514395393474 | CSI: 0.11952199443700422 | Loss: 0.010753633454442024


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 1.0276737057476453 | MAE: 0.12138195777351247 | CSI: 0.11811332438848178 | Loss: 0.010888705030083656


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 0.90410702954464 | MAE: 0.10817658349328214 | CSI: 0.11965019622384969 | Loss: 0.010728368535637856


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.9914165414434577 | MAE: 0.11754318618042227 | CSI: 0.1185608483073051 | Loss: 0.010847923345863819


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.9764170742744235 | MAE: 0.11593090211132438 | CSI: 0.11873092468860844 | Loss: 0.01069142110645771


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.8930430185237244 | MAE: 0.10694817658349329 | CSI: 0.11975702666529393 | Loss: 0.01066147442907095


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.8887563310135709 | MAE: 0.10648752399232246 | CSI: 0.1198163324136227 | Loss: 0.010664239525794983


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.840922715282812 | MAE: 0.10126679462571977 | CSI: 0.12042342629646013 | Loss: 0.010658983141183853


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.8959247022302149 | MAE: 0.1072552783109405 | CSI: 0.11971461222472743 | Loss: 0.010633627884089947


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.8951603121976484 | MAE: 0.1071785028790787 | CSI: 0.11973107097996417 | Loss: 0.010618860833346844


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9008450087222826 | MAE: 0.10779270633397313 | CSI: 0.11965732760839795 | Loss: 0.010613663122057915


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.8874501850055542 | MAE: 0.10633397312859885 | CSI: 0.11981965289358286 | Loss: 0.010609721764922142


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8923343991380076 | MAE: 0.10687140115163148 | CSI: 0.11976608909616909 | Loss: 0.010615041479468346


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.8853361925398303 | MAE: 0.10610364683301343 | CSI: 0.11984559958826557 | Loss: 0.01060927752405405

Model saved at ../models/baseline_bs128_epochs30_lr0.0005_adamw.ckpt


  z *= np.power(10.0, dbz_max / 10.0)


<IPython.core.display.Javascript object>

In [13]:
# RAdam
fold = 0
train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729047590/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.1287884420798677 | MAE: 0.1327447216890595 | CSI: 0.11759929207225027 | Loss: 0.016403084620833397


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.2616119527130563 | MAE: 0.14610364683301344 | CSI: 0.11580712002415686 | Loss: 0.013337209820747375




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 0.8670467608914513 | MAE: 0.10456813819577736 | CSI: 0.12060265133497404 | Loss: 0.012833792716264725


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 1.0554725725706757 | MAE: 0.12460652591170826 | CSI: 0.11805756885483537 | Loss: 0.012007031589746475


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.3001223097119974 | MAE: 0.14963531669865643 | CSI: 0.11509326128747337 | Loss: 0.012559030205011368


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 0.9668908031778127 | MAE: 0.11516314779270634 | CSI: 0.11910667410760423 | Loss: 0.011626843363046646


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.142794785587942 | MAE: 0.13358925143953934 | CSI: 0.11689697321262094 | Loss: 0.011799349449574947


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.1083582171365651 | MAE: 0.12998080614203456 | CSI: 0.11727328234794937 | Loss: 0.011587237007915974




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.8388044949660333 | MAE: 0.10119001919385796 | CSI: 0.12063600016487366 | Loss: 0.011390800587832928


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 0.9891935240347431 | MAE: 0.11738963531669866 | CSI: 0.11867206210256834 | Loss: 0.011127580888569355


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9411532947274672 | MAE: 0.11224568138195777 | CSI: 0.11926397326539663 | Loss: 0.01120688859373331


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8536757691390952 | MAE: 0.10272552783109405 | CSI: 0.12033318918473672 | Loss: 0.011052107438445091


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9470463135836465 | MAE: 0.1128598848368522 | CSI: 0.11917039665023411 | Loss: 0.011067020706832409


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.8655408572526657 | MAE: 0.10403071017274472 | CSI: 0.12019156496215047 | Loss: 0.011069140397012234


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.8477291753956104 | MAE: 0.10211132437619962 | CSI: 0.12045276644831707 | Loss: 0.011056198738515377


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.8269470306746735 | MAE: 0.09980806142034548 | CSI: 0.12069462458567518 | Loss: 0.010897000320255756




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.9089498997078406 | MAE: 0.10871401151631478 | CSI: 0.1196039644763141 | Loss: 0.010795501060783863




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 1.0144721688973737 | MAE: 0.12 | CSI: 0.11828811442842548 | Loss: 0.010921232402324677


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 0.9739822739190495 | MAE: 0.11570057581573896 | CSI: 0.11879125412541254 | Loss: 0.010769343003630638




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 1.012934577465377 | MAE: 0.11984644913627639 | CSI: 0.11831608062501935 | Loss: 0.010915917344391346


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.9746930990857315 | MAE: 0.11577735124760077 | CSI: 0.1187833907465088 | Loss: 0.010718860663473606




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.8865415622549162 | MAE: 0.10625719769673704 | CSI: 0.11985585585585586 | Loss: 0.010684728622436523




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.9431093784188942 | MAE: 0.11239923224568138 | CSI: 0.11917942374104427 | Loss: 0.010712772607803345


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8567591311220998 | MAE: 0.10303262955854127 | CSI: 0.12025857188442496 | Loss: 0.01069499459117651


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.9041815313239039 | MAE: 0.10817658349328214 | CSI: 0.11964033741541441 | Loss: 0.01064450852572918


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.9142298319658559 | MAE: 0.1092514395393474 | CSI: 0.11950106605415761 | Loss: 0.010635210201144218


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9199470836152046 | MAE: 0.10986564299424184 | CSI: 0.11942604629124133 | Loss: 0.01062960084527731


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.8950865803184656 | MAE: 0.1071785028790787 | CSI: 0.11974093370950802 | Loss: 0.0106242336332798


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8978677068424368 | MAE: 0.10748560460652591 | CSI: 0.11971207315566174 | Loss: 0.010626324452459812


  z *= np.power(10.0, dbz_max / 10.0)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.8928407437743435 | MAE: 0.10694817658349329 | CSI: 0.11978415784487374 | Loss: 0.010623933747410774

Model saved at ../models/baseline_bs128_epochs30_lr0.0005_radam.ckpt


  z *= np.power(10.0, dbz_max / 10.0)


<IPython.core.display.Javascript object>

In [14]:
# Ranger
fold = 0
train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.1385640966037704 | MAE: 0.13389635316698656 | CSI: 0.11760106748952318 | Loss: 0.01573796570301056


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.085314053724478 | MAE: 0.12806142034548945 | CSI: 0.1179948051948052 | Loss: 0.013925119303166866


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 0.9167696697228916 | MAE: 0.11001919385796545 | CSI: 0.12000745387912293 | Loss: 0.013134698383510113


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 1.0299586717682772 | MAE: 0.12199616122840691 | CSI: 0.11844762762949383 | Loss: 0.012236752547323704


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.013925980609321 | MAE: 0.12023032629558542 | CSI: 0.11857899747506105 | Loss: 0.012843919917941093


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 1.011833885875258 | MAE: 0.12 | CSI: 0.11859654205510776 | Loss: 0.011795842088758945


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.0590260203151582 | MAE: 0.12491362763915546 | CSI: 0.11795142446162284 | Loss: 0.011630662716925144


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 0.9600188360970295 | MAE: 0.1143953934740883 | CSI: 0.1191595301798498 | Loss: 0.011449274607002735


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.9020290172989984 | MAE: 0.10809980806142035 | CSI: 0.11984072129321105 | Loss: 0.011327717453241348


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 0.9770176426018803 | MAE: 0.11616122840690979 | CSI: 0.11889368558031933 | Loss: 0.011341550387442112


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9844928312476287 | MAE: 0.11692898272552783 | CSI: 0.11877078127258835 | Loss: 0.011181168258190155


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8644004373845067 | MAE: 0.10395393474088292 | CSI: 0.12026131668160787 | Loss: 0.011257469654083252


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.8851591085350253 | MAE: 0.10618042226487524 | CSI: 0.119956312080122 | Loss: 0.01108333095908165


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 1.013253430079364 | MAE: 0.1199232245681382 | CSI: 0.11835461988787131 | Loss: 0.011116106063127518


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 1.0395499227504847 | MAE: 0.12268714011516314 | CSI: 0.11801947884283692 | Loss: 0.01116656232625246


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.8960707942438373 | MAE: 0.1073320537428023 | CSI: 0.11978077450061296 | Loss: 0.010934371501207352


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.9623145456213387 | MAE: 0.11447216890595009 | CSI: 0.11895504378048906 | Loss: 0.011090419255197048


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 0.896153875733789 | MAE: 0.1073320537428023 | CSI: 0.11976966975009785 | Loss: 0.010892595164477825


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 0.9479289803918486 | MAE: 0.11293666026871402 | CSI: 0.11914042360122915 | Loss: 0.010916203260421753


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.8959230938172563 | MAE: 0.1073320537428023 | CSI: 0.1198005213646152 | Loss: 0.010992695577442646


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.9846853003130904 | MAE: 0.11685220729366602 | CSI: 0.1186695965254351 | Loss: 0.010856914333999157


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.9188840579993098 | MAE: 0.10978886756238004 | CSI: 0.1194806533051677 | Loss: 0.01079531665891409


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.9567146873342918 | MAE: 0.11385796545105566 | CSI: 0.11900932112513404 | Loss: 0.010844394564628601


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.9026124925539706 | MAE: 0.10802303262955854 | CSI: 0.11967819359889573 | Loss: 0.010828354395925999


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.9244720777907447 | MAE: 0.11040307101727448 | CSI: 0.11942282916774027 | Loss: 0.010767893865704536


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.9259653880900643 | MAE: 0.11055662188099807 | CSI: 0.11939606307327631 | Loss: 0.010765370912849903


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9082203937558084 | MAE: 0.10863723608445297 | CSI: 0.11961549953122264 | Loss: 0.010760427452623844


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.9217818510578139 | MAE: 0.11009596928982726 | CSI: 0.11943820456278466 | Loss: 0.010757229290902615


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.9232172476669039 | MAE: 0.11024952015355087 | CSI: 0.11941882631768767 | Loss: 0.010758249089121819


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.9124960947270578 | MAE: 0.1090978886756238 | CSI: 0.11955984174085064 | Loss: 0.010757893323898315

Model saved at ../models/baseline_bs128_epochs30_lr0.0005_ranger.ckpt


<IPython.core.display.Javascript object>

In [16]:
# AdamW bs50
fold = 0
train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.431734138641048 | MAE: 0.16314779270633398 | CSI: 0.11395117871517432 | Loss: 0.014841477386653423


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.0666386134475698 | MAE: 0.12591170825335893 | CSI: 0.11804533106608879 | Loss: 0.01233157329261303


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 0.8780003852209621 | MAE: 0.1056429942418426 | CSI: 0.12032226411196612 | Loss: 0.012936452403664589


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 0.9439850315398074 | MAE: 0.1127063339731286 | CSI: 0.11939419610111883 | Loss: 0.011566683650016785


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.106750432758379 | MAE: 0.12982725527831093 | CSI: 0.11730490581660112 | Loss: 0.01176687702536583


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 1.0036502619730607 | MAE: 0.11900191938579655 | CSI: 0.11856911106748365 | Loss: 0.011306156404316425


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.2770520777797807 | MAE: 0.1472552783109405 | CSI: 0.11530874963664299 | Loss: 0.011934855952858925


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.073235025239742 | MAE: 0.12629558541266794 | CSI: 0.11767747272633267 | Loss: 0.011313681490719318


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.8434373491310988 | MAE: 0.10165067178502879 | CSI: 0.12051952867501647 | Loss: 0.011227131821215153


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 1.0523852115853143 | MAE: 0.12406909788867562 | CSI: 0.11789323578647157 | Loss: 0.011084862984716892


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9726492106175404 | MAE: 0.11562380038387716 | CSI: 0.118875128998968 | Loss: 0.011067209765315056


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8857018008780639 | MAE: 0.10625719769673704 | CSI: 0.11996949491410139 | Loss: 0.010920040309429169


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9229223356706401 | MAE: 0.11024952015355087 | CSI: 0.1194569855897087 | Loss: 0.010936064645648003


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.942111649744198 | MAE: 0.11232245681381958 | CSI: 0.11922414593150954 | Loss: 0.010872152633965015


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.9676800877047635 | MAE: 0.11508637236084453 | CSI: 0.1189301855253111 | Loss: 0.010937022045254707


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.8352548220955233 | MAE: 0.10072936660268714 | CSI: 0.1205971685971686 | Loss: 0.010871059261262417


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.9581314429908468 | MAE: 0.11401151631477927 | CSI: 0.11899360692926376 | Loss: 0.0107908695936203


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 0.8868929365209549 | MAE: 0.10633397312859885 | CSI: 0.11989493742596694 | Loss: 0.010807439684867859


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 1.0813794255930738 | MAE: 0.127063339731286 | CSI: 0.11750116261044799 | Loss: 0.01099312026053667


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.8605000659886831 | MAE: 0.1034932821497121 | CSI: 0.1202710914727724 | Loss: 0.010955577716231346


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.9247674515027908 | MAE: 0.11040307101727448 | CSI: 0.11938468513023409 | Loss: 0.010693227872252464


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.7480731419030813 | MAE: 0.09105566218809981 | CSI: 0.12172026649119921 | Loss: 0.011536278761923313


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.9189219325547425 | MAE: 0.10978886756238004 | CSI: 0.11947572875557708 | Loss: 0.010655155405402184


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8743660610106783 | MAE: 0.10495201535508637 | CSI: 0.12003212388287138 | Loss: 0.010694632306694984


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.8596956033848641 | MAE: 0.10333973128598849 | CSI: 0.12020502475323953 | Loss: 0.010617670603096485


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.9320987589153724 | MAE: 0.11117082533589251 | CSI: 0.11926936311375765 | Loss: 0.010588045231997967


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.8901725335752227 | MAE: 0.10664107485604607 | CSI: 0.11979820858643056 | Loss: 0.01059663761407137


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.9442341384933582 | MAE: 0.11247600767754319 | CSI: 0.11911876841910023 | Loss: 0.010609036311507225


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8273726514056701 | MAE: 0.09980806142034548 | CSI: 0.12063253631836701 | Loss: 0.010630769655108452


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.9478179216236259 | MAE: 0.1128598848368522 | CSI: 0.11907338135427294 | Loss: 0.01066933386027813


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.8845922772176007 | MAE: 0.10602687140115163 | CSI: 0.11985959422318754 | Loss: 0.010545131750404835


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.8893270175124783 | MAE: 0.10656429942418426 | CSI: 0.11982577536142346 | Loss: 0.010544568300247192


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.8942856105404078 | MAE: 0.1071017274472169 | CSI: 0.11976232892934743 | Loss: 0.010560829192399979


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.8121483608195074 | MAE: 0.0981190019193858 | CSI: 0.12081413526411058 | Loss: 0.010571216233074665


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.9277556442673833 | MAE: 0.11071017274472168 | CSI: 0.11933117672511487 | Loss: 0.010578982532024384


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.9421531481549642 | MAE: 0.11224568138195777 | CSI: 0.11913740520936367 | Loss: 0.010580139234662056


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.8190910415457837 | MAE: 0.09888675623800385 | CSI: 0.12072742982338844 | Loss: 0.01055662240833044


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.9201081769333753 | MAE: 0.10986564299424184 | CSI: 0.11940513707800367 | Loss: 0.010542426258325577


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.9013507058044742 | MAE: 0.10786948176583493 | CSI: 0.11967537282689297 | Loss: 0.010510878637433052


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.8654870949149819 | MAE: 0.10395393474088292 | CSI: 0.12011032325124268 | Loss: 0.01051416713744402


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.9463158106190775 | MAE: 0.1127063339731286 | CSI: 0.11910012778762522 | Loss: 0.010519065894186497


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.8654692808571053 | MAE: 0.10395393474088292 | CSI: 0.12011279549641339 | Loss: 0.010517258197069168


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.9022928392247763 | MAE: 0.10794625719769674 | CSI: 0.11963550247116969 | Loss: 0.010511299595236778


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.8773881200914918 | MAE: 0.10525911708253359 | CSI: 0.11996870560622594 | Loss: 0.010500743053853512


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.8986881268806691 | MAE: 0.10756238003838772 | CSI: 0.11968821754754476 | Loss: 0.010506429709494114


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.9135933108757834 | MAE: 0.10917466410748561 | CSI: 0.11950028837439236 | Loss: 0.010512260720133781


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.8986140990898661 | MAE: 0.10756238003838772 | CSI: 0.11969807745775453 | Loss: 0.010502759367227554


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.8943869105405711 | MAE: 0.1071017274472169 | CSI: 0.11974876441515651 | Loss: 0.010501205921173096


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.8986696199329683 | MAE: 0.10756238003838772 | CSI: 0.11969068237280805 | Loss: 0.010503779165446758


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.8781667057234775 | MAE: 0.10533589251439539 | CSI: 0.11994976788232509 | Loss: 0.010497772134840488

Model saved at ../models/baseline_bs128_epochs50_lr0.0005_adamw.ckpt


<IPython.core.display.Javascript object>

In [14]:
# AdamW bs50 lr 1e-4
fold = 0
train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.1213412779907173 | MAE: 0.13205374280230325 | CSI: 0.1177640967946915 | Loss: 0.015904391184449196


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.0232372097056714 | MAE: 0.12153550863723608 | CSI: 0.11877549749307091 | Loss: 0.013309536501765251


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 1.0139998351280197 | MAE: 0.12046065259117082 | CSI: 0.11879750707745274 | Loss: 0.012951242737472057


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 1.1047516526727235 | MAE: 0.12990403071017276 | CSI: 0.11758663623158332 | Loss: 0.012385123409330845


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.117690074479474 | MAE: 0.13120921305182343 | CSI: 0.11739319874680997 | Loss: 0.012682433240115643


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 0.9326559243791501 | MAE: 0.11155470249520154 | CSI: 0.1196097076942159 | Loss: 0.011910993605852127


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.027332020181361 | MAE: 0.1216122840690979 | CSI: 0.11837680679572474 | Loss: 0.011943171732127666


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.0422545127997724 | MAE: 0.12314779270633397 | CSI: 0.11815520220150837 | Loss: 0.011781273409724236


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.976940484692145 | MAE: 0.11623800383877159 | CSI: 0.11898166332458189 | Loss: 0.011739981360733509


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 0.9288001062140463 | MAE: 0.1110172744721689 | CSI: 0.11952762895750106 | Loss: 0.011434326879680157


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.994726636645222 | MAE: 0.11808061420345489 | CSI: 0.11870659722222222 | Loss: 0.011498453095555305


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.86822094408707 | MAE: 0.10441458733205375 | CSI: 0.12026269124499979 | Loss: 0.011420509777963161


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9243556771232797 | MAE: 0.11047984644913628 | CSI: 0.11952092596222288 | Loss: 0.011264418251812458


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.8254324608893246 | MAE: 0.09973128598848369 | CSI: 0.12082307240523024 | Loss: 0.01140605192631483


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.9816671665505846 | MAE: 0.11662188099808062 | CSI: 0.11879981827943667 | Loss: 0.011349151842296124


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.9067617427372024 | MAE: 0.10856046065259117 | CSI: 0.1197232476129591 | Loss: 0.011154781095683575


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.9926448922389418 | MAE: 0.11777351247600767 | CSI: 0.11864616782480304 | Loss: 0.011162293143570423


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 0.9942763377520618 | MAE: 0.11792706333973128 | CSI: 0.11860592358594774 | Loss: 0.011070771142840385


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 1.0252305224848342 | MAE: 0.12122840690978887 | CSI: 0.11824502319238835 | Loss: 0.011332811787724495


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.8948916276541378 | MAE: 0.1072552783109405 | CSI: 0.11985281233572806 | Loss: 0.011093420907855034


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 1.013906822535492 | MAE: 0.12 | CSI: 0.11835407093809695 | Loss: 0.01113690622150898


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.8771994920876025 | MAE: 0.10533589251439539 | CSI: 0.1200820263391109 | Loss: 0.011256896890699863


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.9377515842154777 | MAE: 0.11186180422264876 | CSI: 0.11928724632898861 | Loss: 0.010979496873915195


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8964470593831949 | MAE: 0.10740882917466411 | CSI: 0.11981614312804023 | Loss: 0.010982400737702847


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.9242243449999964 | MAE: 0.11040307101727448 | CSI: 0.11945483974061588 | Loss: 0.0109475776553154


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.9463299270633719 | MAE: 0.11278310940499041 | CSI: 0.11917948083289502 | Loss: 0.010897496715188026


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9341393703215161 | MAE: 0.11147792706333973 | CSI: 0.11933757488889118 | Loss: 0.010901767760515213


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.9795402145427297 | MAE: 0.1163147792706334 | CSI: 0.11874426138180767 | Loss: 0.010949915274977684


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8924269999687918 | MAE: 0.10694817658349329 | CSI: 0.11983969174659757 | Loss: 0.010875929147005081


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 1.0142418426017923 | MAE: 0.12 | CSI: 0.1183149767230612 | Loss: 0.010970874689519405


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.9095483358412575 | MAE: 0.10879078694817658 | CSI: 0.11960968170717885 | Loss: 0.010860883630812168


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.9272108389625103 | MAE: 0.11071017274472168 | CSI: 0.1194012926635673 | Loss: 0.010840944945812225


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.9187514970552954 | MAE: 0.10978886756238004 | CSI: 0.11949789242613186 | Loss: 0.010872705839574337


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.8528154604418748 | MAE: 0.10264875239923224 | CSI: 0.12036455383347923 | Loss: 0.01089971512556076


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.9545016953524549 | MAE: 0.11362763915547025 | CSI: 0.11904393644115857 | Loss: 0.010870699770748615


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.9730808299782855 | MAE: 0.11562380038387716 | CSI: 0.1188224007922181 | Loss: 0.010838640853762627


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.8882988317697552 | MAE: 0.10648752399232246 | CSI: 0.11987804124348489 | Loss: 0.010831070132553577


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.9609859280455686 | MAE: 0.11431861804222648 | CSI: 0.11895972116237342 | Loss: 0.010835222899913788


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.9067524560855617 | MAE: 0.10848368522072936 | CSI: 0.11963980300438913 | Loss: 0.01081210095435381


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.9167002517731152 | MAE: 0.10955854126679462 | CSI: 0.1195140298630504 | Loss: 0.010801968164741993


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.9488315770034639 | MAE: 0.11301343570057582 | CSI: 0.11910800445379191 | Loss: 0.010811405256390572


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.9066310051806491 | MAE: 0.10848368522072936 | CSI: 0.11965582976969447 | Loss: 0.010802964679896832


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.9316570332554117 | MAE: 0.11117082533589251 | CSI: 0.11932591218305504 | Loss: 0.010797062888741493


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.8995883642404501 | MAE: 0.10771593090211132 | CSI: 0.11973913312246938 | Loss: 0.010799948126077652


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.9295530052067085 | MAE: 0.1109404990403071 | CSI: 0.11934822265967206 | Loss: 0.010795538313686848


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.9423350900232739 | MAE: 0.11232245681381958 | CSI: 0.1191958762886598 | Loss: 0.010799865238368511


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.933066581537057 | MAE: 0.11132437619961612 | CSI: 0.11931021687144389 | Loss: 0.010794575326144695


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.9401394272301328 | MAE: 0.11209213051823416 | CSI: 0.11922926245901977 | Loss: 0.010793674737215042


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.9387956837220837 | MAE: 0.11193857965451055 | CSI: 0.11923635951303488 | Loss: 0.01079593040049076


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.9216488734398317 | MAE: 0.11009596928982726 | CSI: 0.1194554373814824 | Loss: 0.010792912915349007

Model saved at ../models/baseline_bs128_epochs50_lr0.0001_adamw.ckpt


<IPython.core.display.Javascript object>

In [16]:
# AdamW bs50 lr 1e-3
fold = 0
train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.138320426835042 | MAE: 0.13381957773512476 | CSI: 0.11755879502756099 | Loss: 0.01589236781001091


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.002306027157473 | MAE: 0.11930902111324376 | CSI: 0.11903452426660578 | Loss: 0.013353673741221428


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 1.0154380687146456 | MAE: 0.12061420345489443 | CSI: 0.11878046251166649 | Loss: 0.01299264281988144


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 1.114900060171623 | MAE: 0.130978886756238 | CSI: 0.11748038360941587 | Loss: 0.012409305199980736


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.1488281191112761 | MAE: 0.1344337811900192 | CSI: 0.11701818483766503 | Loss: 0.01272483542561531


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 0.9354791495688765 | MAE: 0.11186180422264876 | CSI: 0.11957701491611623 | Loss: 0.011894077993929386


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.0315937112542246 | MAE: 0.12207293666026871 | CSI: 0.11833431643434437 | Loss: 0.011928667314350605


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.007229160999509 | MAE: 0.11946257197696737 | CSI: 0.11860515620637235 | Loss: 0.011697919107973576


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.9667908785376472 | MAE: 0.11516314779270634 | CSI: 0.11911898462047296 | Loss: 0.011760309338569641


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 0.9315854624066506 | MAE: 0.11132437619961612 | CSI: 0.11949990708430551 | Loss: 0.011442412622272968


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9982712355163732 | MAE: 0.11846449136276392 | CSI: 0.11866964322625986 | Loss: 0.011561795137822628


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8730246773867592 | MAE: 0.10495201535508637 | CSI: 0.12021655065738593 | Loss: 0.011409180238842964


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9235669002860872 | MAE: 0.11040307101727448 | CSI: 0.11953987413597442 | Loss: 0.011275683529675007


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.81997550304274 | MAE: 0.09911708253358925 | CSI: 0.12087810204691314 | Loss: 0.011408326216042042


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.9859672869815994 | MAE: 0.11708253358925144 | CSI: 0.11874890286339745 | Loss: 0.011363031342625618


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.9031850700262466 | MAE: 0.10817658349328214 | CSI: 0.11977233358079684 | Loss: 0.011150655336678028


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.9805180073263742 | MAE: 0.116468330134357 | CSI: 0.11878244893324938 | Loss: 0.011165974661707878


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 0.9950718900787031 | MAE: 0.1180038387715931 | CSI: 0.11858825472525884 | Loss: 0.011075479909777641


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 1.0223605930731856 | MAE: 0.12092130518234165 | CSI: 0.11827657090912848 | Loss: 0.011321539990603924


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.9021475974230198 | MAE: 0.10802303262955854 | CSI: 0.11973986622280396 | Loss: 0.011085733771324158


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 1.0066560658663206 | MAE: 0.11923224568138195 | CSI: 0.11844387544395804 | Loss: 0.011128334328532219


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.8885852706487493 | MAE: 0.10656429942418426 | CSI: 0.11992579996908333 | Loss: 0.011248618364334106


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.9406901157859805 | MAE: 0.11216890595009597 | CSI: 0.11924108063518252 | Loss: 0.010979549959301949


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8930616191910925 | MAE: 0.10702495201535508 | CSI: 0.11984050116430028 | Loss: 0.010991481132805347


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.9242719859212941 | MAE: 0.11040307101727448 | CSI: 0.11944868252855552 | Loss: 0.010939003899693489


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.9343801849915703 | MAE: 0.11147792706333973 | CSI: 0.11930681841611514 | Loss: 0.010901150293648243


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9321022845520779 | MAE: 0.11124760076775432 | CSI: 0.11935128001567155 | Loss: 0.010904178023338318


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.9866086956438601 | MAE: 0.11708253358925144 | CSI: 0.11867170247456298 | Loss: 0.01094411313533783


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.9017261190518742 | MAE: 0.10794625719769674 | CSI: 0.11971069143510649 | Loss: 0.01087891310453415


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.9998807513369863 | MAE: 0.11846449136276392 | CSI: 0.11847861977876836 | Loss: 0.010963994078338146


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.9089405350759588 | MAE: 0.10871401151631478 | CSI: 0.11960519673195207 | Loss: 0.010861601680517197


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.9265165003474861 | MAE: 0.11063339731285989 | CSI: 0.11940790830000413 | Loss: 0.0108433673158288


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.924424436869447 | MAE: 0.11040307101727448 | CSI: 0.119428983714698 | Loss: 0.010874624364078045


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.8675104462257137 | MAE: 0.10426103646833014 | CSI: 0.12018418558655074 | Loss: 0.010894048027694225


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.9538233214104781 | MAE: 0.11355086372360844 | CSI: 0.11904811003650011 | Loss: 0.010875188745558262


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.9766951757220987 | MAE: 0.11600767754318618 | CSI: 0.11877572494042646 | Loss: 0.010839647613465786


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.8868015843103648 | MAE: 0.10633397312859885 | CSI: 0.11990728817924286 | Loss: 0.010835183784365654


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.9737914151813813 | MAE: 0.11570057581573896 | CSI: 0.11881453667694783 | Loss: 0.01084214635193348


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.9109421263596105 | MAE: 0.10894433781190019 | CSI: 0.1195952351510655 | Loss: 0.010812942869961262


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.9110078346453826 | MAE: 0.10894433781190019 | CSI: 0.11958660910243069 | Loss: 0.010806724429130554


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.9516903840669387 | MAE: 0.11332053742802303 | CSI: 0.11907290367147468 | Loss: 0.010813283734023571


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.9131134429218587 | MAE: 0.10917466410748561 | CSI: 0.11956308928847442 | Loss: 0.010803856886923313


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.9324002005527466 | MAE: 0.11124760076775432 | CSI: 0.11931314547216096 | Loss: 0.010799448005855083


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.9046681478553775 | MAE: 0.10825335892514396 | CSI: 0.11966084931902005 | Loss: 0.010800926014780998


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.9323809801656068 | MAE: 0.11124760076775432 | CSI: 0.11931560502989075 | Loss: 0.010798103176057339


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.9466129584040338 | MAE: 0.11278310940499041 | CSI: 0.1191438469152095 | Loss: 0.010802337899804115


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.9281252628678694 | MAE: 0.1107869481765835 | CSI: 0.11936637500128827 | Loss: 0.010795307345688343


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.9366597046688451 | MAE: 0.11170825335892515 | CSI: 0.11926236689928153 | Loss: 0.010794851928949356


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.9374134629679352 | MAE: 0.11178502879078694 | CSI: 0.11924837140265523 | Loss: 0.010798281989991665


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.921677368643685 | MAE: 0.11009596928982726 | CSI: 0.11945174421600453 | Loss: 0.010794124566018581

Model saved at ../models/baseline_bs128_epochs50_lr0.0001_adamw.ckpt


<IPython.core.display.Javascript object>

In [None]:
# AdamW bs50 lr 5e-4
fold = 0
train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034783065319061


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 0.9381335799455521 | MAE: 0.11247600767754319 | CSI: 0.11989338200976986 | Loss: 0.013658225536346436


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.091616736304388 | MAE: 0.12852207293666026 | CSI: 0.11773552810363963 | Loss: 0.012366656213998795


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 0.8717283066361261 | MAE: 0.10495201535508637 | CSI: 0.12039532794249776 | Loss: 0.012938014231622219


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 0.922680779149949 | MAE: 0.11040307101727448 | CSI: 0.11965467744766979 | Loss: 0.011683200486004353


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.1422260861540856 | MAE: 0.13351247600767754 | CSI: 0.11688795906953622 | Loss: 0.011974487453699112


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 1.0207991020989804 | MAE: 0.12084452975047985 | CSI: 0.11838228452687405 | Loss: 0.011398336850106716


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.3015827350484537 | MAE: 0.14971209213051823 | CSI: 0.11502310848003323 | Loss: 0.012028665281832218


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.0668947256657269 | MAE: 0.1256046065259117 | CSI: 0.11772914749997415 | Loss: 0.011318517848849297


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.8600657788579773 | MAE: 0.1034932821497121 | CSI: 0.1203318219291014 | Loss: 0.011268786154687405


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 1.1294213724123174 | MAE: 0.13213051823416508 | CSI: 0.11698956780923994 | Loss: 0.011137530207633972


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9699192152207351 | MAE: 0.11531669865642995 | CSI: 0.11889309629690772 | Loss: 0.011074976995587349


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8592266415081566 | MAE: 0.10333973128598849 | CSI: 0.1202706320927646 | Loss: 0.01094895415008068


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9156201995627933 | MAE: 0.10948176583493283 | CSI: 0.11957115612597288 | Loss: 0.010989376343786716


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.9186473420278555 | MAE: 0.10978886756238004 | CSI: 0.11951144094001237 | Loss: 0.010939965955913067


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.9428760437961236 | MAE: 0.11239923224568138 | CSI: 0.11920891721058764 | Loss: 0.011009103618562222


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.8349164229815289 | MAE: 0.10065259117082534 | CSI: 0.1205540918821011 | Loss: 0.010793941095471382


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.949548565503695 | MAE: 0.11309021113243763 | CSI: 0.11909892262487758 | Loss: 0.010786321945488453


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 0.8586650737446018 | MAE: 0.10326295585412668 | CSI: 0.12025987665125665 | Loss: 0.010778253898024559


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 1.0735752097766498 | MAE: 0.12621880998080615 | CSI: 0.11756867039244652 | Loss: 0.011026301421225071


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.8782398711638773 | MAE: 0.1054126679462572 | CSI: 0.12002719462700097 | Loss: 0.010951431468129158


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.961744436310654 | MAE: 0.1143953934740883 | CSI: 0.11894572939975458 | Loss: 0.010739585384726524


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.7626325761511776 | MAE: 0.0926679462571977 | CSI: 0.12151060570230005 | Loss: 0.011532123200595379


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.8935769142914931 | MAE: 0.10702495201535508 | CSI: 0.11977139326536917 | Loss: 0.010652078315615654


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8583507216152066 | MAE: 0.10318618042226488 | CSI: 0.12021447390087271 | Loss: 0.010725504718720913


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.8592220571537372 | MAE: 0.10326295585412668 | CSI: 0.12018191920647006 | Loss: 0.010612650774419308


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.910653194743647 | MAE: 0.10886756238003839 | CSI: 0.11954887218045113 | Loss: 0.010592028498649597


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.8931625445119947 | MAE: 0.10694817658349329 | CSI: 0.1197410003808818 | Loss: 0.01058614905923605


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.9508534944090756 | MAE: 0.11316698656429942 | CSI: 0.11901621777567127 | Loss: 0.01062245573848486


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8332982458772665 | MAE: 0.10042226487523992 | CSI: 0.1205117919919364 | Loss: 0.010614056140184402


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.9256429876075178 | MAE: 0.11047984644913628 | CSI: 0.1193547057853964 | Loss: 0.010645455680787563


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.8837848408830448 | MAE: 0.10595009596928982 | CSI: 0.11988222819317047 | Loss: 0.010543121956288815


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.8787736290945952 | MAE: 0.1054126679462572 | CSI: 0.11995429136168505 | Loss: 0.010546290315687656


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.8894002764619824 | MAE: 0.10656429942418426 | CSI: 0.11981590544046786 | Loss: 0.010572988539934158


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.8058613562069132 | MAE: 0.09742802303262955 | CSI: 0.12089923692383636 | Loss: 0.010588102042675018


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.9003947044562759 | MAE: 0.10771593090211132 | CSI: 0.1196319018404908 | Loss: 0.010574890300631523


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.937954560868547 | MAE: 0.11178502879078694 | CSI: 0.11917957804516235 | Loss: 0.010592584498226643


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.8038668596813368 | MAE: 0.09719769673704415 | CSI: 0.1209126804590137 | Loss: 0.010560299269855022


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.9095679654120318 | MAE: 0.10871401151631478 | CSI: 0.1195226917057903 | Loss: 0.010540400631725788


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.8960538365522246 | MAE: 0.1072552783109405 | CSI: 0.11969735961706728 | Loss: 0.010512854903936386


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.8620193727041435 | MAE: 0.1035700575815739 | CSI: 0.12014817863757975 | Loss: 0.010515077039599419


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.9139790870740366 | MAE: 0.10917466410748561 | CSI: 0.11944984918208304 | Loss: 0.010519146919250488


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.8697157778705387 | MAE: 0.10441458733205375 | CSI: 0.12005598839064253 | Loss: 0.010522506199777126


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.9010305223722821 | MAE: 0.10779270633397313 | CSI: 0.11963269129803683 | Loss: 0.010514034889638424


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.8783836508380664 | MAE: 0.10533589251439539 | CSI: 0.11992014242639416 | Loss: 0.010503781028091908


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.8798422581432144 | MAE: 0.105489443378119 | CSI: 0.11989585905985017 | Loss: 0.010508840903639793


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.9060126520071495 | MAE: 0.10833013435700575 | CSI: 0.11956801498975632 | Loss: 0.010515524074435234


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.8869249315791313 | MAE: 0.10625719769673704 | CSI: 0.11980404869966141 | Loss: 0.01050545647740364


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.891837017510536 | MAE: 0.10679462571976968 | CSI: 0.11974679635633781 | Loss: 0.010504990816116333


In [15]:
# AdamW bs50 lr 5e-4
for fold in range(5):
    train_fold(df, fold)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Training fold0...



  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.07034780830144882


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 0.9437922424302221 | MAE: 0.11309021113243763 | CSI: 0.1198253238872696 | Loss: 0.01358635164797306


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.1144372062329246 | MAE: 0.1309021113243762 | CSI: 0.1174602845195231 | Loss: 0.01243552751839161


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 0.8626368456547189 | MAE: 0.10395393474088292 | CSI: 0.1205071812822022 | Loss: 0.012849031016230583


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 0.9002768018883869 | MAE: 0.10794625719769674 | CSI: 0.11990340856320238 | Loss: 0.011554539203643799


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.0671439714543172 | MAE: 0.1256813819577735 | CSI: 0.11777359505243346 | Loss: 0.011750211007893085


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 0.9762669221542187 | MAE: 0.11608445297504799 | CSI: 0.11890646947037925 | Loss: 0.011374273337423801


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.268498090098286 | MAE: 0.14641074856046066 | CSI: 0.1154205510454083 | Loss: 0.012063690461218357


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.1031854423828713 | MAE: 0.12944337811900192 | CSI: 0.11733601001686725 | Loss: 0.01138242892920971


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.8217496448119358 | MAE: 0.09927063339731286 | CSI: 0.12080398698463693 | Loss: 0.011296030133962631


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 1.1088743781210597 | MAE: 0.12998080614203456 | CSI: 0.11721869375426691 | Loss: 0.011095265857875347


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9620862761671343 | MAE: 0.11447216890595009 | CSI: 0.11898326765561493 | Loss: 0.011082764714956284


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8788553600805569 | MAE: 0.105489443378119 | CSI: 0.12003049440077472 | Loss: 0.010911011137068272


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9072385822422571 | MAE: 0.10856046065259117 | CSI: 0.11966032174621005 | Loss: 0.010981745086610317


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.9236828223175657 | MAE: 0.11032629558541267 | CSI: 0.11944175307674865 | Loss: 0.010909012518823147


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.957795533843998 | MAE: 0.11401151631477927 | CSI: 0.11903533926103191 | Loss: 0.010993537493050098


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.8396774933378116 | MAE: 0.10119001919385796 | CSI: 0.12051057697256679 | Loss: 0.01081531960517168


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.9688093562689909 | MAE: 0.11516314779270634 | CSI: 0.11887080471151268 | Loss: 0.01078347209841013


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 0.8454214584932136 | MAE: 0.1018042226487524 | CSI: 0.12041830926476794 | Loss: 0.010767708532512188


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 1.0531790624176152 | MAE: 0.12406909788867562 | CSI: 0.11780437184424918 | Loss: 0.011013168841600418


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.9021290016177816 | MAE: 0.10802303262955854 | CSI: 0.11974233444988405 | Loss: 0.011004014872014523


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.9261467066757963 | MAE: 0.11055662188099807 | CSI: 0.11937268802357572 | Loss: 0.010737544856965542


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.738513689685624 | MAE: 0.08998080614203455 | CSI: 0.12184040377044293 | Loss: 0.011556853540241718


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.9217058638475384 | MAE: 0.11009596928982726 | CSI: 0.11944805127888043 | Loss: 0.010670513845980167


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8700357218950803 | MAE: 0.10449136276391555 | CSI: 0.12010008340283569 | Loss: 0.010724274441599846


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.8609963941380611 | MAE: 0.1034932821497121 | CSI: 0.12020176025528849 | Loss: 0.010626938194036484


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.9355479405170692 | MAE: 0.11155470249520154 | CSI: 0.1192399637292886 | Loss: 0.010589144192636013


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.8959708216309327 | MAE: 0.1072552783109405 | CSI: 0.11970845000823588 | Loss: 0.010607089847326279


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.9535970473260313 | MAE: 0.11347408829174664 | CSI: 0.11899584694497975 | Loss: 0.010617554187774658


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8343800580415958 | MAE: 0.10057581573896353 | CSI: 0.1205395727867638 | Loss: 0.010615648701786995


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.9420075384325132 | MAE: 0.11224568138195777 | CSI: 0.11915582073556538 | Loss: 0.010660631582140923


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.8873040214686102 | MAE: 0.10633397312859885 | CSI: 0.11983939050756717 | Loss: 0.01054992526769638


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.8851448177581197 | MAE: 0.10610364683301343 | CSI: 0.11987151108319863 | Loss: 0.01055346429347992


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.9027612589958749 | MAE: 0.10802303262955854 | CSI: 0.11965847177448426 | Loss: 0.010560257360339165


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.8119980124723106 | MAE: 0.0981190019193858 | CSI: 0.1208365050301397 | Loss: 0.010586261749267578


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.9156443179485217 | MAE: 0.10940499040307101 | CSI: 0.11948415804870009 | Loss: 0.010580405592918396


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.9486229827182243 | MAE: 0.11293666026871402 | CSI: 0.1190532617543715 | Loss: 0.010587424039840698


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.8343972284753538 | MAE: 0.10057581573896353 | CSI: 0.12053709229344582 | Loss: 0.01055158395320177


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.93222359442797 | MAE: 0.11117082533589251 | CSI: 0.1192533915676923 | Loss: 0.010545892640948296


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.8951787451674442 | MAE: 0.1071785028790787 | CSI: 0.11972860555143727 | Loss: 0.010513982735574245


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.8696352179660404 | MAE: 0.10441458733205375 | CSI: 0.12006710994915289 | Loss: 0.010518943890929222


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.9230175593085749 | MAE: 0.11017274472168906 | CSI: 0.11936148300720906 | Loss: 0.010523991659283638


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.865406931654537 | MAE: 0.10395393474088292 | CSI: 0.12012144915603129 | Loss: 0.010527591221034527


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.9123202405399485 | MAE: 0.109021113243762 | CSI: 0.11949873344728876 | Loss: 0.010516504757106304


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.878891236774245 | MAE: 0.1054126679462572 | CSI: 0.11993823983530623 | Loss: 0.010505661368370056


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.8987529011976216 | MAE: 0.10756238003838772 | CSI: 0.1196795914585178 | Loss: 0.010510049760341644


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.9165924059496204 | MAE: 0.10948176583493283 | CSI: 0.11944433002430284 | Loss: 0.010518706403672695


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.8994440072829465 | MAE: 0.10763915547024952 | CSI: 0.11967299197924282 | Loss: 0.010508017614483833


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.898789915093023 | MAE: 0.10756238003838772 | CSI: 0.1196746628230207 | Loss: 0.010506379418075085


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.9051807552372788 | MAE: 0.10825335892514396 | CSI: 0.11959308491469405 | Loss: 0.010508042760193348


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.8824160365458403 | MAE: 0.10579654510556621 | CSI: 0.11989417771555629 | Loss: 0.010502020828425884

Model saved at ../models/baseline_fold0_bs128_epochs50_lr0.0005_adamw.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


Training fold1...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.05332362651824951


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.4649042747142322 | MAE: 0.1602899718082964 | CSI: 0.10942009971136185 | Loss: 0.013296589255332947


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.290671138246638 | MAE: 0.14337494965767217 | CSI: 0.1110855781986066 | Loss: 0.0125888055190444


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 1.2783221086827037 | MAE: 0.14208618606524365 | CSI: 0.11115053482911558 | Loss: 0.011991249397397041


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 0.9015394723596947 | MAE: 0.10398711236407572 | CSI: 0.11534393728873313 | Loss: 0.01250691618770361


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.2422832403087438 | MAE: 0.13846153846153847 | CSI: 0.11145730214140571 | Loss: 0.012093799188733101


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 0.7875750740947998 | MAE: 0.0919049536850584 | CSI: 0.11669357843746118 | Loss: 0.01224138680845499


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 1.02749542826656 | MAE: 0.11695529601288764 | CSI: 0.11382561206055195 | Loss: 0.01118035800755024


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.195252290179063 | MAE: 0.13370922271445831 | CSI: 0.11186694542390864 | Loss: 0.011427664197981358


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 1.0641914459521074 | MAE: 0.12057994361659283 | CSI: 0.11330662736877063 | Loss: 0.011057481169700623


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 0.9369481790171504 | MAE: 0.10753121224325413 | CSI: 0.11476751292170329 | Loss: 0.010967950336635113


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.7484682012246785 | MAE: 0.08763592428513894 | CSI: 0.11708703741988837 | Loss: 0.011205273680388927


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.9380351910149244 | MAE: 0.1076923076923077 | CSI: 0.11480625537603764 | Loss: 0.011112474836409092


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9998407717848802 | MAE: 0.11397503020539669 | CSI: 0.11399318113516482 | Loss: 0.010790626518428326


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.8450261420411643 | MAE: 0.09794603302456706 | CSI: 0.11590887920595327 | Loss: 0.01095589343458414


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 1.0267842632014805 | MAE: 0.11671365283930729 | CSI: 0.1136691094917749 | Loss: 0.010724475607275963


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.9396971828684089 | MAE: 0.10777285541683447 | CSI: 0.1146889204104242 | Loss: 0.010688774287700653


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.7939414270633738 | MAE: 0.09246878775674587 | CSI: 0.11646802220407489 | Loss: 0.010742655955255032


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 1.1852240181120506 | MAE: 0.13266210229561015 | CSI: 0.11192998139351164 | Loss: 0.011140896938741207


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 1.0105798988834238 | MAE: 0.11510269834877165 | CSI: 0.11389767248976207 | Loss: 0.010770338587462902


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.8335278197512083 | MAE: 0.09665726943213854 | CSI: 0.11596165975618583 | Loss: 0.010712211951613426


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.8510294880021778 | MAE: 0.09850986709625453 | CSI: 0.11575376468641403 | Loss: 0.010666263289749622


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.8465844434347423 | MAE: 0.09802658074909384 | CSI: 0.11579067098201822 | Loss: 0.01059988234192133


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.9672245166705439 | MAE: 0.11059202577527184 | CSI: 0.11433956012094602 | Loss: 0.010529949329793453


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.997738360733415 | MAE: 0.11373338703181635 | CSI: 0.11399119399119399 | Loss: 0.010567311197519302


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.898391589075928 | MAE: 0.10342327829238825 | CSI: 0.11512048815803082 | Loss: 0.010495680384337902


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.9244896330381611 | MAE: 0.10616190092629883 | CSI: 0.11483298149757855 | Loss: 0.010658398270606995


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9298419514112347 | MAE: 0.10672573499798631 | CSI: 0.11477836081183179 | Loss: 0.010502760298550129


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.9704039361029636 | MAE: 0.11091421667337897 | CSI: 0.11429695670632578 | Loss: 0.010476275347173214


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 1.0245480664178566 | MAE: 0.11647200966572695 | CSI: 0.11368135227849807 | Loss: 0.010573608800768852


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.9706752428478566 | MAE: 0.11091421667337897 | CSI: 0.11426501035196687 | Loss: 0.010471746325492859


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.9757942295074935 | MAE: 0.11147805074506645 | CSI: 0.11424340027134232 | Loss: 0.010471350513398647


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.895234102893301 | MAE: 0.10310108739428112 | CSI: 0.11516662184804419 | Loss: 0.010457354597747326


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.8317717117439819 | MAE: 0.09641562625855819 | CSI: 0.115915971770152 | Loss: 0.01048274990171194


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 1.022891463602481 | MAE: 0.11631091421667338 | CSI: 0.11370797230628912 | Loss: 0.010477441363036633


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.8913645102599291 | MAE: 0.1026983487716472 | CSI: 0.11521476072769395 | Loss: 0.01042830292135477


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.9030110911423105 | MAE: 0.10390656463954893 | CSI: 0.11506676458115697 | Loss: 0.010428202338516712


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.9846335480455448 | MAE: 0.11236407571486105 | CSI: 0.11411765924176996 | Loss: 0.010422425344586372


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.9260881520169612 | MAE: 0.10632299637535239 | CSI: 0.1148087211167334 | Loss: 0.01040708739310503


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.909877958160892 | MAE: 0.10463149416028997 | CSI: 0.11499508612217452 | Loss: 0.010411178693175316


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 1.0135557133956454 | MAE: 0.115344341522352 | CSI: 0.1138016785825303 | Loss: 0.010445826686918736


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.9502076884845897 | MAE: 0.10881997583568265 | CSI: 0.11452230617948453 | Loss: 0.010401807725429535


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.9053927143266041 | MAE: 0.10414820781312928 | CSI: 0.11503097624292821 | Loss: 0.010389027185738087


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.946331936735179 | MAE: 0.10841723721304873 | CSI: 0.11456575964892307 | Loss: 0.010387437418103218


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.936402672617956 | MAE: 0.10737011679420057 | CSI: 0.11466233484050534 | Loss: 0.010386471636593342


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.9192851882635985 | MAE: 0.10559806685461136 | CSI: 0.11486975772246705 | Loss: 0.010383939370512962


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.9495266410697014 | MAE: 0.10873942811115586 | CSI: 0.11451961788845075 | Loss: 0.010390044189989567


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.9316930161089786 | MAE: 0.10688683044703987 | CSI: 0.11472322814278324 | Loss: 0.010382372885942459


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.9208344140544454 | MAE: 0.10575916230366492 | CSI: 0.11485144417777042 | Loss: 0.010380524210631847


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.9222982587823745 | MAE: 0.10592025775271849 | CSI: 0.11484382274735394 | Loss: 0.010379559360444546


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.9301114122389976 | MAE: 0.10672573499798631 | CSI: 0.11474510858881105 | Loss: 0.010381213389337063

Model saved at ../models/baseline_fold1_bs128_epochs50_lr0.0005_adamw.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


Training fold2...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.06792119145393372


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.0583238103064436 | MAE: 0.12316623519259434 | CSI: 0.11637859225322782 | Loss: 0.018375571817159653


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.3090439302248436 | MAE: 0.14827018121911037 | CSI: 0.11326600872159744 | Loss: 0.0128381522372365


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 1.0749716990031783 | MAE: 0.124499882325253 | CSI: 0.11581689307693106 | Loss: 0.012292868457734585


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 1.5262270258583122 | MAE: 0.16905938652231897 | CSI: 0.11076948819309364 | Loss: 0.014321698807179928


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 0.8281121601940147 | MAE: 0.09837608849140975 | CSI: 0.11879560912079053 | Loss: 0.013426098972558975


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 0.9191661424525188 | MAE: 0.10802541774535185 | CSI: 0.1175254535118094 | Loss: 0.011718537658452988


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 0.9063162075373175 | MAE: 0.10661332078136032 | CSI: 0.11763369108244068 | Loss: 0.01155879907310009


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 1.0341454411945987 | MAE: 0.12002824193927983 | CSI: 0.11606514631016929 | Loss: 0.01131217647343874


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 1.0352961355132178 | MAE: 0.12018514160194556 | CSI: 0.11608769460086121 | Loss: 0.011417840607464314


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 1.0188106526559253 | MAE: 0.11845924531262257 | CSI: 0.1162720913869459 | Loss: 0.011291706003248692


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.8717325295321008 | MAE: 0.10284772887738292 | CSI: 0.11798083172566053 | Loss: 0.011077452450990677


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.8649155800572065 | MAE: 0.10214168039538715 | CSI: 0.1180943929669604 | Loss: 0.011104998178780079


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9185423571477423 | MAE: 0.10786851808268612 | CSI: 0.11743445170750849 | Loss: 0.011087087914347649


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.937323037184238 | MAE: 0.10982976386600769 | CSI: 0.11717386590113488 | Loss: 0.01098643708974123


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.9917942060852485 | MAE: 0.11555660155330666 | CSI: 0.11651268059775531 | Loss: 0.011159125715494156


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.87468919949303 | MAE: 0.10316152820271436 | CSI: 0.11794078200763437 | Loss: 0.010921032167971134


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.8852680079904909 | MAE: 0.10425982584137444 | CSI: 0.1177720474471377 | Loss: 0.01089795958250761


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 1.0215452563791814 | MAE: 0.1186161449752883 | CSI: 0.11611442981458897 | Loss: 0.010930635966360569


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 0.9635993753446791 | MAE: 0.11257550796265788 | CSI: 0.1168281246772561 | Loss: 0.010946526192128658


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.990630094343072 | MAE: 0.11539970189064093 | CSI: 0.11649121357066851 | Loss: 0.011101718060672283


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.9296621215830078 | MAE: 0.1089668157213462 | CSI: 0.11721120307114405 | Loss: 0.010813501663506031


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 1.053979582174606 | MAE: 0.1219894877226014 | CSI: 0.1157417940391735 | Loss: 0.010878403671085835


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.8713598432778761 | MAE: 0.10276927904605006 | CSI: 0.11794126139504187 | Loss: 0.01081886701285839


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8094857284627778 | MAE: 0.09610104338275673 | CSI: 0.11871863826981134 | Loss: 0.010966574773192406


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.8459103347898103 | MAE: 0.10002353494939986 | CSI: 0.11824366110080396 | Loss: 0.010708808898925781


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.8446276818383552 | MAE: 0.09986663528673413 | CSI: 0.11823746419076277 | Loss: 0.010703754611313343


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9895471686168905 | MAE: 0.11524280222797521 | CSI: 0.1164601404378356 | Loss: 0.010782686993479729


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.769290903636 | MAE: 0.09170785282811642 | CSI: 0.11921088939684109 | Loss: 0.01082665380090475


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8198090405058381 | MAE: 0.09719934102141681 | CSI: 0.11856339247079187 | Loss: 0.010702860541641712


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.9320700526860811 | MAE: 0.10920216521534479 | CSI: 0.11716089890422832 | Loss: 0.010679910890758038


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.885834141218907 | MAE: 0.10425982584137444 | CSI: 0.11769677977982106 | Loss: 0.010661674663424492


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.8572212453409812 | MAE: 0.1012002824193928 | CSI: 0.11805619957340258 | Loss: 0.0106227220967412


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.9336796264837647 | MAE: 0.10935906487801052 | CSI: 0.11712696922489657 | Loss: 0.010662314482033253


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.9107966428687078 | MAE: 0.10692712010669177 | CSI: 0.11739955449220361 | Loss: 0.010680872946977615


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.9366501332996858 | MAE: 0.10967286420334196 | CSI: 0.11709053391797782 | Loss: 0.010653090663254261


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.9027136102945182 | MAE: 0.10606417196203029 | CSI: 0.1174948186795624 | Loss: 0.010605936869978905


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.9012293495300664 | MAE: 0.10590727229936456 | CSI: 0.1175142291512002 | Loss: 0.010662487708032131


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.8489388914597604 | MAE: 0.10033733427473131 | CSI: 0.1181914685300271 | Loss: 0.010626079514622688


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.9501680730955672 | MAE: 0.11108496116733349 | CSI: 0.11691085431283532 | Loss: 0.010639404878020287


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.850572468240016 | MAE: 0.10049423393739704 | CSI: 0.11814893814337384 | Loss: 0.010601965710520744


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.8908257676194905 | MAE: 0.10480897466070448 | CSI: 0.11765373035839485 | Loss: 0.010582203976809978


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.8185795509506226 | MAE: 0.09704244135875108 | CSI: 0.1185497991554228 | Loss: 0.010606820695102215


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.9011550090739481 | MAE: 0.10590727229936456 | CSI: 0.11752392344497607 | Loss: 0.010577197186648846


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.8879572498856813 | MAE: 0.10449517533537303 | CSI: 0.11768041237113402 | Loss: 0.010582765564322472


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.8653973955432505 | MAE: 0.1020632305640543 | CSI: 0.1179379913653646 | Loss: 0.010572961531579494


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.8799125258902079 | MAE: 0.10363222719071154 | CSI: 0.11777560171107561 | Loss: 0.010572567582130432


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.8857702229511826 | MAE: 0.10425982584137444 | CSI: 0.11770527292407608 | Loss: 0.010572829283773899


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.8616950848384399 | MAE: 0.10167098140738998 | CSI: 0.11798951066964791 | Loss: 0.010569415986537933


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.8690044359771177 | MAE: 0.1024554797207186 | CSI: 0.11789983511953833 | Loss: 0.010570191778242588


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.8623908562366291 | MAE: 0.10174943123872283 | CSI: 0.11798528532860705 | Loss: 0.01056947186589241

Model saved at ../models/baseline_fold2_bs128_epochs50_lr0.0005_adamw.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


Training fold3...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.06584510207176208


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 0.9761461365590198 | MAE: 0.11572746298736532 | CSI: 0.1185554689529748 | Loss: 0.016777969896793365


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 0.8674826352979895 | MAE: 0.10386791721571971 | CSI: 0.1197348661384703 | Loss: 0.013262536376714706


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 0.984027520372654 | MAE: 0.11627005658476088 | CSI: 0.11815732200228003 | Loss: 0.01209369394928217


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 0.9891841708089898 | MAE: 0.1166576234400434 | CSI: 0.11793316844490899 | Loss: 0.011719717644155025


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.0537721338059343 | MAE: 0.12340128672195953 | CSI: 0.11710433666072982 | Loss: 0.011907537467777729


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 1.2459166432193667 | MAE: 0.1432447097124254 | CSI: 0.11497134297928997 | Loss: 0.012449313886463642


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 0.9462066203023297 | MAE: 0.11192930780559647 | CSI: 0.11829267033545683 | Loss: 0.011184750124812126


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 0.9627440123754404 | MAE: 0.11371211533989613 | CSI: 0.11811251368716814 | Loss: 0.011118467897176743


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.8771097191461886 | MAE: 0.10448802418417177 | CSI: 0.11912765518663641 | Loss: 0.01101650483906269


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 0.9138557533172998 | MAE: 0.10851871947911014 | CSI: 0.11874819311939867 | Loss: 0.011204872280359268


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.8554829748464027 | MAE: 0.10216262305247656 | CSI: 0.11942098914354644 | Loss: 0.010996270924806595


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 1.097735005830804 | MAE: 0.12781954887218044 | CSI: 0.11643934846948278 | Loss: 0.011197962798178196


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 0.9800621222273554 | MAE: 0.1154949228741958 | CSI: 0.11784449194989208 | Loss: 0.010927666909992695


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 0.8232752663309416 | MAE: 0.09859700798387722 | CSI: 0.11976189740579397 | Loss: 0.010764451697468758


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 1.0497969881777247 | MAE: 0.12278117975350748 | CSI: 0.11695706992414377 | Loss: 0.010931246913969517


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 0.8431765659576419 | MAE: 0.10076738237345942 | CSI: 0.11950923026207351 | Loss: 0.010661198757588863


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 1.0157086375232767 | MAE: 0.11921556468490814 | CSI: 0.11737181341156055 | Loss: 0.010978718288242817


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 0.9975301019021314 | MAE: 0.11727773040849547 | CSI: 0.11756811166286404 | Loss: 0.01100278552621603


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 0.8262334319267731 | MAE: 0.09890706146810324 | CSI: 0.11970837495237409 | Loss: 0.010737092234194279


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.790807704258165 | MAE: 0.09503139291527789 | CSI: 0.12017003931901929 | Loss: 0.010724930092692375


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.8398138968407228 | MAE: 0.10037981551817689 | CSI: 0.11952626158599382 | Loss: 0.010706203989684582


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.8178549615586439 | MAE: 0.09797690101542517 | CSI: 0.11979740372044184 | Loss: 0.010585488751530647


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.8513987351271364 | MAE: 0.101620029455081 | CSI: 0.11935656615587412 | Loss: 0.010546199977397919


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.9827907270442443 | MAE: 0.11572746298736532 | CSI: 0.11775392237819983 | Loss: 0.010765822604298592


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 1.0323321113161896 | MAE: 0.1209208588481513 | CSI: 0.11713367967692959 | Loss: 0.01074812188744545


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.8066592825906277 | MAE: 0.09673668707852104 | CSI: 0.11992261065544246 | Loss: 0.010556642897427082


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 0.9649874075775624 | MAE: 0.11378962871095263 | CSI: 0.1179182524211764 | Loss: 0.010568722151219845


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.8237076721618036 | MAE: 0.09859700798387722 | CSI: 0.11969902832674571 | Loss: 0.010539213195443153


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.8324059296905256 | MAE: 0.0995271684365553 | CSI: 0.11956566488266776 | Loss: 0.01049722172319889


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.956361963256247 | MAE: 0.11285946825827456 | CSI: 0.11800915614946796 | Loss: 0.010548410937190056


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.9077731340303097 | MAE: 0.10766607239748857 | CSI: 0.1186046032432878 | Loss: 0.010499398224055767


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.8091561787903138 | MAE: 0.09696922719169057 | CSI: 0.11983993910279489 | Loss: 0.010456260293722153


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.9129124508798632 | MAE: 0.10820866599488412 | CSI: 0.11853126320018957 | Loss: 0.010469174012541771


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.9177270800913426 | MAE: 0.10875125959227967 | CSI: 0.11850065444351689 | Loss: 0.010467367246747017


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.8670352892262351 | MAE: 0.10324781024726766 | CSI: 0.11908143939393939 | Loss: 0.010418031364679337


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.8610727589358931 | MAE: 0.1026277032788156 | CSI: 0.1191858669466922 | Loss: 0.01043530460447073


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.7667947876794768 | MAE: 0.09231842492830013 | CSI: 0.12039521709180263 | Loss: 0.010462336242198944


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.8685409212526897 | MAE: 0.10340283698938067 | CSI: 0.11905350048374812 | Loss: 0.010403391905128956


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.7888173193755786 | MAE: 0.09472133943105186 | CSI: 0.12008019740900679 | Loss: 0.010409584268927574


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.7951825273769476 | MAE: 0.09541895977056042 | CSI: 0.11999629831470494 | Loss: 0.010402645915746689


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.8022224399679789 | MAE: 0.0961940934811255 | CSI: 0.11990950226244344 | Loss: 0.010389924049377441


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.7697754150328099 | MAE: 0.09262847841252617 | CSI: 0.12033182224689302 | Loss: 0.010405430570244789


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.8248249742946362 | MAE: 0.09867452135493372 | CSI: 0.11963086039979835 | Loss: 0.010377668775618076


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.8526745313404479 | MAE: 0.10169754282613751 | CSI: 0.1192688875852913 | Loss: 0.010379205457866192


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.8434967891742426 | MAE: 0.10068986900240291 | CSI: 0.11937196477076302 | Loss: 0.010375346057116985


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.8384023821780489 | MAE: 0.10014727540500737 | CSI: 0.11945013221932975 | Loss: 0.010371914133429527


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.8541217422244934 | MAE: 0.10185256956825052 | CSI: 0.11924830446550784 | Loss: 0.010385624133050442


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.8484406626048611 | MAE: 0.10123246259979847 | CSI: 0.11931590158367548 | Loss: 0.01037642452865839


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.8419922996367154 | MAE: 0.1005348422602899 | CSI: 0.1194011421515666 | Loss: 0.010372198186814785


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.8477266269374579 | MAE: 0.10115494922874196 | CSI: 0.11932496398435892 | Loss: 0.010375108569860458

Model saved at ../models/baseline_fold3_bs128_epochs50_lr0.0005_adamw.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


Training fold4...


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Epoch 0 | MAE/CSI: 1000000000000.0 | MAE: 1.0 | CSI: 0.0 | Loss: 0.051339007914066315


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0 | MAE/CSI: 1.2923509983887724 | MAE: 0.14612412270325684 | CSI: 0.1130684488069754 | Loss: 0.013486770913004875


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1 | MAE/CSI: 1.2824421687627843 | MAE: 0.14502010882422522 | CSI: 0.11308120736769645 | Loss: 0.013099905103445053


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2 | MAE/CSI: 1.0337804664341868 | MAE: 0.119785505874931 | CSI: 0.11587131868245944 | Loss: 0.013135443441569805


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3 | MAE/CSI: 0.8676452444378151 | MAE: 0.10212128381042504 | CSI: 0.11769935289131468 | Loss: 0.012131770141422749


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4 | MAE/CSI: 1.0104902871287074 | MAE: 0.11710432931156849 | CSI: 0.11588862436600766 | Loss: 0.011633886024355888


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5 | MAE/CSI: 1.06292039288501 | MAE: 0.12254554057251005 | CSI: 0.11529136273209549 | Loss: 0.011597689241170883


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6 | MAE/CSI: 0.9304383558461906 | MAE: 0.10866650895039823 | CSI: 0.11679065922711304 | Loss: 0.011226480826735497


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7 | MAE/CSI: 0.8195911449757846 | MAE: 0.09683778881791656 | CSI: 0.118153776319726 | Loss: 0.011528192088007927


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8 | MAE/CSI: 0.8785722230291139 | MAE: 0.10314643955524012 | CSI: 0.11740234536295317 | Loss: 0.011167199350893497


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9 | MAE/CSI: 0.8861688488530748 | MAE: 0.10393502089740557 | CSI: 0.11728579833407306 | Loss: 0.011075042188167572


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10 | MAE/CSI: 0.9453812543293786 | MAE: 0.11024367163472912 | CSI: 0.11661292322956716 | Loss: 0.011175395920872688


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11 | MAE/CSI: 0.9979241244851905 | MAE: 0.11576374102988723 | CSI: 0.11600455203807158 | Loss: 0.011165978386998177


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12 | MAE/CSI: 1.007878343679855 | MAE: 0.11671003864048576 | CSI: 0.11579774421321423 | Loss: 0.011008608154952526


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13 | MAE/CSI: 1.1125746265559728 | MAE: 0.12751360302815234 | CSI: 0.11461128088258145 | Loss: 0.011036342941224575


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14 | MAE/CSI: 0.7850852822097438 | MAE: 0.09305259837552243 | CSI: 0.11852546530082239 | Loss: 0.010939818806946278


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 15 | MAE/CSI: 1.0153115494115053 | MAE: 0.11749861998265121 | CSI: 0.11572666542574093 | Loss: 0.011188478209078312


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 16 | MAE/CSI: 0.986025980481453 | MAE: 0.11442315274820598 | CSI: 0.11604476455209614 | Loss: 0.011116521432995796


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 17 | MAE/CSI: 1.0617361140641164 | MAE: 0.12223010803564388 | CSI: 0.11512286943571073 | Loss: 0.010877163149416447


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 18 | MAE/CSI: 0.9572121558345978 | MAE: 0.11142654364797729 | CSI: 0.11640736378850802 | Loss: 0.011018428951501846


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 19 | MAE/CSI: 0.7521329035471026 | MAE: 0.08942512420156139 | CSI: 0.11889537577610972 | Loss: 0.010738796554505825


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 20 | MAE/CSI: 0.9892293528688089 | MAE: 0.11473858528507215 | CSI: 0.11598784948236315 | Loss: 0.01070215180516243


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 21 | MAE/CSI: 0.9864335429133184 | MAE: 0.11442315274820598 | CSI: 0.1159968186090711 | Loss: 0.010646497830748558


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 22 | MAE/CSI: 0.84469233657039 | MAE: 0.09944010724706254 | CSI: 0.117723463255229 | Loss: 0.010614980012178421


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 23 | MAE/CSI: 0.8534076363799895 | MAE: 0.10038640485766107 | CSI: 0.11763007568414757 | Loss: 0.010652757249772549


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 24 | MAE/CSI: 0.8727730631181053 | MAE: 0.10243671634729122 | CSI: 0.11736924599901007 | Loss: 0.010685691609978676


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 25 | MAE/CSI: 0.8287695273573423 | MAE: 0.09770522829429856 | CSI: 0.11789191695430426 | Loss: 0.010555294342339039


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 26 | MAE/CSI: 1.0085243594280549 | MAE: 0.11671003864048576 | CSI: 0.1157235693401246 | Loss: 0.010769852437078953


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 27 | MAE/CSI: 0.861976988570521 | MAE: 0.10125384433404305 | CSI: 0.11746699236263566 | Loss: 0.010523894801735878


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 28 | MAE/CSI: 0.87451386176962 | MAE: 0.10259443261572432 | CSI: 0.11731595930022783 | Loss: 0.010531861335039139


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 29 | MAE/CSI: 0.8629818743300961 | MAE: 0.1013327024682596 | CSI: 0.11742158842682274 | Loss: 0.010494248941540718


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 30 | MAE/CSI: 0.8834115504948823 | MAE: 0.10354073022632285 | CSI: 0.11720554272517321 | Loss: 0.010468833148479462


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 31 | MAE/CSI: 0.8848855029908551 | MAE: 0.10369844649475593 | CSI: 0.11718854715483198 | Loss: 0.01055654976516962


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 32 | MAE/CSI: 0.8179083911242114 | MAE: 0.0965223562810504 | CSI: 0.1180112067899963 | Loss: 0.010491106659173965


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 33 | MAE/CSI: 0.8508487946200639 | MAE: 0.10007097232079488 | CSI: 0.11761310934762445 | Loss: 0.01048244908452034


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 34 | MAE/CSI: 0.8400339456350763 | MAE: 0.09888810030754672 | CSI: 0.11771917173173999 | Loss: 0.010445468127727509


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 35 | MAE/CSI: 0.8725101451576892 | MAE: 0.10235785821307468 | CSI: 0.11731423271153252 | Loss: 0.01043706201016903


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 36 | MAE/CSI: 0.9569020424454356 | MAE: 0.1112688273795442 | CSI: 0.11628026949783844 | Loss: 0.010467946529388428


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 37 | MAE/CSI: 0.8362668945479855 | MAE: 0.09849380963646401 | CSI: 0.11777796093299267 | Loss: 0.010502946563065052


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 38 | MAE/CSI: 0.8416683758276093 | MAE: 0.09904581657597981 | CSI: 0.11767795894403989 | Loss: 0.010454561561346054


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 39 | MAE/CSI: 0.8865609501271944 | MAE: 0.10385616276318903 | CSI: 0.11714497773379515 | Loss: 0.010411867871880531


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 40 | MAE/CSI: 0.8445140682777522 | MAE: 0.099361249112846 | CSI: 0.11765493654194824 | Loss: 0.010427672415971756


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 41 | MAE/CSI: 0.8718152715712825 | MAE: 0.10227900007885814 | CSI: 0.11731728430685519 | Loss: 0.010411477647721767


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 42 | MAE/CSI: 0.8327265628359038 | MAE: 0.09809951896538129 | CSI: 0.11780519962094681 | Loss: 0.010417300276458263


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 43 | MAE/CSI: 0.8829431978104287 | MAE: 0.1034618720921063 | CSI: 0.11717840099770158 | Loss: 0.010404979810118675


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 44 | MAE/CSI: 0.8866157844264348 | MAE: 0.10385616276318903 | CSI: 0.11713773269837344 | Loss: 0.010403498075902462


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 45 | MAE/CSI: 0.870380796129511 | MAE: 0.10212128381042504 | CSI: 0.11732943128303948 | Loss: 0.010401197709143162


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 46 | MAE/CSI: 0.8607582781995061 | MAE: 0.10109612806560997 | CSI: 0.11745007933727615 | Loss: 0.01040069479495287


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 47 | MAE/CSI: 0.8755787668255379 | MAE: 0.10267329074994086 | CSI: 0.11726334013479256 | Loss: 0.010401761159300804


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 48 | MAE/CSI: 0.8866706187256753 | MAE: 0.10385616276318903 | CSI: 0.11713048855905998 | Loss: 0.010412335395812988


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 49 | MAE/CSI: 0.8792732161385125 | MAE: 0.10306758142102358 | CSI: 0.11721906175282382 | Loss: 0.010402982123196125

Model saved at ../models/baseline_fold4_bs128_epochs50_lr0.0005_adamw.ckpt


<IPython.core.display.Javascript object>

## Inference

In [21]:
def inference(checkpoints):
    datamodule = NowcastingDataModule()
    datamodule.setup("test")

    test_paths = datamodule.test_dataset.paths
    test_filenames = [path.name for path in test_paths]
    final_preds = np.zeros((len(datamodule.test_dataset), 14400))

    for checkpoint in checkpoints:
        print("Inference from", checkpoint)
        model = Baseline.load_from_checkpoint(str(checkpoint))
        model.cuda()
        model.eval()
        preds = []
        with torch.no_grad():
            for batch in tqdm(datamodule.test_dataloader()):
                batch = batch.cuda()
                imgs = model(batch)
                imgs = imgs.detach().cpu().numpy()
                imgs = imgs[:, 0, 4:124, 4:124]
                imgs = args["rng"] * imgs
                imgs = imgs.clip(0, 255)
                imgs = imgs.round()
                preds.append(imgs)

        preds = np.concatenate(preds)
        preds = preds.astype(np.uint8)
        preds = preds.reshape(-1, 14400)
        final_preds += preds / len(checkpoints)

        del model
        gc.collect()
        torch.cuda.empty_cache()

    final_preds = final_preds.round()
    final_preds = final_preds.astype(np.uint8)

    subm = pd.DataFrame()
    subm["file_name"] = test_filenames
    for i in tqdm(range(14400)):
        subm[str(i)] = final_preds[:, i]

    return subm

<IPython.core.display.Javascript object>

In [22]:
checkpoints = [
    args["model_dir"]
    / f"baseline_fold{fold}_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.ckpt"
    for fold in range(5)
]

<IPython.core.display.Javascript object>

In [23]:
subm = inference(checkpoints)

Inference from ../models/baseline_fold0_bs128_epochs50_lr0.0005_adamw.ckpt


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


Inference from ../models/baseline_fold1_bs128_epochs50_lr0.0005_adamw.ckpt


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


Inference from ../models/baseline_fold2_bs128_epochs50_lr0.0005_adamw.ckpt


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


Inference from ../models/baseline_fold3_bs128_epochs50_lr0.0005_adamw.ckpt


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))


Inference from ../models/baseline_fold4_bs128_epochs50_lr0.0005_adamw.ckpt


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14400.0), HTML(value='')))




<IPython.core.display.Javascript object>

In [24]:
output_csv = (
    args["output_dir"]
    / f"baseline_bs{args['batch_size']}_epochs{args['max_epochs']}_lr{args['lr']}_{args['optimizer']}.csv"
)
subm.to_csv(output_csv, index=False)
subm.head()

Unnamed: 0,file_name,0,1,2,3,4,5,6,7,8,...,14390,14391,14392,14393,14394,14395,14396,14397,14398,14399
0,test_00000.npy,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,test_00001.npy,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,test_00002.npy,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,test_00003.npy,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,test_00004.npy,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


<IPython.core.display.Javascript object>