In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torch.nn.functional as F
import glob
import os
from skimage.transform import resize
import cv2
import itertools
import warnings
warnings.filterwarnings('ignore')
import skimage as sk
import logging
import matplotlib.pyplot as plt
from PIL import Image

logging.basicConfig(filename='cnn_extract.log',level=logging.INFO)

In [None]:
normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
   transforms.Scale(256),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   normalize
])


In [None]:
def test(pairs,batch_size,num_classes,model):
    pair_loss = 0.0
    pair_corrects = 0
    pairs = list(pairs)
    label = 0
    k = 0

    all_frame_list = []
    for folder in pairs:
        if folder is not None:
            X,Y,frame_list = create_sets(folder,label)
            all_frame_list.append(frame_list)
            if k == 0:
                X_batch = X.copy()
                Y_batch = Y.copy()
            else:
                X_batch = np.concatenate([X_batch,X],axis=0)
                Y_batch = np.concatenate([Y_batch,Y],axis=0)
            k += 1

        label += 1
    all_frame_list  = [val for sublist in all_frame_list for val in sublist]


    num_samples = Y_batch.shape[0]
    sub_batches = create_batches(X_batch,Y_batch,all_frame_list,batch_size)
    for X_minibatch,Y_minibatch,frame_minibatch in sub_batches:

        Y_minibatch = one_hot(Y_minibatch,num_classes)
        X_minibatch, Y_minibatch = torch.from_numpy(X_minibatch).to(device).float(), torch.from_numpy(Y_minibatch.reshape(-1,num_classes)).to(device).float()
        output = model(X_minibatch)

        loss,correct = evaluate(output,Y_minibatch)
            
        pair_loss += loss.item()
        pair_corrects += correct
        #save_features(frame_minibatch,output)

    return pair_loss,pair_corrects,num_samples

In [None]:

def model():

    model = models.resnet34(pretrained=True)

    for param in model.parameters():
        param.requires_grad = False

    p = 0
    for child in model.children():
        if p >= 6:
            for param in child.parameters():
                param.requires_grad = True
        p += 1

    return model


class RESNET(nn.Module):
    def __init__(self, original_model):
        super(RESNET, self).__init__()
        self.feature_extractor = nn.Sequential(*list(original_model.children())[:-1])
        self.fc1 = nn.Linear(in_features=512,out_features=99,bias=True)
        #self.fc2 = nn.Linear(in_features=512,out_features=99,bias=True)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self,x):
        (m,nc,nw,nh) = x.size()
        cnn = self.feature_extractor(x).view(m,-1)
        fc1 = self.fc1(cnn)
        #relu1 = F.relu(fc1)
        drop1 = self.dropout(fc1)
        #fc2 = self.fc2(drop1)
        out = F.log_softmax(drop1,dim=1)
        return out

def create_batches(X,Y,all_frame_list,batch_size,shuffle=True):
    l = Y.shape[0]
    index = np.random.permutation(l)
    X = X[index,:,:,:]
    Y = Y[index,:]
    all_frame_list = np.array(all_frame_list)
    all_frame_list = all_frame_list[index]
    #print(len(Y))
    for i in range(0,Y.shape[0],batch_size):
        yield(X[i:i+batch_size,:,:,:],Y[i:i+batch_size,:],all_frame_list[i:i+batch_size])


def save_features(frame_array,conv_features):
    for k in range(len(frame_array)):
        jpg_path = frame_array[k]
        npy_path = jpg_path[:-3] + "npy"
        array = conv_features[k]
        #np.save(npy_path,array.cpu().detach().numpy())



