In [2]:
import torch
import torch.nn as nn
from math import sqrt
import os
# from ConvLSTM import ConvLSTMCell -> ConvLSTM.py
import numpy as np

## Model Architecture (UNet based)

![image](https://user-images.githubusercontent.com/44194558/145318807-8818aa85-6bab-40ec-b9cd-9fe91d3bff7b.png)

1. 4개의 이미지를 입력으로 받음 (t-3, t-2, t-1, t)

2. Generator1은 각각의 이미지를 입력으로 받아 LSTM1, LSTM2의 입력을 출력 (out1s, out2s)

3. LSTM1, LSTM2는 2의 출력을 각각 입력으로 받음

 - LSTM1은 out1 4개를 입력으로 받아 output1을 출력 
 - LSTM2는 out2 4개를 입력으로 받아 output2를 출력

 - output1, 2는 4개 시점의 time information이 모두 고려된 단일 출력

4. Generator2는 output1, output2를 입력으로 받아 예측 이미지 생성 (t-2_hat, t-1_hat, t_hat, t+1_hat)

### Convolution

In [None]:
# input에 2D Convolution 적용
# (N, C_in, H, W) -> (N, C_out, H_out, W_out)
# Output size = ((W - Kernel size + 2*Padding size) / Strides) + 1

# conv 3x3은 zero padding을 사용하여, 연산 후에도 입력의 원본 사이즈를 계속 유지 (파란색 화살표)
def conv3x3(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=1, padding=1, bias=True)  # padding=1 : padding을 줄지 말지 여부 & padding 사이즈 지정

# conv 2x2는 zero padding을 사용하지 않기 때문에 연산 후 입력의 사이즈가 감소, 채널 증가 (주황색 화살표)    
def conv2x2(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=2,
                     stride=2, padding=0, bias=True)  # padding의 default=0 (따로 설정하지 않으면 zero padding 수행 x)

def conv1x1(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1,
                     stride=1, padding=0, bias=True)

### Generators

In [5]:
# Generating the inputs of LSTM1 & LSTM2
class Generator1(nn.Module):
    def __init__(self, in_channels):
        super(Generator1, self).__init__()

        self.in_channels = in_channels

        # 1 - Original Input Images size (128x128 with 16 channels) / input : 4 channels
        # 128x128 사이즈 유지
        self.conv3_1_1 = conv3x3(in_channels=self.in_channels, out_channels=16)
        self.relu1_1 = nn.LeakyReLU(0.1)
        self.conv3_1_2 = conv3x3(in_channels=16, out_channels=16)
        self.relu1_2 = nn.LeakyReLU(0.1)
        self.conv3_1_3 = conv3x3(in_channels=16, out_channels=16)
        self.relu1_3 = nn.LeakyReLU(0.1)

        # 2 - 1/2 Size (64x64 with 32 channels) 
        # 64x64 사이즈 유지
        self.conv2_2_1 = conv2x2(in_channels=16, out_channels=32) # 이미지의 사이즈를 감소시키는 대신 depth(channel)가 증가 
        self.relu2_1 = nn.LeakyReLU(0.1)
        self.conv3_2_1 = conv3x3(in_channels=32, out_channels=32)
        self.relu2_2 = nn.LeakyReLU(0.1)
        self.conv3_2_2 = conv3x3(in_channels=32, out_channels=32)
        self.relu2_3 = nn.LeakyReLU(0.1)

        # 3 - 1/4 Size 
        # 32x32 사이즈 유지
        self.conv2_3_1 = conv2x2(in_channels=32, out_channels=64)   # 이미지의 사이즈를 감소시키는 대신 depth(channel)가 증가 
        self.relu3_1 = nn.LeakyReLU(0.1)
        self.conv3_3_1 = conv3x3(in_channels=64, out_channels=64)
        self.relu3_2 = nn.LeakyReLU(0.1)
        self.conv3_3_2 = conv3x3(in_channels=64, out_channels=64)
        self.relu3_3 = nn.LeakyReLU(0.1)

        # 4 - 1/8 Size
        # 16x16 사이즈 유지
        self.conv2_4_1 = conv2x2(in_channels=64, out_channels=128)   # 이미지의 사이즈를 감소시키는 대신 depth(channel)가 증가 
        self.relu4_1 = nn.LeakyReLU(0.1)
        self.conv3_4_1 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_2 = nn.LeakyReLU(0.1)
        self.conv3_4_2 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_3 = nn.LeakyReLU(0.1)

    def forward(self, x):
        # 1 - Original Input Images (128x128 with 16 channels)
        y1_1 = self.conv3_1_1(x)  # out_channels=16
        y1_1 = self.relu1_1(y1_1)
        y1_2 = self.conv3_1_2(y1_1)  # out_channels=16
        y1_2 = self.relu1_2(y1_2)
        y1_3 = self.conv3_1_3(y1_2)  # out_channels=16
        y1_3 = self.relu1_3(y1_3)  # shape : (None, 16, 128, 128)

        # 2 - 1/2 Size (64x64 with 32 channels) 
        y2_1 = self.conv2_2_1(y1_3)  # out_channels=32 / 여기서 1/2 size로 감소 : ((128-2+2*0) / 2) + 1 = 64
        y2_1 = self.relu2_1(y2_1)
        y2_2 = self.conv3_2_1(y2_1)  # out_channels=32
        y2_2 = self.relu2_2(y2_2)
        y2_3 = self.conv3_2_2(y2_2)  # out_channels=32
        out1 = self.relu2_3(y2_3)  # Input of LSTM1 / shape : (None, 32, 64, 64)

        # 3 - 1/4 Size (32x32 with 64 channels) 
        y3_1 = self.conv2_3_1(out1)  # out_channels=64 / 여기서 1/4 size로 감소 : ((64-2+2*0) / 2) + 1 = 32
        y3_1 = self.relu3_1(y3_1)
        y3_2 = self.conv3_3_1(y3_1)  # out_channels=64  
        y3_2 = self.relu3_2(y3_2)
        y3_3 = self.conv3_3_2(y3_2)  # out_channels=64  
        y3_3 = self.relu3_3(y3_3)

        # 4 - 1/8 Size (16x16 with 128 channels) 
        y4_1 = self.conv2_4_1(y3_3)  # out_channels=128
        y4_1 = self.relu4_1(y4_1)
        y4_2 = self.conv3_4_1(y4_1)  # out_channels=128
        y4_2 = self.relu4_2(y4_2)
        y4_3 = self.conv3_4_2(y4_2)  # out_channels=128
        out2 = self.relu4_3(y4_3)  #  Input of LSTM2

        return out1, out2

In [None]:
class Generator2(nn.Module):
    def __init__(self, out_channels):
        super(Generator2, self).__init__()

        self.out_channels = out_channels
        self.PS = nn.PixelShuffle(2)  # 2배 (None, 32 x 2^2, H, W) -> (None, 32, H x 2, W x 2)

        # 4 - 1/8 Size
        self.conv3_4_3 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_6 = nn.LeakyReLU(0.1)
        self.conv3_4_4 = conv3x3(in_channels=128, out_channels=256)
        self.relu4_7 = nn.LeakyReLU(0.1)

        # 5 - 1/4 Size
        self.conv3_5_1 = conv3x3(in_channels=64, out_channels=64)
        self.relu5_1 = nn.LeakyReLU(0.1)
        self.conv3_5_2 = conv3x3(in_channels=64, out_channels=128)
        self.relu5_2 = nn.LeakyReLU(0.1)

        # 6 - 1/2 Size
        self.conv3_6_1 = conv3x3(in_channels=64, out_channels=64)
        self.relu6_1 = nn.LeakyReLU(0.1)
        self.conv3_6_2 = conv3x3(in_channels=64, out_channels=64)
        self.relu6_2 = nn.LeakyReLU(0.1)

        # 7 - Original Input Images size
        self.conv3_7_1 = conv3x3(in_channels=16, out_channels=16)
        self.relu7_1 = nn.LeakyReLU(0.1)
        self.conv3_7_2 = conv3x3(in_channels=16, out_channels=16)
        self.relu7_2 = nn.LeakyReLU(0.1)
        self.conv1_7_1 = conv1x1(in_channels=16, out_channels=self.out_channels)

    def forward(self, x1, x2):  # x1 : out1 (64x64 with 32 channels), x2 : out2 (16x16 with 128 channels)
        # 4 - 1/8 Size
        y4_6 = self.conv3_4_3(x2)  # out_channels=128
        y4_6 = self.relu4_6(y4_6)
        y4_7 = self.conv3_4_4(y4_6)  # out_channels=256
        y4_7 = self.relu4_7(y4_7)

        # 5 - 1/4 Size
        y5_1 = self.PS(y4_7)   # Pixel shuffler / out_channels=64 (초록색 화살표) / shape : (None, 64, 32, 32)
        y5_2 = self.conv3_5_1(y5_1)  # out_channels=64
        y5_2 = self.relu5_1(y5_2)
        y5_3 = self.conv3_5_2(y5_2)  # out_channels=128
        y5_3 = self.relu5_2(y5_3)

        # 6 - 1/2 Size
        y6_1 = self.PS(y5_3)   # Pixel shuffler / out_channels=32 /  shape : (None, 32, 64, 64)
        y6_2 = torch.cat((x1, y6_1), 1)  # Concat (32+32 channels) / shape : (None, 32+32, 64, 64)
        y6_3 = self.conv3_6_1(y6_2)  # out_channels=64  
        y6_3 = self.relu6_1(y6_3)
        y6_4 = self.conv3_6_2(y6_3)  # out_channels=64
        y6_4 = self.relu6_2(y6_4)

         # 7 - Original Input Images size
        y7_1 = self.PS(y6_4)  # out_channels=16 / shape : (None, 16, 128, 128)
        y7_2 = self.conv3_7_1(y7_1)  # out_channels=16
        y7_2 = self.relu7_1(y7_2)
        y7_3 = self.conv3_7_2(y7_2)  # out_channels=16
        y7_3 = self.relu7_2(y7_3)
        out = self.conv1_7_1(y7_3)  # # out_channels=4 / generated images (흰색 화살표)

        return out    

### Convolutional LSTM (ConvLSTM.py)

![image](https://user-images.githubusercontent.com/44194558/145341328-9d6b9d47-d997-419c-9cea-d2348f2ba01e.png)

![image](https://user-images.githubusercontent.com/44194558/145342215-05296e7a-ce7f-46dc-b298-0171abfd7388.png)


m=4 (4개 시점의 이미지)





        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

```
# 코드로 형식 지정됨
```



In [7]:
import torch.nn as nn
import torch

class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channel=sef.input_dim + self.hidden_dim,
                              out_channel=4*self.hidden_dim,  # time x channels
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state  # shape of input_tensor & h_cur : (None, 32, 64, 64)
        # 현 시점의 입력(이미지), 이전 시점의 hidden_state를 입력으로 받음
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis / shape : (None, 32+32, 64, 64)

        combined_conv = self.conv(combined)  # shape : (1, 128, 64, 64) / 128=4x32

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)  # shape : (1, 32, 64, 64) 
        i = torch.sigmoid(cc_i)  # input gate
        f = torch.sigmoid(cc_f)  # forget gate
        o = torch.sigmoid(cc_o)  # output gate
        g = torch.tanh(cc_g)  # update gate

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next  # shape : (1, 32, 64, 64)

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))

