In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class MyConvAugmentedLSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, kernel_size=3, stride=1, padding=1):
        super(MyConvAugmentedLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.conv_i_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_i_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)

        self.conv_f_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_f_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)

        self.conv_c_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_c_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)

        self.conv_o_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_o_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)
        
        #Params of CAMs
        
        self.conv_i_xx_cam = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_i_hh_cam = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)

        self.conv_f_xx_cam = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_f_hh_cam = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)

        self.conv_c_xx_cam = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_c_hh_cam = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)

        self.conv_o_xx_cam = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_o_hh_cam = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=False)
        

        torch.nn.init.xavier_normal(self.conv_i_xx.weight)
        torch.nn.init.constant(self.conv_i_xx.bias, 0)
        torch.nn.init.xavier_normal(self.conv_i_hh.weight)

        torch.nn.init.xavier_normal(self.conv_f_xx.weight)
        torch.nn.init.constant(self.conv_f_xx.bias, 0)
        torch.nn.init.xavier_normal(self.conv_f_hh.weight)

        torch.nn.init.xavier_normal(self.conv_c_xx.weight)
        torch.nn.init.constant(self.conv_c_xx.bias, 0)
        torch.nn.init.xavier_normal(self.conv_c_hh.weight)

        torch.nn.init.xavier_normal(self.conv_o_xx.weight)
        torch.nn.init.constant(self.conv_o_xx.bias, 0)
        torch.nn.init.xavier_normal(self.conv_o_hh.weight)

        
        torch.nn.init.xavier_normal(self.conv_i_xx_cam.weight)
        torch.nn.init.constant(self.conv_i_xx_cam.bias, 0)
        torch.nn.init.xavier_normal(self.conv_i_hh_cam.weight)

        torch.nn.init.xavier_normal(self.conv_f_xx_cam.weight)
        torch.nn.init.constant(self.conv_f_xx_cam.bias, 0)
        torch.nn.init.xavier_normal(self.conv_f_hh_cam.weight)

        torch.nn.init.xavier_normal(self.conv_c_xx_cam.weight)
        torch.nn.init.constant(self.conv_c_xx_cam.bias, 0)
        torch.nn.init.xavier_normal(self.conv_c_hh_cam.weight)

        torch.nn.init.xavier_normal(self.conv_o_xx_cam.weight)
        torch.nn.init.constant(self.conv_o_xx_cam.bias, 0)
        torch.nn.init.xavier_normal(self.conv_o_hh_cam.weight)
        

    def forward(self, x, cam, state, statecam):
        if state is None:
            state = (Variable(torch.randn(x.size(0), x.size(1), x.size(2), x.size(3)).cuda()),
                     Variable(torch.randn(x.size(0), x.size(1), x.size(2), x.size(3)).cuda()))
        ht_1, ct_1 = state
        it = F.sigmoid(self.conv_i_xx(x) + self.conv_i_hh(ht_1))
        ft = F.sigmoid(self.conv_f_xx(x) + self.conv_f_hh(ht_1))
        ct_tilde = F.tanh(self.conv_c_xx(x) + self.conv_c_hh(ht_1))
        ct = (ct_tilde * it) + (ct_1 * ft)
        ot = F.sigmoid(self.conv_o_xx(x) + self.conv_o_hh(ht_1))
        ht = ot * F.tanh(ct)
        '''
        ht_1_cam, ct_1_cam = statecam
        it_cam = F.sigmoid(self.conv_i_xx(cam) + self.conv_i_hh(ht_1_cam))
        ft_cam = F.sigmoid(self.conv_f_xx(cam) + self.conv_f_hh(ht_1_cam))
        ct_tilde_cam = F.tanh(self.conv_c_xx(cam) + self.conv_c_hh(ht_1_cam))
        ct_cam = (ct_tilde_cam * it_cam) + (ct_1_cam * ft_cam)
        ot_cam = F.sigmoid(self.conv_o_xx(cam) + self.conv_o_hh(ht_1_cam))
        ht_cam = ot_cam * F.tanh(ct_cam)
        '''
        ht_1_cam, ct_1_cam = statecam
        it_cam = F.sigmoid(self.conv_i_xx_cam(cam) + self.conv_i_hh_cam(ht_1_cam))
        ft_cam = F.sigmoid(self.conv_f_xx_cam(cam) + self.conv_f_hh_cam(ht_1_cam))
        ct_tilde_cam = F.tanh(self.conv_c_xx_cam(cam) + self.conv_c_hh_cam(ht_1_cam))
        ct_cam = (ct_tilde_cam * it_cam) + (ct_1_cam * ft_cam)
        ot_cam = F.sigmoid(self.conv_o_xx_cam(cam) + self.conv_o_hh_cam(ht_1_cam))
        ht_cam = ot_cam * F.tanh(ct_cam)
        
        #return (ht, ct), (ht_cam, ot_cam)
        return (ht, ct), (ht_cam, ct_cam)