def train_model(model,nb_epoch,learning_rate,paired_training_list,paired_test_list,device,batch_size):

    optims = filter(lambda p: p.requires_grad,model.parameters())
    optimizer = torch.optim.Adam(optims,lr=learning_rate)
    logging.info("There are total number of %d pairs in training set"%len(paired_training_list))
    logging.info("There are total number of %d pairs in test set"%len(paired_test_list))
    num_classes = 99

    for epoch in range(1,nb_epoch+1):

        train_epoch_loss = 0.0
        train_epoch_correct = 0
        
        test_epoch_loss = 0.0
        test_epoch_correct = 0
        
        pair_nu = 1
        
        train_total_image = 0
        test_total_image = 0
        
        len_train = len(paired_training_list)
        len_test = len(paired_test_list)
    
        
        
        for k in range(len_train):
            train_pairs = paired_training_list[k]
            test_pairs = paired_test_list[k%len_test]
            
            pair_training_loss,pair_training_corrects,train_num_samples = train(train_pairs,batch_size,num_classes,optimizer,model)
            logging.info("pair: "+str(pair_nu))
            logging.info("pair training loss:" +str(pair_training_loss/train_num_samples))
            logging.info("pair training accuracy: "+str(100*pair_training_corrects/train_num_samples)+" "+str(pair_training_corrects)+"/"+str(train_num_samples))
            logging.info("********")
            pair_nu += 1
            train_epoch_loss += pair_training_loss
            train_epoch_correct += pair_training_corrects
            train_total_image += train_num_samples
            
            
            pair_test_loss,pair_test_corrects,test_num_samples = test(test_pairs,batch_size,num_classes,model)
            logging.info("pair test loss:" +str(pair_test_loss/test_num_samples))
            logging.info("pair test accuracy: "+str(100*pair_test_corrects/test_num_samples)+" "+str(pair_test_corrects)+"/"+str(test_num_samples))
            logging.info("********")
            test_epoch_loss += pair_test_loss
            test_epoch_correct += pair_test_corrects
            test_total_image += test_num_samples


        logging.info("epoch: "+str(epoch))
        logging.info("epoch training loss: " +str(train_epoch_loss/train_total_image))
        logging.info("epoch training accuracy: "+str(100*train_epoch_correct/train_total_image)+" "+str(train_epoch_correct)+"/"+str(train_total_image))
        logging.info("epoch test loss: " +str(test_epoch_loss/test_total_image))
        logging.info("epoch test accuracy: "+str(100*test_epoch_correct/test_total_image)+" "+str(test_epoch_correct)+"/"+str(test_total_image))
        logging.info("-----------------")
        logging.info("-----------------")
        
        save_model(epoch,model,optimizer)

        



def evaluate(out,real):
    #print(out.size(),real.size())
    log_loss = torch.sum(-real*out)
    mean_loss = torch.sum((out-real)**2)

    loss = (mean_loss + log_loss)/2.0

    real_arg = torch.argmax(real,dim=1)
    out_arg = torch.argmax(out,dim=1)
    #print(sum(out_arg==real_arg).item())
    correct = sum(out_arg==real_arg).item()
    return log_loss,correct

In [None]:

def train(pairs,batch_size,num_classes,optimizer,model):

        pair_loss = 0.0
        pair_corrects = 0
        pairs = list(pairs)
        label = 0
        k = 0

        all_frame_list = []
        for folder in pairs:
            if folder is not None:
                X,Y,frame_list = create_sets(folder,label)
                all_frame_list.append(frame_list)
                if k == 0:
                    X_batch = X.copy()
                    Y_batch = Y.copy()
                else:
                    X_batch = np.concatenate([X_batch,X],axis=0)
                    Y_batch = np.concatenate([Y_batch,Y],axis=0)
                k += 1

            label += 1
        all_frame_list  = [val for sublist in all_frame_list for val in sublist]


        num_samples = Y_batch.shape[0]
        sub_batches = create_batches(X_batch,Y_batch,all_frame_list,batch_size)
        for X_minibatch,Y_minibatch,frame_minibatch in sub_batches:

            Y_minibatch = one_hot(Y_minibatch,num_classes)
            X_minibatch, Y_minibatch = torch.from_numpy(X_minibatch).to(device).float(), torch.from_numpy(Y_minibatch.reshape(-1,num_classes)).to(device).float()
            output = model(X_minibatch)
            #print(output.size(),Y_minibatch.size())

            loss,correct = evaluate(output,Y_minibatch)
            
            av_loss = loss/Y_minibatch.size(0)
            
            optimizer.zero_grad()
            av_loss.backward()
            optimizer.step()
            
            pair_loss += loss.item()
            pair_corrects += correct
            #save_features(frame_minibatch,output)

        return pair_loss,pair_corrects,num_samples

In [None]:
def create_pair_lists(train_test_path):

    class_names = [names for names in glob.glob(train_test_path+"/*")]
    #print(class_names)

    a = []
    for names in class_names:
        folder = []
        for folders in glob.glob(names+"/*"):
            folder.append(folders)
        a.append(folder)


    paired_folders = itertools.zip_longest(*a)

    return list(paired_folders)

def im_toarray(im):
    image = Image.open(im)
    im_array = preprocess(image).numpy()
    return im_array

def create_sets(folder,label):
    X_ = []
    Y_ = []
    frame_list = []
    for frame in glob.glob(folder+"/*.jpg"):
        im_array = im_toarray(frame)
        X_.append(im_array)
        Y_.append(label)
        frame_list.append(frame)

    return np.array(X_),np.array(Y_).reshape(-1,1),frame_list


def one_hot(array, num_classes):
    return np.squeeze(np.eye(num_classes)[array.reshape(-1)])


In [None]:
def save_model(epoch,model,optimizer):
    state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    }
    torch.save(state, "cnn_extract_model.pth")

In [None]:
pair_training = create_pair_lists(train_test_path= "Frames/Train")
pair_test = create_pair_lists(train_test_path= "Frames/Test")

device = torch.device("cuda")
model = model()
resnet = RESNET(model).to(device).float()



train_model(resnet,25,1.e-2,pair_training,pair_test,device,32)