In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [75]:
conv = nn.Conv2d(
in_channels=1,
out_channels=8,
kernel_size=3,
stride=1,
padding=0)
pool = nn.MaxPool2d(
        kernel_size=3
        )
# convlstm1 = ConvLSTMCell(
#            shape=[self.size1,self.size1], 
#            input_channel=8, 
#            filter_size=3,
#            hidden_size=self.hidden_size)
# convlstm2 = ConvLSTMCell(
#            shape=[self.size1,self.size1], 
#            input_channel=self.hidden_size,
#            filter_size=3,
#            hidden_size=self.hidden_size)
deconv = nn.ConvTranspose2d(
           in_channels=64 , 
           out_channels=1, 
           kernel_size=6,
           stride=3,
           padding=0, 
           output_padding=1, 
           )

In [67]:
# Based on the tensorflow implementation by yunbo:
# https://github.com/Yunbo426/predrnn-pp/blob/master/layers/CausalLSTMCell.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalLSTMCell(nn.Module):

    def __init__(self, input_channels, layer_name, filter_size, num_hidden_in, num_hidden_out,
                 seq_shape, forget_bias=1.0, initializer=0.001):
        super(CausalLSTMCell,self).__init__()
        
        self.layer_name = layer_name
        self.filter_size = filter_size
        self.input_channels = input_channels
        self.num_hidden_in = num_hidden_in
        self.num_hidden = num_hidden_out
        self.batch = seq_shape[0]
        self.x_channels = 1
        self.height = seq_shape[2]
        self.width = seq_shape[3]
        # self.layer_norm = tln
        self._forget_bias = forget_bias

        #########
        #NOT SURE ABOUT INPUT CHANNELS
        #########
        self.conv_h = nn.Conv2d(in_channels=self.num_hidden, ###hidden state has similar spatial struture as inputs, we simply concatenate them on the feature dimension
                           out_channels=self.num_hidden*4, ##lstm has four gates
                           kernel_size=self.filter_size,
                           stride=1,
                           padding=1
        )

        self.conv_c = nn.Conv2d(in_channels=self.num_hidden, ###hidden state has similar spatial struture as inputs, we simply concatenate them on the feature dimension
                           out_channels=self.num_hidden*3, 
                           kernel_size=self.filter_size,
                           stride=1,
                           padding=1
        )

        self.conv_m = nn.Conv2d(in_channels=self.num_hidden_in, ###hidden state has similar spatial struture as inputs, we simply concatenate them on the feature dimension
                           out_channels=self.num_hidden*3, 
                           kernel_size=self.filter_size,
                           stride=1,
                           padding=1
        )

        self.conv_x = nn.Conv2d(in_channels=self.input_channels, ###hidden state has similar spatial struture as inputs, we simply concatenate them on the feature dimension
                           out_channels=self.num_hidden*7, 
                           kernel_size=self.filter_size,
                           stride=1,
                           padding=1)

        self.conv_o = nn.Conv2d(in_channels=self.num_hidden,out_channels=self.num_hidden, 
                           kernel_size=self.filter_size,
                           stride=1,
                           padding=1)
        
        self.conv_1_1 =  nn.Conv2d(in_channels=self.num_hidden*2,out_channels=self.num_hidden, 
                           kernel_size=1,
                           stride=1,
                           padding=0)

    def forward(self,x,h,c,m):
        if h is None:
            h = torch.zeros([self.batch,self.num_hidden,self.height,self.width])
        if c is None:
            c = torch.zeros([self.batch,self.num_hidden,self.height,self.width])
        if m is None:
            m = torch.zeros([self.batch,self.num_hidden_in,self.height,self.width])

        h_cc = self.conv_h(h)
        c_cc = self.conv_c(c)
        m_cc = self.conv_m(m)

        i_h, g_h, f_h, o_h = torch.chunk(h_cc, 4, dim=1)
        i_c, g_c, f_c = torch.chunk(c_cc, 3, dim=1)
        i_m, f_m, m_m = torch.chunk(m_cc, 3, dim=1)

        if x is None:
            i = torch.sigmoid(i_h + i_c)
            f = torch.sigmoid(f_h + f_c + self._forget_bias)
            g = torch.tanh(g_h + g_c)

        else:
            x_cc = self.conv_x(x)
            
            i_x, g_x, f_x, o_x, i_x_, g_x_, f_x_ = torch.chunk(x_cc, 7, dim = 1)
            print(i_x.shape,i_h.shape,i_c.shape)
            i = torch.sigmoid(i_x + i_h + i_c)
            f = torch.sigmoid(f_x + f_h + f_c + self._forget_bias)
            g = torch.tanh(g_x + g_h + g_c)
        
        c_new = f * c + i * g
        
        c2m = self.conv_h(c_new)

        i_c, g_c, f_c, o_c = torch.chunk(c2m, 4, dim=1)

        if x is None:
            ii = torch.sigmoid(i_c + i_m)
            ff = torch.sigmoid(f_c + f_m + self._forget_bias)
            gg = torch.tanh(g_c)
        else:
            ii = torch.sigmoid(i_c + i_x_ + i_m)
            ff = torch.sigmoid(f_c + f_x_ + f_m + self._forget_bias)
            gg = torch.tanh(g_c + g_x_)

        m_new = ff * torch.tanh(m_m) + ii * gg
        o_m = self.conv_o(m_new)        

        if x is None:
            o = torch.tanh(o_h + o_c + o_m)
        else:
            o = torch.tanh(o_x + o_h + o_c + o_m)
        
        cell = torch.cat((c_new, m_new),1)
        print(cell.shape)
        cell = self.conv_1_1(cell)
        print(cell.shape)
        h_new = o * torch.tanh(cell)

        return h_new, c_new, m_new

