In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import models
import torch.nn.functional as F
import matplotlib.pyplot as plt
import glob
import os
from skimage.transform import resize
import cv2
import itertools
import warnings
warnings.filterwarnings('ignore')
import skimage as sk
from natsort import natsorted

device = torch.device("cpu")

In [None]:
def array_to_tensor(array):
    tensor = torch.from_numpy(array)
    return tensor

def normalize(X):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    X_normalized = (X - mean)/std
    return X_normalized


def one_hot(array, num_classes):
    return np.squeeze(np.eye(num_classes)[array.reshape(-1)])

def im_toarray(im):
    array = plt.imread(im)
    array = resize(array,(224,224,3),anti_aliasing=True)
    return array

def create_sets(folder,label):
    X_ = []
    Y_ = []


    for frame in glob.glob(folder+"/*.jpg"):
        im_array = im_toarray(frame)
        X_.append(im_array)
        Y_.append(label)

    
    return normalize(np.array(X_)),np.array(Y_)

In [None]:
class Conv(nn.Module):
    def __init__(self, input_dim, out_dim, kernel_size, stride, padding, bias, use_max):
        super(Conv, self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.bias = bias
        self.use_max = use_max
        
        self.conv2d = nn.Conv2d(in_channels=input_dim,
                              out_channels=out_dim,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding,
                              bias=bias)
        
        self.maxpool = nn.MaxPool2d(kernel_size=kernel_size//2)
        
    def forward(self,x):
        out = self.conv2d(x)
        if self.use_max:
            out = self.maxpool(out)
        return out

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()
        self.batch,self.input_channel,self.height, self.width = input_size
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_channel + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

        
    def forward(self, input_tensor, cur_state):
        
        if cur_state is None:
            cur_state = (Variable(torch.zeros(1, self.hidden_dim, self.height, self.width)),
                Variable(torch.zeros(1, self.hidden_dim, self.height, self.width)))
        
        h_cur, c_cur = cur_state
        combined = torch.cat((input_tensor, h_cur), dim=1)  # concatenate along channel axis
        #print(combined.size())
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        
        next_state = (h_next,c_next)
        return next_state
        

In [None]:
class Stack(nn.Module):
    def __init__(self):
        super(Stack, self).__init__()
        
        
        self.conv1 = Conv(3,128,5,1,0,True,True)
        self.convlstm1 = ConvLSTMCell((1,128,110,110),256,(3,3),True)
        self.conv2 = Conv(256,256,4,1,0,True,True)
        self.convlstm2 = ConvLSTMCell((1,256,53,53),512,(3,3),True)
        self.conv3 = Conv(512,1024,5,5,0,True,True)
        self.linear = nn.Linear(1024*5*5,99,bias=True)
        
    def forward(self,x,state1,state2):
        
        conv1_o = self.conv1(x)
        convlstm_h1,convlstm_c1 = self.convlstm1(conv1_o,state1)
        conv2_o = self.conv2(convlstm_h1)
        convlstm_h2,convlstm_c2 = self.convlstm2(conv2_o,state2)
        conv3_o = self.conv3(convlstm_h2)
        flatten = conv3_o.view(1,-1)
        soft = F.softmax(self.linear(flatten),dim=1)
        
        state2 = (convlstm_h2,convlstm_c2)
        state1 = (convlstm_h1,convlstm_c1)
        
        return soft,state2,state1

In [None]:
a = torch.rand(1,3,224,224)
s_model = Stack()
o,d,f = s_model(a,None,None)
o.size()
model_parameters = filter(lambda p: p.requires_grad, s_model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])  
print(params)
o.size()

In [None]:
def calculate_loss(out,real):

    log_loss = torch.mean(-real*torch.log(out))
    mean_loss = torch.mean((out-real)**2)
    
    loss = (mean_loss + log_loss)/2.0

    return log_loss

In [None]:
folder1 = "Frames/Train/ApplyLipstick/1"
X1,Y1 = create_sets(folder1,78)
X1 = X1.reshape(-1,3,224,224)
Y1 = one_hot(Y1,99)
X1,Y1 = array_to_tensor(X1).to(device).float(),array_to_tensor(Y1).to(device).float()

folder2 = "Frames/Train/Bowling/1"
X2,Y2 = create_sets(folder2,55)
X2 = X2.reshape(-1,3,224,224)
Y2 = one_hot(Y2,99)
X2,Y2 = array_to_tensor(X2).to(device).float(),array_to_tensor(Y2).to(device).float()

xlist = [X1]
ylist = [Y1]

optimizer = torch.optim.SGD(s_model.parameters(),lr=1.e-1)

In [None]:
for epoch in range(500):
    for k in range(len(xlist)):
        inp = xlist[k]
        real = ylist[k]
        state1 = None
        state2 = None
        pair_loss = 0
        for i in range(inp.size(0)):
            out,state2,state1 = s_model(inp[i:i+1],state1,state2)
            loss = calculate_loss(out,real[i:i+1])
            pair_loss += loss
        
        pair_loss_av = pair_loss/i
        
        optimizer.zero_grad()
        pair_loss_av.backward()
        optimizer.step()
        
        print(pair_loss_av.item())
            
            