In [2]:
import torch.nn as nn
from torch.autograd import Variable
import torch

In [8]:
class Encoder(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(Encoder, self).__init__()
        '''
        Encoder network for generator 
        input size (N,3,64,64)
        output size(N,256,8,8)
        '''
        
        self.conv1 = nn.Conv2d(3,32,3,padding=1)
        self.conv2 = nn.Conv2d(32,64,3,padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.AvgPool2d(2,stride=2)
        self.conv3 = nn.Conv2d(64,128,3,padding=1)
        self.bn2 - nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128,256,3,padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256,256,3,padding=1)
        self.bn4 = nn.BatchNorm(256)
        
        #output (256*8*8)
    def forward(self,input_tensor):
        h1 = self.conv1(input_tensor)
        h1 = F.LeakyReLU(h1)

        h2 = self.conv2(h1)
        h2 = self.bn1(h2)
        h2 = F.LeakyReLU(h2)

        h3 = self.pool(h2)
        h3 = self.conv3(h3)
        h3 = self.bn2(h3)
        h3 = F.LeakyReLU(h3)

        h4 = self.pool(h3)
        h4 = self.conv4(h4)
        h4 = self.bn3(h4)
        h4 = self.LeakyReLU(h4)

        h5 = self.pool(h4)
        h5 = self.conv5(h4)
        h5 = self.bn4(h4)
        out = F.LeakyReLU(h4)
        cache = (h2,h3,h4)

        return out,cache
        

In [8]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        
        Parameters
        ----------
        input_size: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.height, self.width = input_size
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        
        h_cur, c_cur = cur_state
        
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        
        return h_next, c_next

    def init_hidden(self, batch_size):
        #return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(),
        #        Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda())
        return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)),
                Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)))

