In [4]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch import nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

import cv2
from cv2 import imread, imwrite
from tqdm import tqdm

from sklearn.model_selection import train_test_split

In [5]:
def parse_data(data: pd.DataFrame):
    """
    Reads input dataframe then return arrays of images and labels
    """
    image_array = np.zeros(shape=(len(data), 48, 48))
    image_label = np.array(list(map(int, data['emotion'])))

    for i, row in enumerate(data.index):
        image = np.fromstring(data.loc[row, 'pixels'], dtype=int, sep=' ')
        image = np.reshape(image, (48, 48))
        image_array[i] = image

    return image_array, image_label

def show_img(images: np.ndarray, labels: np.ndarray):
    """
    Visualize images and labels respectively
    """
    _, axarr=plt.subplots(nrows=2, ncols=5, figsize=(18, 9))
    axarr=axarr.flatten()
    for idx, label in enumerate(labels[:10]):
        axarr[idx].imshow(images[idx], cmap='gray')
        axarr[idx].set_xticks([])
        axarr[idx].set_yticks([])
        axarr[idx].set_title("Label:{}".format(label))

In [6]:
df = pd.read_csv("../data/icml_face_data.csv")
images, labels = parse_data(df)

In [7]:
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=42)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

In [8]:
index = 1
for image, label in tqdm(zip(X_train, y_train), desc="train data"):
    imwrite(f'../data/images/train/{label}/{index}.png', image)
    index += 1
for image, label in tqdm(zip(X_valid, y_valid), desc="validation data"):
    imwrite(f'../data/images/valid/{label}/{index}.png', image)
    index += 1
for image, label in tqdm(zip(X_test, y_test), desc="test data"):
    imwrite(f'../data/images/test/{label}/{index}.png', image)
    index += 1

train data: 22967it [00:03, 6725.11it/s]
validation data: 5742it [00:00, 7663.92it/s]
test data: 7178it [00:00, 8331.89it/s]
