In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms, models
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from skimage import io
import os
from PIL import Image
import numpy as np

In [2]:
path = '/home/jeet/WEBEmo/category.txt'

with open(path, 'r') as f:
    content = f.readlines()
    
content = [elem.strip('\n').split(',') for elem in content]   

for elem in content:
    if elem[2] == '+':
        elem[2] = 'positive'
    else:
        elem[2] = 'negative'
        
print (content)

[['affection', 'love', 'positive'], ['cheerfullness', 'joy', 'positive'], ['confusion', 'confusion', 'negative'], ['contentment', 'joy', 'positive'], ['disappointment', 'sadness', 'negative'], ['disgust', 'anger', 'negative'], ['enthrallment', 'joy', 'positive'], ['envy', 'anger', 'negative'], ['exasperation', 'anger', 'negative'], ['gratitude', 'love', 'positive'], ['horror', 'fear', 'negative'], ['irritabilty', 'anger', 'negative'], ['lust', 'love', 'positive'], ['neglect', 'sadness', 'negative'], ['nervousness', 'fear', 'negative'], ['optimism', 'joy', 'positive'], ['pride', 'joy', 'positive'], ['rage', 'anger', 'negative'], ['relief', 'joy', 'positive'], ['sadness', 'sadness', 'negative'], ['shame', 'sadness', 'negative'], ['suffering', 'sadness', 'negative'], ['surprise', 'surprise', 'positive'], ['sympathy', 'sadness', 'negative'], ['zest', 'joy', 'positive']]


In [3]:
level1 = dict()
level2 = dict()

for i, elem in enumerate(content):
    
    if elem[2] not in level1.keys():
        level1[elem[2]] = []
    
    level1[elem[2]].append(i)
    
    if elem[1] not in level2.keys():
        level2[elem[1]] = []
        
    level2[elem[1]].append(i)
    
# print (level1)
# print (level2)

#### Custom Dataset Loader

In [4]:
# Function to extract the label of the folder 
def get_key(label_dict, val):
    for key, val_list in label_dict.items():
        if val in val_list:
            return key
        
# Function to make the dataset. Returns list of tuple (path, label) for the image
def make_dataset(root_dir, label_dict, class_to_idx):
    images = []
    for target in sorted(os.listdir(root_dir)):
        d = os.path.join(root_dir, target)
        
        try :
            int(target)
        except:
            continue
        
        label = get_key(label_dict, int(target))
        label = class_to_idx[label]

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                item = (path, label)
                images.append(item)

    return images

# Helper function to load the images given the path of the image
def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
# Attribute of the class Level1ImageDataSet    
def find_classes(root_dir, label_dict):
    classes = []
    
    for label_dir in sorted(os.listdir(root_dir)):
        try:
            int(label_dir)
            classes.append(get_key(label_dict, int(label_dir)))
        except:
            continue
    
    classes = list(set(classes))
    
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

In [5]:
class Level1ImageDataset(Dataset):
    
    def __init__(self, root_dir, label_dict, transform=None):
        super(Level1ImageDataset, self).__init__()
        
        classes, class_to_idx = find_classes(root_dir, label_dict)
        samples = make_dataset(root_dir, label_dict, class_to_idx)
        
        self.root_dir = root_dir
        self.transform = transform
        self.label_dict = label_dict
        
        self.samples = samples
        self.classes = classes
        self.class_to_idx = class_to_idx
        
    def __len__(self):
        return (len(self.samples))
    
    def __getitem__(self, index):
        
        path, label = self.samples[index]
        sample = pil_loader(path)
        
        if self.transform is not None:
            sample = self.transform(sample)
        
        return sample, label
    
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root_dir)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

In [6]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [7]:
data_dir = '/home/jeet/WEBEmo/'
BATCH_SIZE = 32

dset_l1 = {x: Level1ImageDataset(os.path.join(data_dir, x), level1, data_transforms[x])
         for x in ['train', 'test']}

