In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import skorch
from skorch import NeuralNetRegressor
from skorch.callbacks import Checkpoint
import numpy as np

In [2]:
from satforecast.data import data

data.download()
image_dir = data.process_gs_rainfall_daily(scale=0.1)
image_files = data.get_files(image_dir, '*.npy')

Processing file number 0 (GPM_3IMERGDL_2015-01-01.PNG)
Processing file number 100 (GPM_3IMERGDL_2015-04-11.PNG)
Processing file number 200 (GPM_3IMERGDL_2015-07-20.PNG)
Processing file number 300 (GPM_3IMERGDL_2015-10-28.PNG)
Processing file number 400 (GPM_3IMERGDL_2016-02-05.PNG)
Processing file number 500 (GPM_3IMERGDL_2016-05-15.PNG)
Processing file number 600 (GPM_3IMERGDL_2016-08-23.PNG)
Processing file number 700 (GPM_3IMERGDL_2016-12-01.PNG)
Processing file number 800 (GPM_3IMERGDL_2017-03-11.PNG)
Processing file number 900 (GPM_3IMERGDL_2017-06-19.PNG)
Processing file number 1000 (GPM_3IMERGDL_2017-09-27.PNG)
Processing file number 1100 (GPM_3IMERGDL_2018-01-05.PNG)
Processing file number 1200 (GPM_3IMERGDL_2018-04-15.PNG)
Processing file number 1300 (GPM_3IMERGDL_2018-07-24.PNG)
Processing file number 1400 (GPM_3IMERGDL_2018-11-01.PNG)
Processing file number 1500 (GPM_3IMERGDL_2019-02-09.PNG)
Processing file number 1600 (GPM_3IMERGDL_2019-05-20.PNG)
Processing file number 170

In [3]:
class CNN3D(nn.Module):
    def __init__(self,
            seq_len = 5, # Data
            image_size = (60, 180),
            in_channels = 1, # Conv
            conv_channels = 2,
            conv_kernals = 3,
            conv_stride = 3,
            conv_padding = 1,
            batch_feats = 10, # BatchNorm
            pool_kernals = 2, # Pool
            pool_stride = 2,
            pool_padding = 1,
            linear_feats = None # Linear
            ):
        super().__init__()
        # Swapping (batch_size, seq_len, n_channels, height, width)
        # To (batch_size, n_channels, seq_len, height, width)
        self.permute_order = (0, 2, 1, 3, 4)

        # Convolutional layers
        if type(conv_channels) is int:
            conv_channels = [conv_channels]

        conv_steps = len(conv_channels)

        if type(conv_kernals) is int:
            conv_kernals = [conv_kernals] * conv_steps
            conv_stride = [conv_stride] * conv_steps
            conv_padding = [conv_padding] * conv_steps
            batch_feats = [batch_feats] * conv_steps

        conv_layers = []
        for in_c, out_c, c_kern, c_stride, c_pad, b_feats, p_kern, p_stride, p_pad in zip(
                [in_channels] + conv_channels[:-1],
                conv_channels,
                conv_kernals, conv_stride, conv_padding,
                batch_feats,
                pool_kernals, pool_stride, pool_padding
        ):
            conv_layers.extend([
                nn.Conv3d(in_c, out_c, c_kern, c_stride, c_pad),
                nn.BatchNorm3d(b_feats),
                nn.ReLU(inplace=True),
                nn.MaxPool3d(p_kern, p_stride, p_pad)
            ])

        conv_layers = nn.Sequential(*conv_layers)

        # Calculate output shape of convolutional layers
        with torch.no_grad():
            test_tensor = torch.zeros(1, in_channels, seq_len, image_size[0], image_size[1])
            test_tensor = test_tensor.permute(*self.permute_order)
            test_output = conv_layers(test_tensor)
            conv_output_shape = test_output.shape

        # Fully connected layers
        if linear_feats is None:
            linear_feats = []
        elif type(linear_feats) is int:
            linear_feats = [linear_feats]

        fc_layers = [
            np.prod(*conv_output_shape),
            *linear_feats,
            in_channels * image_size[0] * image_size[1]
        ]
        for in_feats, out_feats in zip(fc_layers[:-1], fc_layers[1:]):
            fc_layers.extend([
                nn.Linear(in_feats, out_feats),
                nn.ReLU(inplace=True)
            ])

        self.model = nn.Sequential(*conv_layers, *fc_layers)

    def forward(self, x):
        batch_size, _, n_channels, height, width = x.shape
        x = x.permute(*self.permute_order)
        x = self.model(x)
        return x.view(batch_size, n_channels, height, width)

In [4]:
net = NeuralNetRegressor(
    CNN3D,
    criterion=nn.MSELoss,
    optimizer=optim.Adam,
    lr=0.001,
    max_epochs=50,
    batch_size=32,
    train_split=skorch.dataset.ValidSplit(5),
    device='cuda' if torch.cuda.is_available() else 'cpu',
    callbacks=[
        Checkpoint(dirname='checkpoints', f_params='best_params_cnn3d.pt', monitor='valid_loss_best')
    ]
)

In [6]:
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

"""
class ImageSequenceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(root_dir))

    def __len__(self):
        return len(self.image_files) - 1

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image1_path = os.path.join(self.root_dir, self.image_files[idx])
        image2_path = os.path.join(self.root_dir, self.image_files[idx + 1])

        image1 = Image.open(image1_path).convert('RGB')
        image2 = Image.open(image2_path).convert('RGB')

        if self.transform:
            image1 = self.transform(image1)
            image2 = self.transform(image2)

        return image1, image2

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = ImageSequenceDataset('path/to/dataset', transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
""";