In [2]:
!pip install pytorch-lightning



In [3]:
PATH = '/kaggle/input/nowcasting-yandex-cup/'

In [4]:
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import tqdm
import pytorch_lightning as L



In [5]:
print(torch.cuda.is_available())

True


In [6]:
from pytorch_lightning import seed_everything
import random
seed=7
seed_everything(seed, workers=True)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [8]:
class RadarDataset(data.Dataset):

    def __init__(self, list_of_files, in_seq_len=4, out_seq_len=12, mode='sequentially', rotate = 0, with_time=False):
        self.in_seq_len = in_seq_len
        self.out_seq_len = out_seq_len
        self.seq_len = in_seq_len + out_seq_len
        self.with_time = with_time
        self.__prepare_timestamps_mapping(list_of_files)
        self.__prepare_sequences(mode)
        self.rotate = rotate

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        to_append = []
        data = []
        targets = []
        for timestamp in self.sequences[index]:
            with h5py.File(self.timestamp_to_file[timestamp]) as d:
                targets.append(np.array(torch.rot90(torch.tensor(np.array([d[timestamp]['intensity']])), self.rotate)))
                data.append(np.array([d[timestamp]['intensity'], d[timestamp]['reflectivity'][0], d[timestamp]['reflectivity'][1], d[timestamp]['reflectivity'][2], 
                                      d[timestamp]['reflectivity'][3], d[timestamp]['reflectivity'][4], d[timestamp]['reflectivity'][5],
                                     d[timestamp]['reflectivity'][6], d[timestamp]['reflectivity'][7],
                                     d[timestamp]['radial_velocity'][0], d[timestamp]['radial_velocity'][1], d[timestamp]['radial_velocity'][2],
                                     d[timestamp]['radial_velocity'][3], d[timestamp]['radial_velocity'][4],
                                     d[timestamp]['radial_velocity'][5], d[timestamp]['radial_velocity'][6], d[timestamp]['radial_velocity'][7]]))
                #data.append(np.array([d[timestamp]['intensity'], d[timestamp]['events'], d[timestamp]['reflectivity'][0]]))
                
        data = np.array(data)
        targets = np.array(targets)
        #data = np.expand_dims(data, axis=1)
        #targets = np.expand_dims(targets, axis=1)
        data[data == -1e6] = 0
        data[data == -2e6] = -1
        targets[targets == -1e6] = 0
        targets[targets == -2e6] = -1
        inputs = data[:self.in_seq_len]
        targets = targets[self.in_seq_len:]
        if self.with_time:
            return (inputs, self.sequences[index][-1]), targets
        else:
            return inputs, targets

    def __prepare_timestamps_mapping(self, list_of_files):
        self.timestamp_to_file = {}
        for filename in list_of_files:
            with h5py.File(filename) as d:
                self.timestamp_to_file = {
                    **self.timestamp_to_file,
                    **dict(map(lambda x: (x, filename), d.keys()))
                }

    def __prepare_sequences(self, mode):
        timestamps = np.unique(sorted(self.timestamp_to_file.keys()))
        if mode == 'sequentially':
            self.sequences = [
                timestamps[index * self.seq_len: (index + 1) * self.seq_len]
                for index in range(len(timestamps) // self.seq_len)
            ]
        elif mode == 'overlap':
            self.sequences = [
                timestamps[index: index + self.seq_len]
                for index in range(len(timestamps) - self.seq_len + 1)
            ] 
        else:
            raise Exception(f'Unknown mode {mode}')
        self.sequences = list(filter(
            lambda x: int(x[-1]) - int(x[0]) == (self.seq_len - 1) * 600,
            self.sequences
        ))

In [9]:
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, activation):
        super().__init__()

        if activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'relu':
            self.activation = torch.relu

        self.conv_int_0 = nn.Conv2d(
            in_channels=in_channels + out_channels,
            out_channels= 4 * out_channels,
            kernel_size=kernel_size,
            padding=padding
        )
        
        self.conv_int_1 = nn.Conv2d(
            in_channels= 4 * out_channels,
            out_channels=4 * out_channels,
            kernel_size=kernel_size,
            padding=padding
        )
        self.conv_refl_0 = nn.Conv2d(
            in_channels=16 + out_channels,
            out_channels= 4 * out_channels,
            kernel_size=kernel_size,
            padding=padding
        )
        
    def forward(self, X, H_prev, C_prev):
        
        int_X = X[:,0:1, :, :]
        refl_X= X[:,1:, :, :]
        H_prev_int, H_prev_refl = torch.chunk(H_prev, chunks=2, dim=1)
        conv_int_output = self.conv_int_1(self.activation(self.conv_int_0(torch.cat([int_X, H_prev_int], dim=1))))
        conv_refl_output = self.conv_refl_0(torch.cat([refl_X, H_prev_refl], dim=1))
        
        i_conv_int, f_conv_int, C_conv_int, o_conv_int = torch.chunk(conv_int_output, chunks=4, dim=1)
        i_conv_refl, f_conv_refl, C_conv_refl, o_conv_refl = torch.chunk(conv_refl_output, chunks=4, dim=1)


        input_gate = torch.sigmoid(torch.cat([i_conv_int, i_conv_refl], dim=1))
        forget_gate = torch.sigmoid(torch.cat([f_conv_int, f_conv_refl], dim=1))
        output_gate = torch.sigmoid(torch.cat([o_conv_int, o_conv_refl], dim=1))
        C_conv = torch.cat([C_conv_int, C_conv_refl], dim=1)
        C = forget_gate * C_prev + input_gate * self.activation(C_conv)
        H = output_gate * self.activation(C)
        return H, C