class ConvLSTM(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim  = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.height, self.width = input_size

        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]

            cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """
        
        Parameters
        ----------
        input_tensor: todo 
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful
            
        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=input_tensor.size(0))

        layer_output_list = []
        last_state_list   = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):

                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list   = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                    (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [11]:
class Decoder(nn.Module):

    def __init__(self,p,kernel_size=3,upsample_size=2):
        super(Decoder, self).__init__()
        #convLSTM out: (N,256,8,8)
        #encoder out: (N,256,8,8)
        #noise z : (N,p,8,8)
        #input size:(N,512+p,8,8)
        
        self.s1 = Gate(2,3,256)
        self.s2 = Gate(4,3,128)
        self.s3 = Gate(8,3,64)
        
        self.conv6 = nn.Conv2d(512+p,256,kernel_size,padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        #LeakyReLU
        self.upsample = nn.Upsample(scale_factor=upsample_size, mode='bilinear')
        #Gating  is it gonna backprop?
        
        self.conv7 = nn.Con2d(256,128,kernel_size,padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        #LeakyReLU
        #upsample
        #Gating
        
        self.conv8 = nn.Conv2d(128,64,kernel_size,padding=1)
        self.bn8 = nn.BatchNorm2d(64)
        #LeakyReLU
        #upsample
        #Gating
        
        self.conv9 = nn.Conv2d(64,64,kernel_size,padding=1)
        self.bn9 = nn.BatchNorm2d(64)
        #LeakyReLU
        self.conv10 = nn.Conv2d(64,3,kernel_size,padding=1)
        #Tanh
        
    def forward(self,lstm_out,z,encoder_out,encoder_cache):
        #z - tiling noise size(p*8*8)
        #decoder_input:(N,512+P,8,8)
        
        e3,e2,e1 = encoder_cache
        decoder_input = torch.cat((lstm_out,encoder_out,z),1)
        u1 = self.conv6(decoder_input)
        u1 = self.bn6(u1)
        u1 = F.LeakyReLU(u1)
        u1 = self.upsample(u1)
        
        s1 = self.s1(lstm_out)
        u2 = s1*u1 + (1-s1)*e1
        u2 = self.conv7(u2)
        u2 = self.bn7(u2)
        u2 = F.LeakyReLU(u2)
        u2 = self.upsample(u2)
        
        s2 = self.s2(lstm_out)
        u3 = s2*u2 + (1-s2)*e2
        u3 = self.conv8(u3)
        u3 = self.bn8(u3)
        u3 = F.LeakyReLU(u3)
        u3 = self.upsample(u3)
        
        s3 = self.s3(lstm_out)
        u4 = s3*u3 + (1-s3)*e3
        u4 = self.conv9(u4)
        u4 = self.bn9(u4)
        u4 = F.LeakyReLU(u4)
        u4 = self.conv10(u4)
        out = F.tanh(u4)
        return out
        

In [4]:
class Gate(nn.Module):

    def __init__(self, upsample_size,kernel_size,num_filters):
        super(Gate, self).__init__()
        self.upsample = nn.Upsample(scale_factor=upsample_size, mode='bilinear'),
        self.conv = nn.Conv2d(256,num_filters,kernel_size,padding=1)
    def forward(self,lstm_out):
        h = self.upsample(lstm_out)
        h = self.conv(h)
        h = self.LeakyReLU(h)
        out = F.sigmoid(h)
        return out
        

In [15]:
class Appearance_D(nn.Module):

    def __init__(self,kernel_size=4):
        super(Appearance_D, self).__init__()
        #network 1
        self.network1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size,stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size,stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128,256,kernel_size,stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU())
        
        #network 2
        #input to 2: (256*4,8,8)
        self.network2 = nn.Sequential(
            nn.ConvTranspose2d(256*4,256,3,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256,512,kernel_size,stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512,1024,kernel_size,stride=4),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(),
            nn.Linear(1024,64),
            nn.LeakyReLU(),
            nn.Linear(64,1),
            nn.Sigmoid())
        
    def forward(self,x,y_a):
        x_out = self.network1(x)  #(3,256,8,8)
        y_out = self.network1(y_a) #(1,256,8,8)
        y_out = torch.squeeze(y_out)
        x_1 = torch.cat((x_out[0],x_out[1],x_out[2],y_out),0)
        x_1 = x_1.view(1,256*4,8,8)
        out = self.network2(x_1)
        return out
        
        

In [20]:
class Motion_D(nn.Module):

    def __init__(self,num_keypoints,num_classes,kernel_size=4):
        super(Motion_D, self).__init__()
        #input y_a (3*64*64) 
        #y_l (c*4*4) ??
        #x (3*64*64)
        self.encoder = nn.Sequential(
            nn.Conv2d(3,64,kernel_size,stride=2,padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size,stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128,256,kernel_size,stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU())
        #TODO:convLSTM
        
        #last hidden state h_out (1*256*8*8)
        self.conv4 = nn.Conv2d(256,64,kernel_size,stride=2,padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        #LeakyRelu -> flatten
        self.fc_h1 = nn.Linear(1024,64)
        self.bn_fc1 = nn.BatchNorm1d(64)
        #LeakyRelu
        self.fc_h2 = nn.Linear(64,num_classes)
        self.bn_fc2 = nn.BatchNorm1d(num_classes)
        #LeakyRelu -> softmax
        
        #output: ot (N,256,8,8)
        self.conv6 = nn.Conv2d(256,64,kernel_size,stride=2,padding=1)
        self.bn6 = nn.BatchNorm2d(64)
        #leakyRelu -> flatten
        self.fc_o = nn.Linear(1024,2*num_keypoints)
        
        #concatenate y_l (c*4*4) + (64*4*4) = (c+64,4,4)
        self.conv5 = nn.Conv2d(num_classes+64,64,kernel_size,stride=2,padding=1)
        self.bn5 = nn.BatchNorm(64)
        #leakyrelu
        self.fc_y = nn.Linear(64,1)
        #sigmoid
        
        #TODO forward
        
        

In [13]:
y = torch.rand((2,3,8,8))
print(y[1].size())

torch.Size([3, 8, 8])


In [10]:


T = 10
batch_size = 16
input_dim = 1
input_size = (15,1)
hidden_dim = 3
kernel_size = (3,3)
bias = False
#input tensor t, b, c, h, w
y_m = torch.rand((T,batch_size,input_dim,15,1))
model = ConvLSTM(input_size, input_dim, hidden_dim, kernel_size, 1)
out = model.forward(y_m)
print(out)

([tensor([[[[[ 0.0111],
           [ 0.0058],
           [ 0.0232],
           ...,
           [ 0.0195],
           [ 0.0058],
           [ 0.0307]],

          [[-0.0094],
           [ 0.0052],
           [-0.0022],
           ...,
           [ 0.0070],
           [ 0.0107],
           [-0.0107]],

          [[-0.0335],
           [-0.0342],
           [-0.0224],
           ...,
           [-0.0160],
           [-0.0305],
           [-0.0171]]],


         [[[ 0.0244],
           [ 0.0247],
           [ 0.0114],
           ...,
           [ 0.0205],
           [ 0.0327],
           [ 0.0433]],

          [[-0.0066],
           [ 0.0259],
           [ 0.0206],
           ...,
           [ 0.0127],
           [ 0.0287],
           [ 0.0013]],

          [[-0.0517],
           [-0.0237],
           [-0.0423],
           ...,
           [-0.0499],
           [-0.0176],
           [-0.0101]]],


         [[[ 0.0119],
           [ 0.0026],
           [ 0.0153],
           ...,
           [