In [1]:
%load_ext autoreload
%autoreload 2

from dutils.w4c_dataloader import RainData
from dutils.data_utils import load_config
from meteopress.models.unet_attention.model import UNet_Attention
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch
import torchvision.transforms as T

In [2]:
params = load_config("models/configurations/config_smaat.yaml")
data = RainData("training", **params["dataset"])

In [3]:
class DataModule(pl.LightningDataModule):
    def __init__(self, params: dict, mode: str):
        super().__init__()
        self.params = params
        if mode in ['train']:
            self.train_ds = RainData('training', **self.params)
            self.val_ds = RainData('validation', **self.params)
        if mode in ['val']:
            self.val_ds = RainData('validation', **self.params)
        if mode in ['predict']:
            self.test_ds = RainData('test', **self.params)

    def __load_dataloader(self, dataset, shuffle=True, pin=True):
        dl = DataLoader(dataset,
                        batch_size=self.params['train']['batch_size'],
                        num_workers=self.params['train']['n_workers'],
                        shuffle=shuffle, pin_memory=pin, prefetch_factor=2,
                        persistent_workers=False)
        return dl

    def train_dataloader(self) -> DataLoader:
        return self.__load_dataloader(self.train_ds, shuffle=True, pin=True)

    def val_dataloader(self) -> DataLoader:
        return self.__load_dataloader(self.val_ds, shuffle=False, pin=True)

    def test_dataloader(self) -> DataLoader:
        return self.__load_dataloader(self.test_ds, shuffle=False, pin=True)


In [3]:
d = [torch.tensor(data[i][0]) for i in range(16)]
x = torch.stack(d)
# x = x.swapaxes(1, 2)

In [9]:
ys = [torch.tensor(data[i][1]) for i in range(16)]
ys = torch.stack(ys)

In [6]:
x.shape

torch.Size([16, 11, 4, 252, 252])

In [4]:
from meteopress.models.equivariant.model import *

In [5]:
rot_unet = RotUNet(44, 32, 3, 8)

  full_mask[mask] = norms.to(torch.uint8)


In [9]:
input_frames = 44
output_frames = 32
N = 8
kernel_size = 5

In [6]:
resize = T.Resize((256, 256))

In [50]:
resize(x.flatten(1, 2)).shape

torch.Size([16, 44, 256, 256])

In [7]:
res = rot_unet(resize(x.flatten(1, 2)))

In [9]:
res.shape

torch.Size([16, 32, 256, 256])

In [10]:
image_size = 256
crop_size = int((2 / 12) * image_size)
radar_crop = T.CenterCrop((crop_size, crop_size))

In [15]:
res.shape

torch.Size([16, 32, 256, 256])

In [22]:
crop = radar_crop(res)

In [19]:
from models.models import SRCNN

In [20]:
srcnn = SRCNN(32)

In [24]:
srcnn(crop).shape

torch.Size([16, 32, 252, 252])

In [146]:
d = [torch.tensor(data[i][0]) for i in range(16)]
x = torch.stack(d)
x.shape


torch.Size([16, 11, 4, 252, 252])

In [113]:
x.shape


torch.Size([16, 4, 11, 252, 252])

In [123]:
x.squeeze(0).shape


torch.Size([16, 11, 252, 252])

In [23]:
x.flatten(1, 2).shape


torch.Size([16, 44, 252, 252])

In [24]:
x = x[:, 0]


In [25]:
x.shape


torch.Size([16, 11, 252, 252])

In [26]:
x.unsqueeze(1).shape


torch.Size([16, 1, 11, 252, 252])

In [10]:
from matplotlib import pyplot as plt
from utils.visualization import animate


In [10]:
meta = [data[i][2] for i in range(16)]

In [11]:
ys.shape

torch.Size([16, 1, 32, 252, 252])

In [12]:
from torchmetrics import Dice

In [13]:
dice = Dice(average='micro')

In [17]:
xs = x.unsqueeze(1)

In [26]:
animate(x[0][9], vmin=x[0][9].min(), vmax=x[0][9].max())


In [21]:
animate(ys[0][0], vmin=ys[0][0].min(), vmax=ys[0][0].max())


In [None]:
animate(ys[0][0], vmin=ys[0][0].min(), vmax=ys[0][0].max())


In [62]:
crop_size = (2 / 12) * 252
crop = T.CenterCrop((42, 42))


In [69]:
class SRCNN(nn.Module):
    def __init__(self, channels):
        super(SRCNN, self).__init__()

        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(channels, channels * 2, kernel_size=9, padding=2,
                      padding_mode='replicate'),
            nn.ReLU(inplace=True),

            nn.Conv2d(channels * 2, channels, kernel_size=1, padding=2,
                      padding_mode='replicate'),

            nn.Upsample(scale_factor=3, mode='bilinear', align_corners=True),
            torch.nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=5, padding=2,
                      padding_mode='replicate'),
            torch.nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=1, padding=0,
                      padding_mode='replicate'),
        )

    def forward(self, x):
        return self.model(x)


In [70]:
srcnn = SRCNN(32)


In [71]:
res.shape


torch.Size([16, 32, 252, 252])

In [72]:
srcnn(crop(res)).shape

torch.Size([16, 32, 252, 252])

In [73]:
crop(res).shape


torch.Size([16, 32, 42, 42])

In [74]:
res.shape


torch.Size([16, 32, 252, 252])

In [153]:
res.unsqueeze(1).shape


torch.Size([16, 1, 32, 252, 252])

In [152]:
res.reshape((16, 32, 11, 252, 252))


RuntimeError: shape '[16, 32, 11, 252, 252]' is invalid for input of size 32514048

In [None]:
# Hinge