In [None]:
#torch
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'

#torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

#image
from PIL import Image
import cv2

#jupyter
from ipywidgets import FloatProgress
from IPython.display import display
from __future__ import print_function

#os
import os
import os.path as path
import glob

#math
import math
import numpy as np
import random


## ConvLSTM

#### LSTMCell

In [None]:
t = Variable(torch.rand(1,256,6,6))
ht = Variable( torch.zeros(1,128,6,6))
ct = Variable( torch.zeros(1,128,6,6))

In [145]:
import torch.nn.functional as F

class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1, bias=False):
        super(ConvLSTMCell, self).__init__()
        
        self.k = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding = padding
        self.stride = stride
        
        self.w_i = nn.Parameter(torch.Tensor(4*out_channels, in_channels, kernel_size, kernel_size))
        self.w_h = nn.Parameter(torch.Tensor(4*out_channels, out_channels, kernel_size, kernel_size))
        self.w_c = nn.Parameter(torch.Tensor(3*out_channels, out_channels, kernel_size, kernel_size))

        self.bias = bias
        if bias:
          self.bias_i = Parameter(torch.Tensor(4 * out_channels))
          self.bias_h = Parameter(torch.Tensor(4 * out_channels))
          self.bias_c = Parameter(torch.Tensor(3 * out_channels))
        else:
          self.register_parameter('bias_i', None)
          self.register_parameter('bias_h', None)
          self.register_parameter('bias_c', None)
        
        self.register_buffer('wc_blank', torch.zeros(out_channels))
        self.reset_parameters()
        
    def reset_parameters(self):
        n = 4 * self.in_channels * self.k * self.k
        stdv = 1. / math.sqrt(n)
        
        self.w_i.data.uniform_(-stdv, stdv)
        self.w_h.data.uniform_(-stdv, stdv)
        self.w_c.data.uniform_(-stdv, stdv)
        
        if self.bias:
            self.bias_i.data.uniform_(-stdv, stdv)
            self.bias_h.data.uniform_(-stdv, stdv)
            self.bias_c.data.uniform_(-stdv, stdv)

        
    def forward(self, x, hx):
        h, c = hx
        wx = F.conv2d(x, self.w_i, self.bias_i, padding=self.padding, stride=self.stride)
        wh = F.conv2d(h, self.w_h, self.bias_h, padding=self.padding, stride=self.stride)
        wc = F.conv2d(c, self.w_c, self.bias_c, padding=self.padding, stride=self.stride)
        
        
        #wc = torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1)
        
        i = F.sigmoid(wx[:, :self.out_channels] + wh[:, :self.out_channels] + wc[:, :self.out_channels])
        f = F.sigmoid(wx[:, self.out_channels:2*self.out_channels] + wh[:, self.out_channels:2*self.out_channels] 
                + wc[:, self.out_channels:2*self.out_channels])
        g = F.tanh(wx[:, 2*self.out_channels:3*self.out_channels] + wh[:, 2*self.out_channels:3*self.out_channels])
        """
        
        wxhc = wx + wh + torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1)
    
        i = F.sigmoid(wxhc[:, :self.out_channels])
        f = F.sigmoid(wxhc[:, self.out_channels:2 * self.out_channels])
        g = F.tanh(wxhc[:, 2 * self.out_channels:3 * self.out_channels])
        o = F.sigmoid(wxhc[:, 3 * self.out_channels:])
        """

        c_t = f * c + i * g
        o_t = F.sigmoid(wx[:, 3*self.out_channels:] + wh[:, 3*self.out_channels:] 
                        + wc[:, 2*self.out_channels: ]*c_t)
        h_t = o_t * F.tanh(c_t)
        
        return h_t, (h_t, c_t)

In [146]:
"""
    Test convLSTM Cell
"""
c = ConvLSTMCell(256,128,3)
o = c(t, (ht,ct))
print(o[0].size() == torch.Size([1,128,6,6]))

True


In [None]:
from torch.nn import init

