In [1]:
import cv2
import glob
import io
import matplotlib.pyplot as plt
import os
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
class EigenfacesDataset(Dataset):

    def __init__(self, data_dir):
        """
        Arguments:
            data_dir (string): Directory with all the images.
        """
        self.img_path_list = [
            i for i in sorted(glob.glob("{}/**/*.pgm".format(data_dir))) \
            if self.get_labels(i)['scale'] == 1]

    def __len__(self):
        return len(self.img_path_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.img_path_list[idx]
        image = cv2.imread(img_path)
        labels = self.get_labels(img_path)
        item = labels.copy()
        item.pop('scale')
        item['image'] = image

        return item

    def get_labels(self, img_path):
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        img_name_split = img_name.split('_')

        user_id = img_name_split[0]
        head_position = img_name_split[1]
        facial_expression = img_name_split[2]
        eye_state = img_name_split[3]
        
        if len(img_name_split) == 5:
            scale = int(img_name_split[4])
        else:
            scale = 1
        
        labels = {
            'user_id': user_id,
            'head_position': head_position,
            'facial_expression': facial_expression,
            'eye_state': eye_state,
            'scale': scale}
        return labels

In [3]:
eigenfaces_dataset = EigenfacesDataset(data_dir='../data')
eigenfaces_dataloader = DataLoader(eigenfaces_dataset)

In [4]:
next(iter(eigenfaces_dataloader))

{'user_id': ['an2i'],
 'head_position': ['left'],
 'facial_expression': ['angry'],
 'eye_state': ['open'],
 'image': tensor([[[[34, 34, 34],
           [ 3,  3,  3],
           [ 1,  1,  1],
           ...,
           [52, 52, 52],
           [52, 52, 52],
           [52, 52, 52]],
 
          [[47, 47, 47],
           [21, 21, 21],
           [ 0,  0,  0],
           ...,
           [52, 52, 52],
           [52, 52, 52],
           [53, 53, 53]],
 
          [[47, 47, 47],
           [ 6,  6,  6],
           [ 1,  1,  1],
           ...,
           [52, 52, 52],
           [53, 53, 53],
           [55, 55, 55]],
 
          ...,
 
          [[ 0,  0,  0],
           [ 0,  0,  0],
           [ 0,  0,  0],
           ...,
           [60, 60, 60],
           [ 0,  0,  0],
           [ 0,  0,  0]],
 
          [[ 0,  0,  0],
           [ 0,  0,  0],
           [ 0,  0,  0],
           ...,
           [60, 60, 60],
           [ 4,  4,  4],
           [ 0,  0,  0]],
 
          [[ 0,  0,  0