In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from skimage import io
import os

In [2]:
path = '/home/jeet/WEBEmo/category.txt'

with open(path, 'r') as f:
    content = f.readlines()
    
content = [elem.strip('\n').split(',') for elem in content]
print (content)

[['affection', 'love', '+'], ['cheerfullness', 'joy', '+'], ['confusion', 'confusion', '-'], ['contentment', 'joy', '+'], ['disappointment', 'sadness', '-'], ['disgust', 'anger', '-'], ['enthrallment', 'joy', '+'], ['envy', 'anger', '-'], ['exasperation', 'anger', '-'], ['gratitude', 'love', '+'], ['horror', 'fear', '-'], ['irritabilty', 'anger', '-'], ['lust', 'love', '+'], ['neglect', 'sadness', '-'], ['nervousness', 'fear', '-'], ['optimism', 'joy', '+'], ['pride', 'joy', '+'], ['rage', 'anger', '-'], ['relief', 'joy', '+'], ['sadness', 'sadness', '-'], ['shame', 'sadness', '-'], ['suffering', 'sadness', '-'], ['surprise', 'surprise', '+'], ['sympathy', 'sadness', '-'], ['zest', 'joy', '+']]


In [3]:
level1 = dict()
level2 = dict()

for i, elem in enumerate(content):
    if elem[2] not in level1.keys():
        level1[elem[2]] = []
    
    level1[elem[2]].append(i)
    
    if elem[1] not in level2.keys():
        level2[elem[1]] = []
        
    level2[elem[1]].append(i)
    
print (level1)
print (level2)

{'+': [0, 1, 3, 6, 9, 12, 15, 16, 18, 22, 24], '-': [2, 4, 5, 7, 8, 10, 11, 13, 14, 17, 19, 20, 21, 23]}
{'love': [0, 9, 12], 'joy': [1, 3, 6, 15, 16, 18, 24], 'confusion': [2], 'sadness': [4, 13, 19, 20, 21, 23], 'anger': [5, 7, 8, 11, 17], 'fear': [10, 14], 'surprise': [22]}


In [4]:
level1.items()

dict_items([('+', [0, 1, 3, 6, 9, 12, 15, 16, 18, 22, 24]), ('-', [2, 4, 5, 7, 8, 10, 11, 13, 14, 17, 19, 20, 21, 23])])

#### Custom Dataset Loader

In [5]:
# Function to extract the label of the folder 
def get_key(label_dict, val):
    for key, val_list in label_dict.items():
        if val in val_list:
            return key
        
# Function to make the dataset. Returns list of tuple (path, label) for the image
def make_dataset(root_dir, label_dict):
    images = []
    for target in sorted(os.listdir(root_dir)):
        d = os.path.join(root_dir, target)
        
        try :
            int(target)
        except:
            continue
        
        label = get_key(label_dict, int(target))

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                item = (path, label)
                images.append(item)

    return images

# Helper function to load the images given the path of the image
def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
# Attribute of the class Level1ImageDataSet    
def find_classes(root_dir, label_dict):
    classes = []
    
    for label_dir in sorted(os.listdir(root_dir)):
        try:
            int(label_dir)
            classes.append(get_key(label_dict, int(label_dir)))
        except:
            continue
    
    classes = set(classes)
    return classes

In [6]:
classes = find_classes('/home/jeet/WEBEmo/train',level1)
print (classes)

{'-', '+'}


In [7]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [8]:
class Level1ImageDataset(Dataset):
    
    def __init__(self, root_dir, label_dict, transform=None):
        super(Level1ImageDataset, self).__init__()
        
        samples = make_dataset(root_dir, label_dict)
        classes = find_classes(root_dir, label_dict)
        
        self.root_dir = root_dir
        self.transform = transform
        self.label_dict = label_dict
        self.samples = samples
        self.classes = classes
        
    def __len__(self):
        return (len(self.samples))
    
    def __getitem__(self, index):
        
        path, label = self.samples[index]
        sample = pil_loader(path)
        
        if self.transform is not None:
            sample = self.transform(sample)
        
        return sample, label

In [9]:
data_dir = '/home/jeet/WEBEmo/'

dset_l1 = {x: Level1ImageDataset(os.path.join(data_dir, x), level1, data_transforms[x])
         for x in ['train', 'test']}

In [10]:
print (dset_l1['train'])

<__main__.Level1ImageDataset object at 0x7f199e6fd908>


In [11]:
dset_loaders = {x: torch.utils.data.DataLoader(dset_l1[x], batch_size=32, shuffle=True, num_workers=16)
                for x in ['train', 'test']}

dset_sizes = {x: len(dset_l1[x]) for x in ['train', 'test']}
print (dset_sizes)

dset_classes = dset_l1['train'].classes
print (dset_classes)

{'train': 213952, 'test': 53489}
{'-', '+'}
