In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

In [2]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

batch_size = 64

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
CLS = ['neutral', 'happiness', 'surprise', 'sadness', 'anger', 'disgust', 'fear', 'contempt']
NUM_CLS = len(CLS)
CLS_DICT = {cl:i for i, cl in enumerate(CLS)}
CLS_DICT

{'neutral': 0,
 'happiness': 1,
 'surprise': 2,
 'sadness': 3,
 'anger': 4,
 'disgust': 5,
 'fear': 6,
 'contempt': 7}

In [7]:
class FerPlusDataset(Dataset):
    def __init__(self, split, transform, mode='mv'):
        self.root = f'FER2013{split}'
        self.transform = transform
        self.mode = mode

        column_names = [
            'image_name', 'image_tensor',
            'neutral', 'happiness', 'surprise', 'sadness', 'anger', 'disgust', 'fear', 'contempt',
            'unknown', 'NF'
        ]

        df = pd.read_csv(os.path.join(self.root, 'label.csv'), names=column_names, header=None)
        df = df[['image_name', 'neutral', 'happiness',
                 'surprise', 'sadness', 'anger',
                 'disgust', 'fear', 'contempt']]

        self.labels = df[CLS].values
        self.images = df['image_name'].values

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.images[idx])
        img = Image.open(img_path)
        x = self.transform(img)

        p = self.labels[idx]

        if self.mode == 'mv':
            y = torch.tensor(np.argmax(p))

        elif self.mode == 'pld':
            y = torch.tensor(np.random.choice(NUM_CLS, p=p))

        elif self.mode == "cel":
            y = torch.tensor(p)

        elif self.mode == "ml":
            mask = (p > self.theta)
            y = torch.tensor(mask)

        else:
            raise ValueError(f"Unknown mode {self.mode}")

        return x, y

In [9]:
train_tf = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

val_tf = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

In [10]:
def make_loaders(mode):
    train_ds = FerPlusDataset(split='Train', transform=train_tf, mode=mode)
    val_ds = FerPlusDataset(split='Valid', transform=train_tf, mode=mode)
    test_ds = FerPlusDataset(split='Test', transform=train_tf, mode=mode)

    return (
        DataLoader(train_ds, batch_size=batch_size,
                   shuffle=True, num_workers=1,
                   pin_memory=torch.cuda.is_available()),
        DataLoader(val_ds, batch_size=batch_size,
                   shuffle=False, num_workers=1,
                   pin_memory=torch.cuda.is_available()),
        DataLoader(test_ds, batch_size=batch_size,
                   shuffle=False, num_workers=1,
                   pin_memory=torch.cuda.is_available()),
    )

In [11]:
def make_model():
    model = models.vgg13_bn(weights=models.VGG13_BN_Weights.IMAGENET1K_V1)

    return model.to(device)