In [1]:
#torch
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'

#torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

#image
from PIL import Image
import cv2

#jupyter
from ipywidgets import FloatProgress
from IPython.display import display
from __future__ import print_function

#os
import os
import os.path as path
import glob

#math
import math
import numpy as np
import random


## ConvLSTM

#### LSTMCell

In [95]:
import torch.nn.functional as F

class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1, bias=False):
        super(ConvLSTMCell, self).__init__()
        
        self.k = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding = padding
        self.stride = stride
        
        self.w_i = nn.Parameter(torch.Tensor(4*out_channels, in_channels, kernel_size, kernel_size))
        self.w_h = nn.Parameter(torch.Tensor(4*out_channels, out_channels, kernel_size, kernel_size))
        self.w_c = nn.Parameter(torch.Tensor(3*out_channels, out_channels, kernel_size, kernel_size))

        self.bias = bias
        if bias:
          self.bias_i = Parameter(torch.Tensor(4 * out_channels))
          self.bias_h = Parameter(torch.Tensor(4 * out_channels))
          self.bias_c = Parameter(torch.Tensor(3 * out_channels))
        else:
          self.register_parameter('bias_i', None)
          self.register_parameter('bias_h', None)
          self.register_parameter('bias_c', None)
        
        self.register_buffer('wc_blank', torch.zeros(out_channels))
        self.reset_parameters()
        
    def reset_parameters(self):
        n = 4 * self.in_channels * self.k * self.k
        stdv = 1. / math.sqrt(n)
        
        self.w_i.data.uniform_(-stdv, stdv)
        self.w_h.data.uniform_(-stdv, stdv)
        self.w_c.data.uniform_(-stdv, stdv)
        
        if self.bias:
            self.bias_i.data.uniform_(-stdv, stdv)
            self.bias_h.data.uniform_(-stdv, stdv)
            self.bias_c.data.uniform_(-stdv, stdv)

        
    def forward(self, x, hx):
        h, c = hx
        print(self.w_h.size())
        wx = F.conv2d(x, self.w_i, self.bias_i, padding=self.padding, stride=self.stride)
        wh = F.conv2d(h, self.w_h, self.bias_h, padding=self.padding, stride=self.stride)
        wc = F.conv2d(c, self.w_c, self.bias_c, padding=self.padding, stride=self.stride)
        
        
        #wc = torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1)
        
        i = F.sigmoid(wx[:, :self.out_channels] + wh[:, :self.out_channels] + wc[:, :self.out_channels])
        f = F.sigmoid(wx[:, self.out_channels:2*self.out_channels] + wh[:, self.out_channels:2*self.out_channels] 
                + wc[:, self.out_channels:2*self.out_channels])
        g = F.tanh(wx[:, 2*self.out_channels:3*self.out_channels] + wh[:, 2*self.out_channels:3*self.out_channels])
        """
        
        wxhc = wx + wh + torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1)
    
        i = F.sigmoid(wxhc[:, :self.out_channels])
        f = F.sigmoid(wxhc[:, self.out_channels:2 * self.out_channels])
        g = F.tanh(wxhc[:, 2 * self.out_channels:3 * self.out_channels])
        o = F.sigmoid(wxhc[:, 3 * self.out_channels:])
        """

        c_t = f * c + i * g
        o_t = F.sigmoid(wx[:, 3*self.out_channels:] + wh[:, 3*self.out_channels:] 
                        + wc[:, 2*self.out_channels: ]*c_t)
        h_t = o_t * F.tanh(c_t)
        
        return h_t, (h_t, c_t)

In [97]:
from torch.nn import functional as F