In [68]:
clstm = CausalLSTMCell(8,'karas',3,64,64,[5,1,32,32])

In [76]:
x = torch.randn([5,1,100,100])
out = conv(x)
out = pool(out)
out,_,_  = clstm(out,None,None,None) 
out= deconv(out)
out.shape

torch.Size([5, 64, 32, 32]) torch.Size([5, 64, 32, 32]) torch.Size([5, 64, 32, 32])
torch.Size([5, 128, 32, 32])
torch.Size([5, 64, 32, 32])


torch.Size([5, 1, 100, 100])

In [3]:
shape = [5,1,100,100]

In [4]:
cell = CausalLSTMCell(1,'axristo',3,64,192,shape,1.0)

In [5]:
x = torch.zeros(shape)

In [6]:
a = cell.forward(x,None,None,None)

torch.Size([5, 384, 100, 100])
torch.Size([5, 192, 100, 100])


In [7]:
c = torch.zeros([5,192,100,100])
m = torch.zeros([5,192,100,100])

In [8]:
torch.cat((c,m),1).shape

torch.Size([5, 384, 100, 100])

In [9]:
a.shape

AttributeError: 'tuple' object has no attribute 'shape'

In [12]:
# based on the tensorflow implementation by 'yunbo':
# https://github.com/Yunbo426/predrnn-pp/blob/master/nets/predrnn_pp.py

import torch
import torch.nn as nn
import sys
sys.path.append('..')
from model_architectures.pred_rnn_pp.CausalLSTMCell import CausalLSTMCell as clstm

In [46]:
# based on the tensorflow implementation by 'yunbo':
# https://github.com/Yunbo426/predrnn-pp/blob/master/nets/predrnn_pp.py

import torch
import torch.nn as nn
from model_architectures.pred_rnn_pp.CausalLSTMCell import CausalLSTMCell as clstm

class PredRNNPP(nn.Module):

    def __init__(self,input_shape,seq_input,seq_output,batch_size,num_hidden,device):
        super(PredRNNPP,self).__init__()

        self.seq_input = seq_input
        self.seq_output = seq_output
        self.seq_length = seq_input + seq_output
        self.device = device
        self.batch_size = batch_size
        self.input_shape = input_shape #this is the dimensionality of the frame
        self.num_hidden = num_hidden
        self.num_layers = len(num_hidden)

        self.lstm = []
        self.output_channels = 1

        self.conv = nn.Conv2d(in_channels=self.num_hidden[self.num_layers-1], ###hidden state has similar spatial struture as inputs, we simply concatenate them on the feature dimension
                           out_channels=self.output_channels, 
                           kernel_size=1,
                           stride=1,
                           padding=0)

        for i in range(self.num_layers):
            if i == 0:
                num_hidden_in = self.num_hidden[self.num_layers-1]
                input_channels = 1
            else:
                num_hidden_in = self.num_hidden[i-1]
                input_channels = self.num_hidden[i-1]

            new_cell = clstm(input_channels,'lstm_'+str(i+1),3,num_hidden_in,self.num_hidden[i],self.input_shape)
            self.lstm.append(new_cell)

        self.ghu = None
        

    def forward(self,x):
        
        cell = []
        hidden = []
        mem = None
        for i in range(self.num_layers):
            cell.append(None)
            hidden.append(None)
        output = []
        x_gen = None
        # x has shape B S H W
        for t in range(self.seq_length-1):
            if t < self.seq_input:
                inputs = x[:,t,:,:].unsqueeze(1)
            else:
                inputs = x_gen
            
            hidden[0], cell[0], mem = self.lstm[0].forward(inputs, hidden[0],cell[0], mem)
            #z_t = self.ghu(self.hidden[0], z_t)
            z_t = hidden[0]
            hidden[1],cell[1],mem = self.lstm[1](z_t, hidden[1], cell[1], mem)
            for i in range(2, self.num_layers):
                hidden[i], cell[i], mem = self.lstm[i](hidden[i-1], hidden[i], cell[i], mem)
                x_gen = self.conv(hidden[num_layers-1])
                output.append(x_gen)

        output = torch.stack(output)

        return output

In [47]:
X = torch.zeros([5,12,100,100])

In [48]:
num_hidden = [4,8,8,8]
predrnn = PredRNNPP(X.shape,12,10,5,num_hidden,'cpu')

In [49]:
predrnn.forward(X)

RuntimeError: Given groups=1, weight of size 56 1 3 3, expected input[5, 4, 100, 100] to have 1 channels, but got 4 channels instead