In [3]:
# import libraries

import numpy as np
import pandas as pd

from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

from sklearn.model_selection import StratifiedShuffleSplit

from tqdm import tqdm, trange
import matplotlib.pyplot as plt

from os.path import basename

In [5]:
# helper
def show_in_row(list_of_images: list, titles: list = None, disable_ticks: bool = False):
    count = len(list_of_images)
    for idx in range(count):
        subplot = plt.subplot(1, count, idx+1)
        if titles is not None:
            subplot.set_title(titles[idx])
      
        img = list_of_images[idx]
        cmap = 'gray' if (len(img.shape) == 2 or img.shape[2] == 1) else None
        subplot.imshow(img, cmap=cmap)
        if disable_ticks:
            plt.xticks([]), plt.yticks([])
    plt.show()

In [58]:
RANDOM_STATE = 42
# Data

DATA_RAW_ROOT = '../data/raw'

TEST_SIZE = 0.1
TRAIN_BATCH = 64
VALID_BATCH = 1024
IMG_SIZE = (256, 256)

# todo: ? add randomized augmentation as transform for train samples ?

dataset_transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])




In [60]:
from pathlib import Path
from torch.utils.data import Dataset
import enum
from collections import defaultdict
from PIL import Image
from skimage import io

RAW_JAFFE_DATA_DIR = Path('../data/raw/jaffe')

class Expression(enum.Enum):
    NEUTRAL  = 'NE'
    ANGRY    = 'AN'
    DISGUST  = 'DI'
    FEAR     = 'FE'
    HAPPY    = 'HA'
    SAD      = 'SA'
    SURPRISE = 'SU'
    
    
ExpEnc = {exp:idx for idx, exp in enumerate(Expression)}
ExpDec = {idx:exp for idx, exp in enumerate(Expression)}



class JAFFEDataset(Dataset):
    def __init__(self, data_dir=RAW_JAFFE_DATA_DIR, expression=None):
        """
        :param data_dir (Path)         : Path to images
        :param expression (Expression) : Enum member, expression label
        """
        
        self.data_dir = Path(data_dir)
        self.samples = self.build_dict()
        
        self.transform = dataset_transform
        
    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        desc = self.samples[idx]
        image_path = desc['path']
        image = Image.fromarray(io.imread(image_path, as_gray=True))
        
        if self.transform:
            image = self.transform(image)
            
        return (image, desc)
    
    def build_dict(self):
        samples = []
        for entry in self.data_dir.iterdir():
            tokens = entry.name.split('.')
            if tokens[-1] != 'tiff':
                continue
            
            exp = tokens[1][:2]
            iden = tokens[0]
            
            samples.append({
                'path': entry,
                'exp' : Expression(exp),
                'iden': iden,
            })
            
        return samples
    
    def exp_split(self, test_size=0.1, random_state=42):
        exps = []
        for s in self.samples:
            exps.append(s['exp'])
        
        
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
        splits = splitter.split(X=np.arange(len(exps), dtype=np.int), y=[ExpEnc[e] for e in exps])
        
        train_idx, test_idx = next(splits)
        
        train_ds = Subset(self, train_idx)
        test_ds = Subset(self, test_idx)
        
        return train_ds, test_ds
        
            
# test
jaffe = JAFFEDataset()
samples = jaffe.samples
# print('\n'.join([f'{s}' for s in samples]))
one,two = jaffe.exp_split()

jaffe[0]
one[0]

(tensor([[[-0.1059, -0.1059, -0.0902,  ..., -0.0039, -0.0196, -0.0039],
          [-0.1451, -0.0824, -0.1608,  ..., -0.0353, -0.0745,  0.0275],
          [-0.1451, -0.1059, -0.0824,  ..., -0.0196,  0.0275,  0.0275],
          ...,
          [-0.4980, -0.5137, -0.5529,  ..., -0.4275, -0.4353, -0.4667],
          [-0.5373, -0.5686, -0.5529,  ..., -0.4039, -0.4588, -0.4275],
          [-0.4745, -0.5843, -0.5451,  ..., -0.4588, -0.4588, -0.3961]]]),
 {'path': PosixPath('../data/raw/jaffe/TM.HA3.182.tiff'),
  'exp': <Expression.HAPPY: 'HA'>,
  'iden': 'TM'})

In [69]:
class JAFFEDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
    
    def prepare_data(self):
        if not RAW_JAFFE_DATA_DIR.exists():
            print('JAFFE data not found')
        else:
            print('JAFFE data found')
    
    def setup(self, scenario = None, stage=None):
        self.dataset = JAFFEDataset()
        
        if scenario == 'exp' or scenario is None:
            self.jaffe_train, self.jaffe_test = self.dataset.exp_split()
            
    def train_dataloader(self):
        return DataLoader(self.jaffe_train, batch_size=self.batch_size, shuffle=True)

#     def val_dataloader(self):
#         return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.jaffe_test, batch_size=self.batch_size, shuffle=False)

    
dm = JAFFEDataModule()
dm.prepare_data()
dm.setup()
dl = dm.train_dataloader()


JAFFE data found


TypeError: 'DataLoader' object is not an iterator