In [None]:
import torch
import torch.nn as nn
from convolution_lstm import ConvLSTMCell, ConvLSTM

In [None]:
from torch.nn.utils import _pair

class ConvLSTM2DCell(nn.Module):
  """KERAS Cell class for the ConvLSTM2D layer.
  Arguments:
    filters: Integer, the dimensionality of the output space
      (i.e. the number of output filters in the convolution).
    kernel_size: An integer or tuple/list of n integers, specifying the
      dimensions of the convolution window.
    strides: An integer or tuple/list of n integers,
      specifying the strides of the convolution.
      Specifying any stride value != 1 is incompatible with specifying
      any `dilation_rate` value != 1.
    padding: One of `"valid"` or `"same"` (case-insensitive).
    data_format: A string,
      one of `channels_last` (default) or `channels_first`.
      It defaults to the `image_data_format` value found in your
      Keras config file at `~/.keras/keras.json`.
      If you never set it, then it will be "channels_last".
    dilation_rate: An integer or tuple/list of n integers, specifying
      the dilation rate to use for dilated convolution.
      Currently, specifying any `dilation_rate` value != 1 is
      incompatible with specifying any `strides` value != 1.
    activation: Activation function to use.
      If you don't specify anything, no activation is applied
      (ie. "linear" activation: `a(x) = x`).
    recurrent_activation: Activation function to use
      for the recurrent step.
    use_bias: Boolean, whether the layer uses a bias vector.
    kernel_initializer: Initializer for the `kernel` weights matrix,
      used for the linear transformation of the inputs.
    recurrent_initializer: Initializer for the `recurrent_kernel`
      weights matrix,
      used for the linear transformation of the recurrent state.
    bias_initializer: Initializer for the bias vector.
    unit_forget_bias: Boolean.
      If True, add 1 to the bias of the forget gate at initialization.
      Use in combination with `bias_initializer="zeros"`.
      This is recommended in [Jozefowicz et al.]
      (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
    kernel_regularizer: Regularizer function applied to
      the `kernel` weights matrix.
    recurrent_regularizer: Regularizer function applied to
      the `recurrent_kernel` weights matrix.
    bias_regularizer: Regularizer function applied to the bias vector.
    kernel_constraint: Constraint function applied to
      the `kernel` weights matrix.
    recurrent_constraint: Constraint function applied to
      the `recurrent_kernel` weights matrix.
    bias_constraint: Constraint function applied to the bias vector.
    dropout: Float between 0 and 1.
      Fraction of the units to drop for
      the linear transformation of the inputs.
    recurrent_dropout: Float between 0 and 1.
      Fraction of the units to drop for
      the linear transformation of the recurrent state.
  Call arguments:
    inputs: A 4D tensor.
    states:  List of state tensors corresponding to the previous timestep.
    training: Python boolean indicating whether the layer should behave in
      training mode or in inference mode. Only relevant when `dropout` or
      `recurrent_dropout` is used.
  """

  def __init__(self,
               in_channels,
               out_channels,
               kernel_size,
               stride=(1, 1),
               padding=0,
               padding_mode='zeros',
               #data_format=None,
               dilation_rate=(1, 1),
               #activation='tanh',
               #recurrent_activation='hard_sigmoid',
               #use_bias=True,
               #kernel_initializer='glorot_uniform',
               #recurrent_initializer='orthogonal',
               #bias_initializer='zeros',
               #unit_forget_bias=True,
               #kernel_regularizer=None,
               #recurrent_regularizer=None,
               #bias_regularizer=None,
               #kernel_constraint=None,
               #recurrent_constraint=None,
               #bias_constraint=None,
               dropout=0.,
               recurrent_dropout=0.,
               **kwargs):
    super(ConvLSTM2DCell, self).__init__(**kwargs)
    self.filters = filters
    self.kernel_size = _pair(kernel_size)
    self.strides = _pair(strides)
    self.padding = conv_utils.normalize_padding(padding)
    #self.data_format = conv_utils.normalize_data_format(data_format)
    self.dilation_rate = _pair(dilation_rate)
    #self.activation = activations.get(activation)
    #self.recurrent_activation = activations.get(recurrent_activation)
    #self.use_bias = use_bias

    #self.kernel_initializer = initializers.get(kernel_initializer)
    #self.recurrent_initializer = initializers.get(recurrent_initializer)
    #self.bias_initializer = initializers.get(bias_initializer)
    #self.unit_forget_bias = unit_forget_bias

    #self.kernel_regularizer = regularizers.get(kernel_regularizer)
    #self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
    #self.bias_regularizer = regularizers.get(bias_regularizer)

    #self.kernel_constraint = constraints.get(kernel_constraint)
    #self.recurrent_constraint = constraints.get(recurrent_constraint)
    #self.bias_constraint = constraints.get(bias_constraint)

    #self.dropout = min(1., max(0., dropout))
    #self.recurrent_dropout = min(1., max(0., recurrent_dropout))
    #self.state_size = (self.filters, self.filters)

