In [39]:
%load_ext autoreload
%autoreload 2

from dutils.w4c_dataloader import RainData
from dutils.data_utils import load_config
from meteopress.models.unet_attention.model import UNet_Attention
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
params = load_config("models/configurations/config_smaat.yaml")

In [20]:
data = RainData("training", **params["dataset"])

In [21]:
len(data)

20273

In [22]:
x = data[0]

In [23]:
x[0].shape

(11, 4, 252, 252)

In [24]:
x[1].shape

(1, 32, 252, 252)

In [25]:
class DataModule(pl.LightningDataModule):
    def __init__(self, params: dict, mode: str):
        super().__init__()
        self.params = params
        if mode in ['train']:
            self.train_ds = RainData('training', **self.params)
            self.val_ds = RainData('validation', **self.params)
        if mode in ['val']:
            self.val_ds = RainData('validation', **self.params)
        if mode in ['predict']:
            self.test_ds = RainData('test', **self.params)

    def __load_dataloader(self, dataset, shuffle=True, pin=True):
        dl = DataLoader(dataset,
                        batch_size=self.params['train']['batch_size'],
                        num_workers=self.params['train']['n_workers'],
                        shuffle=shuffle, pin_memory=pin, prefetch_factor=2,
                        persistent_workers=False)
        return dl

    def train_dataloader(self) -> DataLoader:
        return self.__load_dataloader(self.train_ds, shuffle=True, pin=True)

    def val_dataloader(self) -> DataLoader:
        return self.__load_dataloader(self.val_ds, shuffle=False, pin=True)

    def test_dataloader(self) -> DataLoader:
        return self.__load_dataloader(self.test_ds, shuffle=False, pin=True)

In [84]:
from turtle import forward
import pytorch_lightning as pl
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from meteopress.models.unet_attention.model import UNet_Attention
from dutils.evaluate import *

class ModelBase(pl.LightningModule):
    def __init__(self, params: dict) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.params = params
        self.loss = params['train']['loss']

        losses = {
            'MSE': F.mse_loss,
            'SmoothL1Loss': nn.SmoothL1Loss(),
            'L1': nn.L1Loss(),
            'BCELoss': nn.BCELoss(),
            'CrossEntropy': nn.CrossEntropyLoss(),
        }

        self.loss_fn = losses[self.loss]

    def retrieve_only_valid_pixels(self, x, mask):
        """ we asume 1s in mask are invalid pixels """
        return x[~mask]

    def get_target_mask(self, metadata):
        return metadata['target']['mask']

    def calculate_loss(self, pred, target, mask=None):
        if mask is not None:
            pred = self.retrieve_only_valid_pixels(pred, mask)
            target = self.retrieve_only_valid_pixels(target, mask)

        return self.loss_fn(pred, target)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=float(self.params['train']["lr"]),
                                      weight_decay=float(self.params['train']["weight_decay"]))
        return optimizer


class SmaAt(ModelBase):
    def __init__(self, params: dict):
        super().__init__(params)

        self.model = UNet_Attention(
            n_channels=11 * 4,
            n_classes=32
        )

    def forward(self, x):
        return self.model(x.swapaxes(1, 2))

    def training_step(self, batch, batch_idx):
        x, y, meta = batch
        # flatten ?
        # x = x.swapaxes(1, 2)
        pred = self.model(x)
        mask = self.get_target_mask(meta)
        loss = self.calculate_loss(pred, y, mask)

        self.log('train_loss', loss,
                 batch_size=self.params['train']['batch_size'], sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y, meta = batch
        x = x.swapaxes(1, 2)
        pred = self.model(x)
        mask = self.get_target_mask(meta)
        loss = self.calculate_loss(pred, y, mask)

        recall, precision, f1, acc, csi = recall_precision_f1_acc(y, pred)
        iou = iou_class(pred, y)

        self.log_dict(
            {
                'val_loss': loss,
                'val_recall': recall,
                'val_precision': precision,
                'val_f1': f1,
                'val_acc': acc,
                'val_csi': csi,
                'val_iou': iou
            },
            batch_size=self.params['train']['batch_size'], sync_dist=True)

        return loss

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x for x in outputs]).mean()
        self.log(f'val_loss_epoch', avg_loss, prog_bar=True,
                 batch_size=self.bs, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y, meta = batch
        x = x.swapaxes(1, 2)
        pred = self.model(x)
        mask = self.get_target_mask(meta)
        loss = self.calculate_loss(pred, y, mask)

        recall, precision, f1, acc, csi = recall_precision_f1_acc(y, pred)
        iou = iou_class(pred, y)

        self.log_dict(
            {
                'test_loss': loss,
                'test_recall': recall,
                'test_precision': precision,
                'test_f1': f1,
                'test_acc': acc,
                'test_csi': csi,
                'test_iou': iou
            },
            batch_size=self.params['train']['batch_size'], sync_dist=True)


In [92]:
smaat = SmaAt(params)

In [130]:
d = [torch.tensor(data[i][0]) for i in range(16)]
x = torch.stack(d)
x = x.swapaxes(1, 2)

In [146]:
d = [torch.tensor(data[i][0]) for i in range(16)]
x = torch.stack(d)
x.shape

torch.Size([16, 11, 4, 252, 252])

In [113]:
x.shape

torch.Size([16, 4, 11, 252, 252])

In [123]:
x.squeeze(0).shape

torch.Size([16, 11, 252, 252])

In [131]:
x.flatten(1, 2).shape

torch.Size([16, 44, 252, 252])

In [114]:
x = x[:, 0]

In [119]:
x.shape

torch.Size([16, 11, 252, 252])

In [116]:
x.unsqueeze(1).shape

torch.Size([16, 1, 11, 252, 252])

In [135]:
res = smaat(x.flatten(1, 2).swapaxes(1, 2))

In [137]:
res.shape

torch.Size([16, 32, 252, 252])

In [153]:
res.unsqueeze(1).shape

torch.Size([16, 1, 32, 252, 252])

In [152]:
res.reshape((16, 32, 11, 252, 252))

RuntimeError: shape '[16, 32, 11, 252, 252]' is invalid for input of size 32514048