dset_loaders = {x: torch.utils.data.DataLoader(dset_l1[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=16)
                for x in ['train', 'test']}

dset_sizes = {x: len(dset_l1[x]) for x in ['train', 'test']}

dset_classes = dset_l1['train'].classes

In [8]:
def train(model, criterion, optimizer, num_epochs = 20):
    
    for epoch in range(num_epochs):
        model.train()
        
        best_acc = 0.0
        epoch_acc = 0.0
        epoch_loss = 0.0
        train_loss = 0.0
        train_acc = 0.0
        
        for i, (images, labels) in enumerate(dset_loaders['train']):
            
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
        
            optimizer.zero_grad()
            outputs = model(images)
            
            _, preds = torch.max(outputs, 1)
            
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
            train_acc += torch.sum(preds == labels.data)
            
            if i % 100 == 0:
                try:
                    avg = train_acc.double() / (i * BATCH_SIZE)
                    print ("Average correctly classified images till {} batches: {}".format(i, avg))
                except:
                    continue
            
        epoch_loss = train_loss / dset_sizes['train']
        epoch_acc = train_acc.double() / dset_sizes['train']
        
        print ("Epoch: {}, Epoch_Accuracy: {:.2f}, Epoch_loss: {:.4f}".format(epoch, epoch_acc, epoch_loss))
        
        # Evaluate on the test set
        test_acc = test(model)

        # Save the model if the test acc is greater than our current best
#         if test_acc > best_acc:
#             best_model_wt = model.state_dict()
#             print("Chekcpoint updated")
#             best_acc = test_acc

        # Print the metrics
        print("Test Accuracy: {:.4f}".format(test_acc))
        
    #Save the best model weights
#     torch.save(best_model_wt, "Resnet34_Transfer_Learning.model")

In [9]:
def test(model):
    model.eval()
    test_acc = 0.0
    
    for i, (images, labels) in enumerate(dset_loaders['test']):

        images = Variable(images.cuda())
        labels = Variable(labels.cuda())

        # Predict classes using images from the test set
        outputs = model(images)
        _, prediction = torch.max(outputs.data, 1)
        
        test_acc += torch.sum(prediction == labels.data)

    # Compute the average acc and loss over all test images
    test_acc = test_acc.double() / dset_sizes['test']

    return test_acc

In [10]:
model_conv = models.resnet34(pretrained=True)

for params in model_conv.parameters():
    params.requires_grad=False
    
num_features = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_features, 2)

model_conv.cuda()

criterion = nn.CrossEntropyLoss()

optimizer_conv = optim.Adam(model_conv.parameters(), lr=0.001, weight_decay=0.0001)

In [11]:
train(model_conv, criterion, optimizer_conv, num_epochs=10)

Average correctly classified images till 0 batches: inf
Average correctly classified images till 100 batches: 0.5700000000000001
Average correctly classified images till 200 batches: 0.5867187500000001
Average correctly classified images till 300 batches: 0.5970833333333334
Average correctly classified images till 400 batches: 0.59875
Average correctly classified images till 500 batches: 0.601
Average correctly classified images till 600 batches: 0.6047395833333333
Average correctly classified images till 700 batches: 0.6045982142857143
Average correctly classified images till 800 batches: 0.6066796875
Average correctly classified images till 900 batches: 0.6085069444444444
Average correctly classified images till 1000 batches: 0.61021875
Average correctly classified images till 1100 batches: 0.6107386363636363
Average correctly classified images till 1200 batches: 0.6124739583333334
Average correctly classified images till 1300 batches: 0.613798076923077
Average correctly classified i

Average correctly classified images till 4700 batches: 0.6294547872340426
Average correctly classified images till 4800 batches: 0.6293098958333334
Average correctly classified images till 4900 batches: 0.6293813775510204
Average correctly classified images till 5000 batches: 0.6294937500000001
Average correctly classified images till 5100 batches: 0.6295526960784313
Average correctly classified images till 5200 batches: 0.6294891826923077
Average correctly classified images till 5300 batches: 0.629192216981132
Average correctly classified images till 5400 batches: 0.6287789351851851
Average correctly classified images till 5500 batches: 0.6290511363636363
Average correctly classified images till 5600 batches: 0.6291183035714286
Average correctly classified images till 5700 batches: 0.629172149122807
Average correctly classified images till 5800 batches: 0.6291487068965517
Average correctly classified images till 5900 batches: 0.6291260593220339
Average correctly classified images till