https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/layers/convolutional_recurrent.py
https://arxiv.org/pdf/1506.04214v2.pdf
https://arxiv.org/pdf/1701.01546.pdf


In [None]:
class VideoAutoencoderLSTM(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        self.conv_encoder = nn.Sequential(OrderedDict([
              ('conv1', nn.Conv3d(in_channels=1, out_channels=128, kernel_size=(11,11,1),stride=(4,4,1), padding=0)),
              ('nonl1', nn.Tanh()),
              ('conv2', nn.Conv3d(in_channels=128, out_channels=64, kernel_size=(5,5,1),stride=(2,2,1), padding=0)),
              ('nonl2', nn.Tanh())
            ]))
        self.rnn_encoder = ConvLSTMCell()
        self.rnn_bottleneck = ConvLSTMCell()
        self.rnn_decoder = ConvLSTMCell()
        self.conv_decoder = nn.Sequential(OrderedDict([
              ('deconv1', nn.ConvTranspose3d(1,20,5)),
              ('nonl1', nn.Tanh()),
              ('deconv2', nn.ConvTranspose3d(20,64,5)),
              ('nonl2', nn.Tanh())
            ]))
    
    def forward(self, x):
        
        model.add(Conv3D(filters=128,kernel_size=(11,11,1),strides=(4,4,1),padding='valid',input_shape=(227,227,10,1),activation='tanh'))
	model.add(Conv3D(filters=64,kernel_size=(5,5,1),strides=(2,2,1),padding='valid',activation='tanh'))



	model.add(ConvLSTM2D(filters=64,kernel_size=(3,3),strides=1,padding='same',dropout=0.4,recurrent_dropout=0.3,return_sequences=True))

	
	model.add(ConvLSTM2D(filters=32,kernel_size=(3,3),strides=1,padding='same',dropout=0.3,return_sequences=True))


	model.add(ConvLSTM2D(filters=64,kernel_size=(3,3),strides=1,return_sequences=True, padding='same',dropout=0.5))




	model.add(Conv3DTranspose(filters=128,kernel_size=(5,5,1),strides=(2,2,1),padding='valid',activation='tanh'))
	model.add(Conv3DTranspose(filters=1,kernel_size=(11,11,1),strides=(4,4,1),padding='valid',activation='tanh'))

	model.compile(optimizer='adam',loss='mean_squared_error',metrics=['accuracy'])

In [None]:


if __name__ == '__main__':
    # gradient check
    convlstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64, 32, 32], kernel_size=3, step=5,
                        effective_step=[4]).cuda()
    loss_fn = torch.nn.MSELoss()

    input = Variable(torch.randn(1, 512, 64, 32)).cuda()
    target = Variable(torch.randn(1, 32, 64, 32)).double().cuda()

    output = convlstm(input)
    output = output[0][0].double()
    res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True)
    print(res)