In [1]:
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import glob
import os
from skimage.transform import resize
import cv2
import itertools
import warnings
warnings.filterwarnings('ignore')
import skimage as sk

device = torch.device("cpu")
#torch.set_default_tensor_type('torch.FloatTensor')

In [2]:
class LSTMCell(nn.Module):
    def __init__(self, flatten_dim, hidden_size, bias):
        super(LSTMCell, self).__init__()
    
        self.flatten_dim = flatten_dim
        self.hidden_size = hidden_size

        self.bias        = bias
        
        self.i2h = nn.Linear(flatten_dim, 4*hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4*hidden_size, bias=bias)
        
 
        
    def forward(self, x, cur_state):
        
        if cur_state is None:
            cur_state = (Variable(torch.zeros(1, self.hidden_size)).to(device).float(),
                Variable(torch.zeros(1, self.hidden_size)).to(device).float())
        
        x = x.view(1,-1)
        c_cur, h_cur = cur_state
        preact = self.i2h(x) + self.h2h(h_cur)
        #print(preact.size())
        ingate, forgetgate, cellgate, outgate = preact.chunk(4, 1)
        
        
        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)
        
        c_next = (forgetgate * c_cur) + (ingate * cellgate)
        h_next = outgate * F.tanh(c_next)
        
        next_state = (c_next,h_next)
        
        #softmax_out = F.softmax(self.linear(h_next.view(1,-1)),dim=1)
        
        return next_state
    

In [3]:
def array_to_tensor(array):
    tensor = torch.from_numpy(array)
    return tensor

def normalize(X):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    X_normalized = (X - mean)/std
    return X_normalized


def one_hot(array, num_classes):
    return np.squeeze(np.eye(num_classes)[array.reshape(-1)])

def im_toarray(im):
    array = plt.imread(im)
    array = resize(array,(224,224,3),anti_aliasing=True)
    return array

def create_sets(folder,label):
    X_ = []
    Y_ = []


    for frame in glob.glob(folder+"/*.jpg"):
        im_array = im_toarray(frame)
        X_.append(im_array)
        Y_.append(label)

    
    return normalize(np.array(X_)),np.array(Y_)

def calculate_loss(out,real):

    log_loss = torch.mean(-real*torch.log(out))
    mean_loss = torch.mean((out-real)**2)
    
    loss = (mean_loss + log_loss)/2.0

    return log_loss

In [9]:
class StackedLSTM(nn.Module):
    def __init__(self):
        super(StackedLSTM, self).__init__()
        self.lstm1 = LSTMCell(3*224*224,99,True).to(device).float()
        self.lstm2 = LSTMCell(99,99,True).to(device).float()
    
    def forward(self,x,state1,state2):
        c1,h1 = self.lstm1(x,state1)
        c2,h2 = self.lstm2(h1,state2)
        soft_out = F.softmax(h2,dim=1)
        
        next_state1 = (c1,h1)
        next_state2 = (c2,h2)
        
        return next_state1,next_state2,soft_out
        
        
stackLSTM = StackedLSTM()

In [34]:

folder1 = "Frames/Train/ApplyLipstick/1"
X1,Y1 = create_sets(folder1,78)
X1 = X1.reshape(-1,3,224,224)
Y1 = one_hot(Y1,99)
X1,Y1 = array_to_tensor(X1).to(device).float(),array_to_tensor(Y1).to(device).float()
optimizer = torch.optim.Adam(stackLSTM.parameters(),lr=1.e-2)

In [None]:
for epoch in range(1,501):

    inp = X1
    real = Y1
    state1 = None
    state2 = None
    epoch_loss = 0
    for i in range(inp.size(0)):
        state1,state2,out = stackLSTM(inp[i:i+1],state1,state2)
        loss = calculate_loss(out,real[i:i+1])
        epoch_loss += loss
        
    epoch_loss_av = epoch_loss/i
        
    optimizer.zero_grad()
    epoch_loss_av.backward()
    optimizer.step()
    
    print("epoch: "+str(epoch))
    print(epoch_loss_av.item())
    print(out.argmax().item(),real.argmax().item())
    print(out[:,out.argmax().item()].item()*100)
    print("*******")