In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda")

In [None]:
class ConvRNNCell(nn.Module):
    def __init__(self, input_c, hidden_size, kernel_size, bias):
        super(ConvRNNCell, self).__init__()
        self.input_c = input_c
        self.hidden_size = hidden_size
        self.conv = nn.Conv2d(in_channels=input_c+hidden_size,out_channels=hidden_size,kernel_size=kernel_size,padding=kernel_size//2,bias=bias)
        
    def forward(self,x,h_cur):
        (m,nc,nh,nw) = x.size()
        if h_cur is None:
            h_cur= Variable(torch.zeros(1, self.hidden_size, nh, nw)).to(device).float()

        
        combined = torch.cat((x, h_cur), dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        h_next = F.tanh(combined_conv)

        return h_next

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_c, hidden_size, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()
        self.input_c = input_c
        self.hidden_size = hidden_size
        self.conv = nn.Conv2d(in_channels=input_c+hidden_size,out_channels=4*hidden_size,kernel_size=kernel_size,padding=kernel_size//2,bias=bias)
        
    def forward(self,x,cur_state):
        (m,nc,nh,nw) = x.size()
        if cur_state is None:
            cur_state = (Variable(torch.zeros(1, self.hidden_size, nh, nw)).to(device).float(),
                        Variable(torch.zeros(1, self.hidden_size, nh, nw)).to(device).float())
        
            
        h_cur, c_cur = cur_state
        combined = torch.cat((x, h_cur), dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = combined_conv.chunk(4,1)

        i = F.sigmoid(cc_i)
        f = F.sigmoid(cc_f)
        o = F.sigmoid(cc_o)
        g = F.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * F.tanh(c_next)

        next_state = (h_next,c_next)
        return next_state

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, flatten_dim, hidden_size, bias):
        super(LSTMCell, self).__init__()
    
        self.flatten_dim = flatten_dim
        self.hidden_size = hidden_size

        self.bias        = bias
        
        self.i2h = nn.Linear(flatten_dim, 4*hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4*hidden_size, bias=bias)
        
 
        
    def forward(self, x, cur_state):
        
        if cur_state is None:
            cur_state = (Variable(torch.zeros(1, self.hidden_size)).to(device).float(),
                Variable(torch.zeros(1, self.hidden_size)).to(device).float())

        x = x.view(1,-1)
        c_cur, h_cur = cur_state
        
        preact = self.i2h(x) + self.h2h(h_cur)
        #print(preact.size())
        ingate, forgetgate, cellgate, outgate = preact.chunk(4, 1)
        
        
        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)
        
        c_next = (forgetgate * c_cur) + (ingate * cellgate)
        h_next = outgate * F.tanh(c_next)
        
        next_state = (c_next,h_next)
        
        #softmax_out = F.softmax(self.linear(h_next.view(1,-1)),dim=1)
        
        return next_state
    

In [None]:
a = torch.rand(1,512,7,7).to(device).float()

rn = ConvRNNCell(512,128,3,True).to(device).float()
h = rn(a,None)

In [None]:
h.size()

In [None]:
model_parameters = filter(lambda p: p.requires_grad, rn.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

In [None]:
ls = ConvLSTMCell(512,128,3,True).to(device).float()
model_parameters = filter(lambda p: p.requires_grad, ls.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params