class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, activation):
        super().__init__()
        self.out_channels = out_channels
        self.convLSTMCell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding, activation)

    def forward(self, X):
        batch_size, seq_len, _, height, width = X.size()
        output = torch.zeros(batch_size, seq_len, 2 * self.out_channels, height, width, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        H = torch.zeros(batch_size, 2 * self.out_channels, height, width, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        C = torch.zeros(batch_size, 2 * self.out_channels, height, width, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        for time_step in range(seq_len):
            H, C = self.convLSTMCell(X[:, time_step], H, C)
            output[:, time_step] = H
        return output


class Seq2Seq(nn.Module):

    def __init__(
        self, num_channels, num_kernels, kernel_size, padding, activation, num_layers, out_seq_len
    ):
        super().__init__()
        self.out_seq_len = out_seq_len

        self.activation = torch.relu
        self.sequential = nn.Sequential()
        self.sequential.add_module(
            'convlstm1',
            ConvLSTM(
                in_channels=num_channels,
                out_channels=num_kernels,
                kernel_size=kernel_size,
                padding=padding,
                activation=activation
            )
        )
        for layer_index in range(2, num_layers + 1):
            self.sequential.add_module(
                f'convlstm{layer_index}',
                ConvLSTM(
                    in_channels=num_kernels,
                    out_channels=num_kernels,
                    kernel_size=kernel_size,
                    padding=padding,
                    activation=activation
                )
            )
        self.conv = nn.Conv2d(
            in_channels=2 * num_kernels,
            out_channels=num_channels,
            kernel_size=kernel_size,
            padding=padding
        )

    def forward(self, X):
        batch_size, seq_len, num_channels, height, width = X.size()
        inputs = torch.zeros(
            batch_size, seq_len + self.out_seq_len - 1, num_channels, height, width,
            device=self.conv.weight.device
        )
        inputs[:, :seq_len] = X
        output = self.sequential(inputs)
        output = torch.stack([
            self.conv(output[:, index + seq_len - 1])
            for index in range(self.out_seq_len)
        ], dim=1)
        return output


class ConvLSTMModel(L.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = Seq2Seq(
            num_channels=1,
            num_kernels=32,
            kernel_size=(3, 3),
            padding=(1, 1),
            activation='relu',
            num_layers=1,
            out_seq_len=12
        )

    def forward(self, x):
        x = x.to(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        output = self.model(x)
        return output

    def training_step(self, batch):
        x, y = batch
        out = self.forward(x)
        out[y == -1] = -1
        loss = F.mse_loss(out, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-4)
        return optimizer

In [10]:
def prepare_midyear_loaders(train_batch_size=1):
    train_datasets = []
    val_datasets = []
    for i in range(4, 10):
        month = ''
        if i < 10:
            month = '0' + str(i)
        else:
            month = str(i)
        path = PATH + '2021-' + month + '-train.hdf5'
        month_dataset = RadarDataset([path])
        train_month, val_month = torch.utils.data.random_split(month_dataset, [0.8, 0.2])
        train_datasets.append(train_month)
        val_datasets.append(val_month)
    full_train_dataset = torch.utils.data.ConcatDataset([train_datasets[i] for i in range(6)])
    train_loader = data.DataLoader(full_train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
    val_loaders = [data.DataLoader(val_datasets[i], batch_size=train_batch_size, shuffle=False, num_workers=4) for i in range(6)] 
    return train_loader, val_loaders
    
def prepare_month_loaders(train_batch_size=1): #summer up
    train_datasets = []
    val_datasets = []
    for i in range(1, 13):
        month = ''
        if i < 10:
            month = '0' + str(i)
        else:
            month = str(i)
        path = PATH + '2021-' + month + '-train.hdf5'
        month_dataset = RadarDataset([path])
        train_month, val_month = torch.utils.data.random_split(month_dataset, [0.8, 0.2])
        train_datasets.append(train_month)
        if int(month) >= 5 and int(month) <= 8:
            train_datasets.append(train_month)
        val_datasets.append(val_month)
    full_train_dataset = torch.utils.data.ConcatDataset([train_datasets[i] for i in range(len(train_datasets))])
    train_loader = data.DataLoader(full_train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
    val_loaders = [data.DataLoader(val_datasets[i], batch_size=train_batch_size, shuffle=False, num_workers=4) for i in range(12)] 
    return train_loader, val_loaders


def prepare_train_loader(train_batch_size=1):
    train_dataset = RadarDataset([
        PATH + '2021-01-train.hdf5', PATH + '2021-03-train.hdf5', PATH + '2021-04-train.hdf5',
        PATH + '2021-06-train.hdf5', PATH + '2021-07-train.hdf5', PATH + '2021-09-train.hdf5',
        PATH + '2021-10-train.hdf5', PATH + '2021-12-train.hdf5'])
    
    #train_dataset = torch.utils.data.ConcatDataset([train_dataset_0, train_dataset_1, train_dataset_2, train_dataset_3])
    #train_data = torch.utils.random_split(train_dataset, [0.8, 0.2])
    return data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
def prepare_valid_loader(valid_batch_size=1):
    #valid_dataset = RadarDataset([PATH + '2021-08-train.hdf5'])
    valid_dataset = RadarDataset([PATH + '2021-08-train.hdf5', PATH + '2021-05-train.hdf5', PATH + '2021-02-train.hdf5', PATH + '2021-11-train.hdf5'])
    valid_loader = data.DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=True, num_workers=4)
    return valid_loader

def prepare_test_loader(test_batch_size=1):
    test_dataset = RadarDataset([PATH + '2022-test-public.hdf5'], out_seq_len=0, with_time=True)
    test_loader = data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
    return test_loader

def evaluate_on_val(model, valid_loader):
    rmses = np.zeros((12,), dtype=float)
    for item in tqdm.tqdm(valid_loader):
        inputs, target = item
        output = model(inputs)
        rmses += np.sum((
            np.square(target.detach().cpu().numpy() - output.detach().cpu().numpy())
        ) * (target.detach().cpu().numpy() != -1), axis=(0, 2, 3, 4))
    rmses /= len(valid_loader)
    return np.mean(np.sqrt(rmses))


def process_test(model, test_loader, output_file='../output.hdf5'):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    for index, item in tqdm.tqdm(enumerate(test_loader)):
        (inputs, last_input_timestamp), _ = item
        output = model(inputs)
        with h5py.File(output_file, mode='a') as f_out:
            for ind in range(output.shape[1]):
                timestamp_out = str(int(last_input_timestamp[-1]) + 600 * (ind + 1))
                f_out.create_group(timestamp_out)
                f_out[timestamp_out].create_dataset(
                    'intensity',
                    data=output[0, ind, 0].detach().cpu().numpy()
                )

In [11]:
train_loader, val_loaders = prepare_month_loaders()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
model = torch.load('/kaggle/input/models-for-baseline/baseline_11e-4.pt') #11e-4

In [20]:
model.current_lr=3e-4
trainer = L.Trainer(
    max_epochs=1
)
trainer.fit(model, train_loader)
model.to(device)

Training: 0it [00:00, ?it/s]

ConvLSTMModel(
  (model): Seq2Seq(
    (sequential): Sequential(
      (convlstm1): ConvLSTM(
        (convLSTMCell): ConvLSTMCell(
          (conv_int_0): Conv2d(33, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv_int_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv_refl_0): Conv2d(48, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
for i in range(1, 13):
    print('month', i)
    print(evaluate_on_val(model, val_loaders[i - 1]))
    


In [24]:
import torch
torch.save(model, 'baseline_18e-4.pt')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
evaluate_on_val(model, val_loader) #168

In [None]:
trainer = L.Trainer(
    max_epochs=1
)
trainer.fit(model, train_loaders[2])

In [None]:
torch.save(model, 'two_conv_with_act_3_data.pt')
model.to(device)

In [None]:
evaluate_on_val(model, val_loader) #checked

In [None]:
'''
import gc
del model
gc.collect()
torch.cuda.empty_cache()
'''

In [None]:
device = torch.device('cuda:0')
model.to(device)

In [None]:
trainer = L.Trainer(
    max_epochs=1
)
trainer.fit(model, train_loaders[3])

In [None]:
torch.save(model, 'rotate_full_data.pt')

In [None]:
model.to(device)
evaluate_on_val(model, val_loader)

In [None]:
trainer = L.Trainer(
    max_epochs=1
)
trainer.fit(model, train_loaders[4])

In [None]:
model.to(device)
evaluate_on_val(model, val_loader) #checked

In [None]:
print(torch.cuda.is_available())

In [None]:

model = torch.load('two_conv_with_act_2_data.pt')

In [None]:
class TestDataset(data.Dataset):

    def __init__(self, list_of_files, in_seq_len=4, out_seq_len=12, mode='overlap', with_time=False):
        self.in_seq_len = in_seq_len
        self.out_seq_len = out_seq_len
        self.seq_len = in_seq_len + out_seq_len
        self.with_time = with_time
        self.__prepare_timestamps_mapping(list_of_files)
        self.__prepare_sequences(mode)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        data = []
        for timestamp in self.sequences[index]:
            with h5py.File(self.timestamp_to_file[timestamp]) as d:
                data.append(np.array(d[timestamp]['intensity']))
        data = np.expand_dims(data, axis=1)
        data[data == -1e6] = 0
        data[data == -2e6] = -1
        inputs = data[:self.in_seq_len]
        targets = data[self.in_seq_len:]
        if self.with_time:
            return (inputs, self.sequences[index][-1]), targets
        else:
            return inputs, targets

    def __prepare_timestamps_mapping(self, list_of_files):
        self.timestamp_to_file = {}
        for filename in list_of_files:
            with h5py.File(filename) as d:
                self.timestamp_to_file = {
                    **self.timestamp_to_file,
                    **dict(map(lambda x: (x, filename), d.keys()))
                }

    def __prepare_sequences(self, mode):
        timestamps = np.unique(sorted(self.timestamp_to_file.keys()))
        if mode == 'sequentially':
            self.sequences = [
                timestamps[index * self.seq_len: (index + 1) * self.seq_len]
                for index in range(len(timestamps) // self.seq_len)
            ]
        elif mode == 'overlap':
            self.sequences = [
                timestamps[index: index + self.seq_len]
                for index in range(len(timestamps) - self.seq_len + 1)
            ] 
        else:
            raise Exception(f'Unknown mode {mode}')
        self.sequences = list(filter(
            lambda x: int(x[-1]) - int(x[0]) == (self.seq_len - 1) * 600,
            self.sequences
        ))

In [None]:
def prepare_test_loader(test_batch_size=1):
    test_dataset = RadarDataset(['/kaggle/input/nowcasting-yandex-cup/2022-test-public.hdf5'], out_seq_len=0, with_time=True)
    test_loader = data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=2)
    return test_loader

In [None]:
test_loader = prepare_test_loader()
model.to(device)
process_test(model, test_loader, output_file='augmented_base.hdf5')

In [None]:
from IPython.display import FileLink
FileLink(r'augmented_base.hdf5')

In [None]:
import matplotlib.pyplot as plt

In [None]:
_, axs = plt.subplots(1, len(radial_velocity), figsize=(20, 2))
for index in range(len(radial_velocity)):
    axs[index].imshow(radial_velocity[index])
    axs[index].set_title(timestamps[index])

In [None]:
import torch
import torch.nn as nn
import numpy as np

# Define the input tensor
input_tensor = torch.randn(10, 252, 252)

# Define the convolutional layer
conv_layer = nn.Conv2d(in_channels=10, out_channels=32, kernel_size=3)

# Apply the convolutional layer to the input tensor
output_tensor = conv_layer(input_tensor)

# Convert the output tensor to a numpy array
output_array = output_tensor.detach().numpy()

# Reshape the numpy array to length 32
reshaped_array = np.reshape(output_array, (32,))

# Print the resulting numpy array
print(reshaped_array)

In [None]:
    def forward(self, X, H_prev, C_prev):
        int_X, ev_X = torch.chunk(X, chunks=2, dim=1)
        int_H_prev, ev_H_prev = torch.chunk(H_prev, chunks=2, dim=1)
        conv_int_output = self.conv_int(torch.cat([int_X, int_H_prev], dim=1))
        conv_ev_output = self.conv_ev(torch.cat([ev_X, ev_H_prev], dim=1))
        i_conv_int, f_conv_int, C_conv_int, o_conv_int = torch.chunk(conv_int_output, chunks=4, dim=1)
        i_conv_ev, f_conv_ev, C_conv_ev, o_conv_ev = torch.chunk(conv_ev_output, chunks=4, dim=1)
        input_gate = torch.sigmoid(torch.cat([i_conv_int, i_conv_ev], dim=1))
        forget_gate = torch.sigmoid(torch.cat([f_conv_int,f_conv_ev], dim=1))
        output_gate = torch.sigmoid(torch.cat([o_conv_int, o_conv_ev], dim=1))
        C_conv = torch.cat([C_conv_int, C_conv_ev], dim=1)
        C = forget_gate * C_prev + input_gate * self.activation(C_conv)
        H = output_gate * self.activation(C)
        return H, C

In [None]:
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, activation):
        super().__init__()

        if activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'relu':
            self.activation = torch.relu
            
        self.conv_wind = nn.Conv2d(
            in_channels= 10 + out_channels,
            out_channels= 4 * out_channels,
            kernel_size=kernel_size,
            padding= padding
        )
        
        self.conv_refl = nn.Conv2d(
            in_channels= 10 + out_channels,
            out_channels= 4 * out_channels,
            kernel_size=kernel_size,
            padding= padding
        )

        self.conv_int = nn.Conv2d(
            in_channels=in_channels + out_channels,
            out_channels=4 * out_channels,
            kernel_size=kernel_size,
            padding=padding
        )

    def forward(self, X, H_prev, C_prev):
        int_X = X[:,0:1, :, :]
        wind_X= X[:,1:11, :, :]
        refl_X= X[:,11:, :, :]
        int_H_prev = H_prev[:,0:32, :, :]
        wind_H_prev = H_prev[:,32:64, :, :]
        refl_H_prev = H_prev[:,64:96, :, :]

        conv_int_output = self.conv_int(torch.cat([int_X, int_H_prev], dim=1))
        conv_wind_output = self.conv_wind(torch.cat([wind_X, wind_H_prev], dim=1))
        conv_refl_output = self.conv_refl(torch.cat([refl_X, refl_H_prev], dim=1))
        
        i_conv_int, f_conv_int, C_conv_int, o_conv_int = torch.chunk(conv_int_output, chunks=4, dim=1)
        i_conv_wind, f_conv_wind, C_conv_wind, o_conv_wind = torch.chunk(conv_wind_output, chunks=4, dim=1)
        i_conv_refl, f_conv_refl, C_conv_refl, o_conv_refl = torch.chunk(conv_refl_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(torch.cat([i_conv_int, i_conv_wind, i_conv_refl], dim=1))
        forget_gate = torch.sigmoid(torch.cat([f_conv_int, f_conv_wind, f_conv_refl], dim=1))
        output_gate = torch.sigmoid(torch.cat([o_conv_int,  o_conv_wind, o_conv_refl], dim=1))
        C_conv = torch.cat([C_conv_int, C_conv_wind, C_conv_refl], dim=1)
        C = forget_gate * C_prev + input_gate * self.activation(C_conv)
        H = output_gate * self.activation(C)


        '''X.shape: torch.Size([1, 4, 252, 252])
H_prev.shape: torch.Size([1, 128, 252, 252])
C_prev.shape: torch.Size([1, 128, 252, 252])
int_X.shape: torch.Size([1, 1, 252, 252])
conv_int_output.shape: torch.Size([1, 128, 252, 252])
i_conv_int.shape: torch.Size([1, 32, 252, 252])
input_gate.shape: torch.Size([1, 128, 252, 252])
H.shape: torch.Size([1, 128, 252, 252])
C.shape: torch.Size([1, 128, 252, 252])'''

        return H, C


class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, activation):
        super().__init__()
        self.out_channels = out_channels
        self.convLSTMCell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding, activation)

    def forward(self, X):
        batch_size, seq_len, _, height, width = X.size()
        output = torch.zeros(batch_size, seq_len, 3 * self.out_channels, height, width, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        H = torch.zeros(batch_size, 3 * self.out_channels, height, width, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        C = torch.zeros(batch_size, 3 * self.out_channels, height, width, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        for time_step in range(seq_len):
            H, C = self.convLSTMCell(X[:, time_step], H, C)
            output[:, time_step] = H
        return output


class Seq2Seq(nn.Module):

    def __init__(
        self, num_channels, num_kernels, kernel_size, padding, activation, num_layers, out_seq_len
    ):
        super().__init__()
        self.out_seq_len = out_seq_len

        self.activation = torch.relu
        self.sequential = nn.Sequential()
        self.sequential.add_module(
            'convlstm1',
            ConvLSTM(
                in_channels=num_channels,
                out_channels=num_kernels,
                kernel_size=kernel_size,
                padding=padding,
                activation=activation
            )
        )
        for layer_index in range(2, num_layers + 1):
            self.sequential.add_module(
                f'convlstm{layer_index}',
                ConvLSTM(
                    in_channels=num_kernels,
                    out_channels=num_kernels,
                    kernel_size=kernel_size,
                    padding=padding,
                    activation=activation
                )
            )
        self.conv = nn.Conv2d(
            in_channels=3*num_kernels,
            out_channels=num_channels,
            kernel_size=kernel_size,
            padding=padding
        )

    def forward(self, X):
        batch_size, seq_len, num_channels, height, width = X.size()
        inputs = torch.zeros(
            batch_size, seq_len + self.out_seq_len - 1, num_channels, height, width,
            device=self.conv.weight.device
        )
        inputs[:, :seq_len] = X
        output = self.sequential(inputs)
        output = torch.stack([
            self.conv(output[:, index + seq_len - 1])
            for index in range(self.out_seq_len)
        ], dim=1)
        return output


class ConvLSTMModel(L.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = Seq2Seq(
            num_channels=1,
            num_kernels=32,
            kernel_size=(3, 3),
            padding=(1, 1),
            activation='relu',
            num_layers=1,
            out_seq_len=12
        )

    def forward(self, x):
        x = x.to(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        output = self.model(x)
        return output

    def training_step(self, batch):
        x, y = batch
        out = self.forward(x)
        out[y == -1] = -1
        loss = F.mse_loss(out, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
        return optimizer

In [None]:
'''
                
                targets.append(np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['intensity'])), k=self.rotate, dims=(0, 1))))
                data.append(np.array([np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['intensity'])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][0])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][1])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][2])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][3])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][4])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][5])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][6])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][7])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][8])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['radial_velocity'][9])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][0])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][1])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][2])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][3])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][4])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][5])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][6])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][7])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][8])), k=self.rotate, dims=(0, 1))),
                                      np.array(torch.rot90(torch.tensor(np.array(d[timestamp]['reflectivity'][9])), k=self.rotate, dims=(0, 1)))]))
                '''