In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt

from satforecast.data import data
from satforecast.modeling.model_selection import rolling_batch
from satforecast.modeling.train import train

data.download()
image_dir = data.process_gs_rainfall_daily(scale=0.1)
image_files = data.get_files(image_dir, '*.npy')

In [3]:
class CNN3D(nn.Module):
    def __init__(self,
            seq_len = 5, # Data
            image_size = (60, 180),
            in_channels = 1, # Conv
            conv_channels = 2,
            conv_kernals = 3,
            conv_stride = 3,
            conv_padding = 1,
            batch_feats = 10, # BatchNorm
            pool_kernals = 2, # Pool
            pool_stride = 2,
            pool_padding = 1,
            linear_feats = None # Linear
            ):
        super().__init__()
        # Swapping (batch_size, seq_len, n_channels, height, width)
        # To (batch_size, n_channels, seq_len, height, width)
        self.permute_order = (0, 2, 1, 3, 4)

        # Convolutional layers
        if type(conv_channels) is int:
            conv_channels = [conv_channels]

        conv_steps = len(conv_channels)

        if type(conv_kernals) is int:
            conv_kernals = [conv_kernals] * conv_steps
            conv_stride = [conv_stride] * conv_steps
            conv_padding = [conv_padding] * conv_steps
            batch_feats = [batch_feats] * conv_steps

        conv_layers = []
        for in_c, out_c, c_kern, c_stride, c_pad, b_feats, p_kern, p_stride, p_pad in zip(
                [in_channels] + conv_channels[:-1],
                conv_channels,
                conv_kernals, conv_stride, conv_padding,
                batch_feats,
                pool_kernals, pool_stride, pool_padding
        ):
            conv_layers.extend([
                nn.Conv3d(in_c, out_c, c_kern, c_stride, c_pad),
                nn.BatchNorm3d(b_feats),
                nn.ReLU(inplace=True),
                nn.MaxPool3d(p_kern, p_stride, p_pad)
            ])

        conv_layers = nn.Sequential(*conv_layers)

        # Calculate output shape of convolutional layers
        with torch.no_grad():
            test_tensor = torch.zeros(1, in_channels, seq_len, image_size[0], image_size[1])
            test_tensor = test_tensor.permute(*self.permute_order)
            test_output = conv_layers(test_tensor)
            conv_output_shape = test_output.shape

        # Fully connected layers
        if linear_feats is None:
            linear_feats = []
        elif type(linear_feats) is int:
            linear_feats = [linear_feats]

        fc_layers = [
            np.prod(*conv_output_shape),
            *linear_feats,
            in_channels * image_size[0] * image_size[1]
        ]
        for in_feats, out_feats in zip(fc_layers[:-1], fc_layers[1:]):
            fc_layers.extend([
                nn.Linear(in_feats, out_feats),
                nn.ReLU(inplace=True)
            ])

        self.model = nn.Sequential(*conv_layers, *fc_layers)

    def forward(self, x):
        batch_size, _, n_channels, height, width = x.shape
        x = x.permute(*self.permute_order)
        x = self.model(x)
        return x.view(batch_size, n_channels, height, width)