In [None]:
%reload_ext nb_black

In [None]:
import gc
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn import model_selection

import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl

In [None]:
PATH = Path("../input")

# 🔥 Baseline ⚡️

## Dataset

In [None]:
class NowcastingDataset(torch.utils.data.Dataset):
    def __init__(self, paths, test=False):
        self.paths = paths
        self.test = test

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        data = np.load(path)
        x = data[:, :, :4]
        #         x = x / 255.0
        x = x.astype(np.float32)
        x = torch.tensor(x, dtype=torch.float)
        x = x.permute(2, 0, 1)
        if self.test:
            return x
        else:
            y = data[:, :, 4]
            #             y = y / 255.0
            y = y.astype(np.float32)
            y = torch.tensor(y, dtype=torch.float)
            y = y.unsqueeze(-1)
            y = y.permute(2, 0, 1)

            return x, y

In [None]:
class NowcastingDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, test=False, num_workers=4):
        super().__init__()
        self.test = test
        self.batch_size = batch_size
        self.num_workers = 4

    def setup(self, stage="train"):
        if stage == "train":
            paths = list((PATH / "train").glob("*.npy"))
            train_paths, val_paths = model_selection.train_test_split(
                paths, test_size=0.1, shuffle=True
            )
            self.train_dataset = NowcastingDataset(train_paths)
            self.val_dataset = NowcastingDataset(val_paths)
        else:
            test_paths = list((PATH / "test").glob("*.npy"))
            self.test_dataset = NowcastingDataset(test_paths, test=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=2 * self.batch_size,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=2 * self.batch_size,
            pin_memory=True,
            num_workers=self.num_workers,
        )

In [None]:
def visualize(x, y=None, test=False):
    cmap = plt.cm.get_cmap("RdBu")
    cmap = cmap.reversed()
    if test:
        fig, axes = plt.subplots(1, 4, figsize=(10, 10))
        for i, ax in enumerate(axes):
            img = x[:, :, i]
            ax.imshow(img, cmap=cmap)
    else:
        fig, axes = plt.subplots(1, 5, figsize=(10, 10))
        for i, ax in enumerate(axes[:-1]):
            img = x[:, :, i]
            ax.imshow(img, cmap=cmap)
        axes[-1].imshow(y[:, :, 0], cmap=cmap)
    #     plt.tight_layout()
    plt.show()

In [None]:
datamodule = NowcastingDataModule(batch_size=32)
datamodule.setup()
for batch in datamodule.train_dataloader():
    xs, ys = batch
    x, y = xs[0], ys[0]
    x = x.permute(1, 2, 0).numpy()
    y = y.permute(1, 2, 0).numpy()
    visualize(x, y)
    break

## Model

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, chs=[4, 64, 128]):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.blocks = nn.ModuleList(
            [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
        )
        self.conv = nn.Conv2d(128, 512, kernel_size=3, padding=1)

    def forward(self, x):
        ftrs = []
        for block in self.blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        x = self.conv(x)
        ftrs.append(x)
        return ftrs

In [None]:
class Decoder(nn.Module):
    def __init__(self, chs=[512, 128, 64]):
        super().__init__()
        self.tr_convs = nn.ModuleList(
            [
                nn.ConvTranspose2d(chs[i], chs[i + 1], kernel_size=2, stride=2)
                for i in range(len(chs) - 1)
            ]
        )
        self.blocks = nn.ModuleList(
            [Block(2 * chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]
        )

    def forward(self, x, ftrs):
        for i, ftr in enumerate(ftrs):
            x = self.tr_convs[i](x)
            x = torch.cat([ftr, x], dim=1)
            x = self.blocks[i](x)
        return x

In [None]:
class Baseline(pl.LightningModule):
    def __init__(self, lr=1e-3, enc_chs=[4, 64, 128], dec_chs=[512, 128, 64]):
        super().__init__()
        self.lr = lr
        self.criterion = nn.L1Loss()
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.out = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        ftrs = self.encoder(x)
        ftrs = ftrs[::-1]
        x = self.decoder(ftrs[0], ftrs[1:])
        out = self.out(x)
        return out

    def shared_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log("val_loss", loss)
        return {"loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        print(f"Epoch {self.current_epoch} | MAE: {avg_loss}")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## Train

In [None]:
datamodule = NowcastingDataModule(batch_size=256)
datamodule.setup()

In [None]:
model = Baseline()

In [None]:
trainer = pl.Trainer(
    gpus=1, max_epochs=10, precision=16, progress_bar_refresh_rate=50, benchmark=True
)

In [None]:
# lr_finder = trainer.tuner.lr_find(model, datamodule)
# fig = lr_finder.plot(suggest=True)

In [None]:
# model.lr = lr_finder.suggestion()
# model.lr

In [None]:
trainer.fit(model, datamodule)

In [None]:
trainer.save_checkpoint("baseline_bs256_epoch10.ckpt")

## Inference

In [None]:
model = Baseline.load_from_checkpoint("baseline_bs256_epoch10.ckpt")
datamodule = NowcastingDataModule(batch_size=128)
datamodule.setup("test")

In [None]:
preds = []
model.to("cuda")
model.eval()
with torch.no_grad():
    for batch in tqdm(datamodule.test_dataloader(), total=len(datamodule.test_dataloader())):
        batch = batch.to("cuda")
        imgs = model(batch)
        imgs = imgs.detach().cpu().numpy()
        imgs = np.round(imgs)
        imgs = np.clip(imgs, 0, 255)
        preds.append(imgs)

preds = np.concatenate(preds)
preds = preds.astype(np.uint8)
preds = preds.reshape(len(preds), -1)

In [None]:
test_paths = datamodule.test_dataset.paths
test_filenames = [path.name for path in test_paths]

In [None]:
subm = pd.DataFrame()
subm["file_name"] = test_filenames
for i in tqdm(range(14400)):
    subm[str(i)] = preds[:, i]

In [None]:
subm.to_csv("baseline_epoch10.csv", index=False)
subm.head()