class ConvGRUCell(nn.Module):
    """
    Generate a convolutional GRU cell
    """

    def __init__(self, input_size, hidden_size, kernel_size):
        super(ConvGRUCell, self).__init__()
        padding = kernel_size // 2
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
        self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
        self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)

        init.orthogonal(self.reset_gate.weight)
        init.orthogonal(self.update_gate.weight)
        init.orthogonal(self.out_gate.weight)
        init.constant(self.reset_gate.bias, 0.)
        init.constant(self.update_gate.bias, 0.)
        init.constant(self.out_gate.bias, 0.)


    def forward(self, input_, prev_state):

        # get batch and spatial sizes
        batch_size = input_.data.size()[0]
        spatial_size = input_.data.size()[2:]

        # generate empty prev_state, if None is provided
        if prev_state is None:
            state_size = [batch_size, self.hidden_size] + list(spatial_size)
            if torch.cuda.is_available():
                prev_state = Variable(torch.zeros(state_size)).cuda()
            else:
                prev_state = Variable(torch.zeros(state_size))

        # data size is [batch, channel, height, width]
        stacked_inputs = torch.cat([input_, prev_state], dim=1)
        update = F.sigmoid(self.update_gate(stacked_inputs))
        reset = F.sigmoid(self.reset_gate(stacked_inputs))
        out_inputs = F.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1)))
        new_state = prev_state * (1 - update) + out_inputs * update

        return new_state

In [147]:
"""
    Test GruCell
"""
c = ConvGRUCell(256,128,3)
o = c(t, Variable( torch.zeros(1,128,6,6)))
print(o.size() == torch.Size([1,128,6,6]))

True


#### ConvRNN

In [169]:
class convRNN_1_layer(nn.Module):
    """
        Define a RNN with 1 recurrent layer
        args : r_type : lstm | gru
    """
    def __init__(self, r_type="lstm"):
        super(convRNN_1_layer, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.r_type = r_type
        if r_type == "lstm":
            self.convRNN = ConvLSTMCell(256,128,kernel_size=3, padding=1, stride=1)
        elif r_type == "gru":
            self.convRNN = ConvGRUCell(256,128,3)
        else:
            print("Error : r_type")
            return -1
        
        self.classifier = nn.Sequential(
            nn.Conv2d(128,6,kernel_size=1, padding=0, stride=1),
            nn.AvgPool2d(kernel_size=6, stride=1, padding=0)
        )

    def forward(self, x):
        
        if self.r_type=="lstm":
            outputs = []
            ht = Variable( torch.zeros(1,128,6,6))
            ct = Variable( torch.zeros(1,128,6,6))

            for i in x:
                xt = self.features(i)
                o, (ht,ct) = self.convRNN(xt, (ht, ct))
                outputs.append(o)
        
            return outputs[-1]
        elif self.r_type=="gru":
            outputs = []
            ht = Variable( torch.zeros(1,128,6,6))
            
            for e,i in enumerate(x):
                xt = self.features(i)
                ht = self.convRNN(xt, ht)
                outputs.append(ht)
        
            return outputs[-1]
        #x = self.classifier(x).squeeze().unsqueeze(0)
        

In [165]:
def testconvRNN():
    x = Variable(torch.Tensor(3,1,3,225,225))
    m = convRNN_1_layer("gru")
    print(m(x))

In [168]:
testconvRNN()

0
1
2
Variable containing:
( 0 , 0 ,.,.) = 
1.00000e-02 *
 -0.0219 -0.0403 -0.1340 -0.1339 -0.0979  0.1944
  0.4033  0.5376  0.4314  0.4571  0.5379  0.6263
  0.3643  0.5017  0.4147  0.4297  0.4806  0.6200
  0.3815  0.5549  0.5013  0.5114  0.5650  0.6845
  0.5531  0.7888  0.7034  0.6976  0.7009  0.7421
  0.3112  0.2488  0.2539  0.2553  0.2769  0.3705

( 0 , 1 ,.,.) = 
1.00000e-02 *
 -0.8323 -0.6060 -0.7092 -0.6883 -0.6448 -0.4667
 -0.7171 -0.3268 -0.3492 -0.3390 -0.2938 -0.4691
 -0.8129 -0.4514 -0.4954 -0.4771 -0.4078 -0.4968
 -0.7693 -0.4077 -0.4672 -0.4539 -0.3774 -0.5185
 -0.8139 -0.4552 -0.5114 -0.5084 -0.4187 -0.5753
 -0.7730 -0.6422 -0.6360 -0.6432 -0.6081 -0.7433

( 0 , 2 ,.,.) = 
1.00000e-02 *
  0.7481  1.2899  1.3627  1.3442  1.2990  0.7380
  1.0561  1.8520  1.9480  1.9349  1.9691  1.2708
  0.9932  1.8623  1.8975  1.8814  1.8921  1.2319
  0.9961  1.8889  1.9104  1.8987  1.9232  1.2420
  1.0630  1.9169  1.9595  1.9386  1.9375  1.2786
  0.6413  1.0986  1.1128  1.1032  1.1468  0.5