Average correctly classified images till 3200 batches: 0.62658203125
Average correctly classified images till 3300 batches: 0.6266761363636364
Average correctly classified images till 3400 batches: 0.6270404411764705
Average correctly classified images till 3500 batches: 0.6266517857142858
Average correctly classified images till 3600 batches: 0.6270920138888889
Average correctly classified images till 3700 batches: 0.6275337837837838
Average correctly classified images till 3800 batches: 0.627483552631579
Average correctly classified images till 3900 batches: 0.6277323717948718
Average correctly classified images till 4000 batches: 0.6278359375
Average correctly classified images till 4100 batches: 0.6280335365853659
Average correctly classified images till 4200 batches: 0.6280877976190476
Average correctly classified images till 4300 batches: 0.6283502906976745
Average correctly classified images till 4400 batches: 0.6286221590909091
Average correctly classified images till 4500 batc

Average correctly classified images till 1100 batches: 0.6298295454545455
Average correctly classified images till 1200 batches: 0.629140625
Average correctly classified images till 1300 batches: 0.6284615384615385
Average correctly classified images till 1400 batches: 0.6284151785714286
Average correctly classified images till 1500 batches: 0.6287291666666667
Average correctly classified images till 1600 batches: 0.6290625
Average correctly classified images till 1700 batches: 0.6285845588235294
Average correctly classified images till 1800 batches: 0.6288715277777778
Average correctly classified images till 1900 batches: 0.6291940789473685
Average correctly classified images till 2000 batches: 0.630296875
Average correctly classified images till 2100 batches: 0.6301785714285715
Average correctly classified images till 2200 batches: 0.6303977272727272
Average correctly classified images till 2300 batches: 0.6304891304347826
Average correctly classified images till 2400 batches: 0.6299

Average correctly classified images till 5700 batches: 0.628766447368421
Average correctly classified images till 5800 batches: 0.6287661637931033
Average correctly classified images till 5900 batches: 0.6289247881355933
Average correctly classified images till 6000 batches: 0.6288072916666666
Average correctly classified images till 6100 batches: 0.6286680327868852
Average correctly classified images till 6200 batches: 0.6287247983870968
Average correctly classified images till 6300 batches: 0.6287896825396826
Average correctly classified images till 6400 batches: 0.6288330078125001
Average correctly classified images till 6500 batches: 0.6288317307692308
Average correctly classified images till 6600 batches: 0.6288352272727272
Epoch: 6, Epoch_Accuracy: 0.63, Epoch_loss: 0.6494
Test Accuracy: 0.6522
Average correctly classified images till 0 batches: inf
Average correctly classified images till 100 batches: 0.633125
Average correctly classified images till 200 batches: 0.62046875
Aver

Average correctly classified images till 3600 batches: 0.6295052083333333
Average correctly classified images till 3700 batches: 0.6297635135135136
Average correctly classified images till 3800 batches: 0.630016447368421
Average correctly classified images till 3900 batches: 0.6298717948717949
Average correctly classified images till 4000 batches: 0.62996875
Average correctly classified images till 4100 batches: 0.6298704268292683
Average correctly classified images till 4200 batches: 0.6300223214285714
Average correctly classified images till 4300 batches: 0.6300363372093024
Average correctly classified images till 4400 batches: 0.6297727272727273
Average correctly classified images till 4500 batches: 0.6298263888888889
Average correctly classified images till 4600 batches: 0.6297622282608696
Average correctly classified images till 4700 batches: 0.6300598404255319
Average correctly classified images till 4800 batches: 0.6304101562500001
Average correctly classified images till 4900 b