class Conv2dLSTMCell(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    super(Conv2dLSTMCell, self).__init__()
    if in_channels % groups != 0:
        raise ValueError('in_channels must be divisible by groups')
    if out_channels % groups != 0:
        raise ValueError('out_channels must be divisible by groups')
    kernel_size = _pair(kernel_size)
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.padding_h = tuple(k // 2 for k, s, p, d in zip(kernel_size, stride, padding, dilation))
    self.dilation = dilation
    self.groups = groups
    self.weight_ih = Parameter(torch.Tensor(4 * out_channels, in_channels // groups, *kernel_size))
    self.weight_hh = Parameter(torch.Tensor(4 * out_channels, out_channels // groups, *kernel_size))
    self.weight_ch = Parameter(torch.Tensor(3 * out_channels, out_channels // groups, *kernel_size))
    if bias:
      self.bias_ih = Parameter(torch.Tensor(4 * out_channels))
      self.bias_hh = Parameter(torch.Tensor(4 * out_channels))
      self.bias_ch = Parameter(torch.Tensor(3 * out_channels))
    else:
      self.register_parameter('bias_ih', None)
      self.register_parameter('bias_hh', None)
      self.register_parameter('bias_ch', None)
    self.register_buffer('wc_blank', torch.zeros(out_channels))
    self.reset_parameters()

  def reset_parameters(self):
    n = 4 * self.in_channels
    for k in self.kernel_size:
      n *= k
    stdv = 1. / math.sqrt(n)
    self.weight_ih.data.uniform_(-stdv, stdv)
    self.weight_hh.data.uniform_(-stdv, stdv)
    self.weight_ch.data.uniform_(-stdv, stdv)
    if self.bias_ih is not None:
      self.bias_ih.data.uniform_(-stdv, stdv)
      self.bias_hh.data.uniform_(-stdv, stdv)
      self.bias_ch.data.uniform_(-stdv, stdv)

  def forward(self, i, hx):
    h_0, c_0 = hx
    print(self.weight_hh.size())
    wx = F.conv2d(i, self.weight_ih, self.bias_ih, self.stride, self.padding, self.dilation, self.groups)
    wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride, self.padding_h, self.dilation, self.groups)
    # Cell uses a Hadamard product instead of a convolution?
    wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride, self.padding_h, self.dilation, self.groups)
    wxhc = wx + wh + torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1)

    i = F.sigmoid(wxhc[:, :self.out_channels])
    f = F.sigmoid(wxhc[:, self.out_channels:2 * self.out_channels])
    g = F.tanh(wxhc[:, 2 * self.out_channels:3 * self.out_channels])
    o = F.sigmoid(wxhc[:, 3 * self.out_channels:])
    
    c_1 = f * c_0 + i * g
    h_1 = o * F.tanh(c_1)
    return h_1, (h_1, c_1)

#### ConvRNN

In [40]:
class convRNN_1_layer(nn.Module):
    def __init__(self, num_classes=1000):
        super(convRNN_1_layer, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.convRNN = Conv2dLSTMCell(256,128,kernel_size=3, padding=1, stride=1)
        self.classifier = nn.Sequential(
            nn.Conv2d(256,6,kernel_size=1, padding=0, stride=1),
            nn.AvgPool2d(kernel_size=6, stride=1, padding=0)
        )

    def forward(self, x):
        ht = Variable( torch.zeros(1,128,6,6))
        ct = Variable( torch.zeros(1,128,6,6))
        x = self.features(x)
        return  self.convRNN(x, (ht, ct))
        #x = self.classifier(x).squeeze().unsqueeze(0)
        

In [61]:
def testconvRNN():
    x = Variable(torch.Tensor(1,3,225,225))
    m = convRNN_1_layer()
    print(m(x))

In [98]:
testconvRNN()

torch.Size([512, 128, 3, 3])
(Variable containing:
( 0 , 0 ,.,.) = 
1.00000e-03 *
 -0.6927 -0.4737 -0.4859 -0.4975 -0.4914 -0.3841
 -0.9575  0.0636  0.0166  0.0037 -0.0145  0.0594
 -0.9212  0.0519  0.0131  0.0107 -0.0113  0.0406
 -0.9261  0.0618  0.0198  0.0136 -0.0148  0.0365
 -0.8768  0.1190  0.0726  0.0626  0.0220  0.0245
 -0.4727  0.4313  0.3779  0.3772  0.3600  0.0772

( 0 , 1 ,.,.) = 
1.00000e-03 *
 -0.0542 -0.5659 -0.5734 -0.5710 -0.6102 -0.6753
  0.3032 -0.3662 -0.3639 -0.3625 -0.3466 -0.6217
  0.2723 -0.4251 -0.4000 -0.4061 -0.4005 -0.6372
  0.2746 -0.4231 -0.4062 -0.4108 -0.4062 -0.6411
  0.2832 -0.3731 -0.3832 -0.3835 -0.3736 -0.6256
  0.1762 -0.5745 -0.5285 -0.5341 -0.5271 -0.5160

( 0 , 2 ,.,.) = 
1.00000e-03 *
 -0.4142 -0.6181 -0.6394 -0.6357 -0.6244 -0.3590
  0.6250  0.0437  0.0638  0.0616  0.0757 -0.2885
  0.5277 -0.0850 -0.0755 -0.0735 -0.0844 -0.3836
  0.5220 -0.0872 -0.0860 -0.0885 -0.0954 -0.4019
  0.4702 -0.1569 -0.1357 -0.1413 -0.1526 -0.4455
  0.5719  0.2484  0.2