## Model (testblock)

In [None]:
#  model = testblock(3, 3, 4, 128) - hscnn_main_output4.py
class testblock(nn.Module):
    def __init__(self, in_channels, out_channels, time, size):
        super(testblock, self).__init__()

        self.in_channels = in_channels  # 3 (RGB)
        self.out_channels = out_channels  # 3
        self.time = time  # 4 (t-3, t-2, t-1, t 시점의 이미지를 입력으로 받음)
        self.size = size  # 128
        self.PS = nn.PixelShuffle(2)

        self.G1 = Generator1(self.in_channels)  # out1 (32 channels) & out2 (64 channels)
        self.G2 = Generator2(self.out_channels)  # generated images (128x128 with 4 channels)

        # 2 - 1/2 Size 
        # Encoder-Decoder of LSTM 1 (input : out1)
        self.encoder_1_convlstm = ConvLSTMCell(input_dim=32,
                                               hidden_dim=32,
                                               kernel_size=(3, 3),
                                               bias=True)
        self.encoder_2_convlstm = ConvLSTMCell(input_dim=32,
                                               hidden_dim=32,
                                               kernel_size=(3, 3),
                                               bias=True)
        self.decoder_1_convlstm = ConvLSTMCell(input_dim=32,
                                               hidden_dim=32,
                                               kernel_size=(3, 3),
                                               bias=True)
        self.decoder_2_convlstm = ConvLSTMCell(input_dim=32,
                                               hidden_dim=32,
                                               kernel_size=(3, 3),
                                               bias=True)
        
        # 4 - 1/8 Size (out2)
        # Encoder-Decoder of LSTM 2
        self.encoder_1_convlstm_2 = ConvLSTMCell(input_dim=128,
                                                 hidden_dim=128,
                                                 kernel_size=(3, 3),
                                                 bias=True)
        self.encoder_2_convlstm_2 = ConvLSTMCell(input_dim=128,
                                                 hidden_dim=128,
                                                 kernel_size=(3, 3),
                                                 bias=True)
        self.decoder_1_convlstm_2 = ConvLSTMCell(input_dim=128,
                                                 hidden_dim=128,
                                                 kernel_size=(3, 3),
                                                 bias=True)
        self.decoder_2_convlstm_2 = ConvLSTMCell(input_dim=128,
                                                 hidden_dim=128,
                                                 kernel_size=(3, 3),
                                                 bias=True)
        
        # 2 - 1/2 Size
        # Encoder-Decoder
        def autoencoder1(self, x, seq_len, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):
            outputs = []

            # encoder
            # iteration 과정에서 이전 시점의 h_t를 입력으로 받아 갱신한 후, 다음 시점의 입력으로 제공
            for t in range(seq_len):
                h_t, c_t = self.encoder_1_convlstm(input_tensor=x[:, t, :, :],  # 특정 시점 t의 out1 / shape : (None, 32, 64, 64)
                                                   cur_state=[h_t, c_t])  # we could concat to provide skip conn here
                # shape of h_t : 
                h_t2, c_t2 = self.encoder_2_convlstm(input_tensor=h_t,
                                                     cur_state=[h_t2, c_t2])  # we could concat to provide skip conn here

            # encoder_vector (Input of decoder, 마지막 시점의 최종 hidden_state)
            encoder_vector = h_t2  # 4개 시점의 time information 반영 (compressed)

            # decoder
            for t in range(seq_len):
                h_t3, c_t3 = self.decoder_1_convlstm(input_tensor=encoder_vector,
                                                     cur_state=[h_t3, c_t3])  # we could concat to provide skip conn here
                h_t4, c_t4 = self.decoder_2_convlstm(input_tensor=h_t3,
                                                     cur_state=[h_t4, c_t4])  # we could concat to provide skip conn here
                encoder_vector = h_t4

                outputs += [h_t4]  # predictions

            outputs = torch.stack(outputs, 1)  # decompress (시점의 차원으로 확장하여 tensor를 쌓음)
            
            return outputs

    def autoencoder2(self, x, seq_len, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):
        outputs = []

        # encoder
        for t in range(seq_len):
            h_t, c_t = self.encoder_1_convlstm_2(input_tensor=x[:, t, :, :],
                                                 cur_state=[h_t, c_t])  # we could concat to provide skip conn here
            h_t2, c_t2 = self.encoder_2_convlstm_2(input_tensor=h_t,
                                                   cur_state=[h_t2, c_t2])  # we could concat to provide skip conn here

        # encoder_vector
        encoder_vector = h_t2  

        # decoder
        for t in range(seq_len):
            h_t3, c_t3 = self.decoder_1_convlstm_2(input_tensor=encoder_vector,
                                                   cur_state=[h_t3, c_t3])  # we could concat to provide skip conn here
            h_t4, c_t4 = self.decoder_2_convlstm_2(input_tensor=h_t3,
                                                   cur_state=[h_t4, c_t4])  # we could concat to provide skip conn here
            encoder_vector = h_t4

            outputs += [h_t4]  # predictions
        
        outputs = torch.stack(outputs, 1)

        return outputs   

    def forward(self, x1, x2, x3, x4):  # x1, x2, x3, x4 => 이미지 4개 입력 / t-3, t-2, t-1, t 시점의 이미지를 입력으로 받아 t+1시점의 예측 이미지 생성
        y1_1, y1_2 = self.G1(x1)  # G1 함수 / 이미지 1의 out1, out2
        y2_1, y2_2 = self.G1(x2)  # 이미지 2의 out1, out2
        y3_1, y3_2 = self.G1(x3)  # 이미지 3의 out1, out2
        y4_1, y4_2 = self.G1(x4)  # 이미지 4의 out1, out2   

        # 2 - 1/2 Size (out1s) - shape : (None, 32, 64, 64) / (Batch size, Channels, H, W)
        # Encoder-Decoder of LSTM1
        y1_1 = y1_1.unsqueeze(1)  # shape : (None, 1, 32, 64, 64)
        y2_1 = y2_1.unsqueeze(1)
        y3_1 = y3_1.unsqueeze(1)
        y4_1 = y4_1.unsqueeze(1)
        stack1 = torch.cat((y1_1, y2_1, y3_1, y4_1), dim=1)  # Concat / shape : (None, 4, 32, 64, 64) - (Batch, time, channel, w, h)

        b, seq_len, _, h, w = stack1.size() 

        output1 = self.autoencoder1(stack1, seq_len, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4)
        output1 = output1.view(-1, 128, 64, 64) # 128=4x32

        x1_1_dec = output1[:, 0:32, :, :]  # shape : (None, 32, 64, 64)
        x2_1_dec = output1[:, 32:64, :, :]
        x3_1_dec = output1[:, 64:96, :, :]
        x4_1_dec = output1[:, 96:128, :, :]

        # 4 - 1/8 Size
        # Encoder-Decoder of LSTM2
        y1_2 = y1_2.unsqueeze(1)
        y2_2 = y2_2.unsqueeze(1)
        y3_2 = y3_2.unsqueeze(1)
        y4_2 = y4_2.unsqueeze(1)
        stack2 = torch.cat((y1_2, y2_2, y3_2, y4_2), dim=1)

        b_2, seq_len_2, _, h_2, w_2 = stack2.size()

        # initialize hidden states
        h_t_2, c_t_2 = self.encoder_1_convlstm_2.init_hidden(batch_size=b_2, image_size=(h_2, w_2))
        h_t2_2, c_t2_2 = self.encoder_2_convlstm_2.init_hidden(batch_size=b_2, image_size=(h_2, w_2))
        h_t3_2, c_t3_2 = self.decoder_1_convlstm_2.init_hidden(batch_size=b_2, image_size=(h_2, w_2))
        h_t4_2, c_t4_2 = self.decoder_2_convlstm_2.init_hidden(batch_size=b_2, image_size=(h_2, w_2))

        output2 = self.autoencoder2(stack2, seq_len_2, h_t_2, c_t_2, h_t2_2, c_t2_2, h_t3_2, c_t3_2, h_t4_2, c_t4_2)
        output2 = output2.view(-1, 512, 16, 16)  # 512 = 4 x 128 (time x channels)

        x1_2_dec = output2[:, 0:128, :, :]  # shape : (1, 128, 64, 64)
        x2_2_dec = output2[:, 128:256, :, :]
        x3_2_dec = output2[:, 256:384, :, :]
        x4_2_dec = output2[:, 384:512, :, :]
        
        # t시점의 입력에 대한 LSTM1, LSTM2의 output들을 입력으로 받아 t+1시점의 예측 이미지 생성 
        # 실제로는 out4만 필요, 나머지는 loss 계산용
        out1 = self.G2(x1_1_dec, x1_2_dec)
        out2 = self.G2(x2_1_dec, x2_2_dec)
        out3 = self.G2(x3_1_dec, x3_2_dec)
        out4 = self.G2(x4_1_dec, x4_2_dec)

        return out1, out2, out3, out4
