# IMDb dataset

In [1]:
%matplotlib inline

import sys
import os
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.offsetbox import TextArea, DrawingArea, OffsetImage, AnnotationBbox
import logging
import torch
import pandas
import PIL.Image

logging.basicConfig(
    format="%(asctime)-5.5s %(name)-30.30s %(levelname)-7.7s %(message)s",
    datefmt="%H:%M",
    level=logging.INFO,
)

sys.path.append("../../")
logging.basicConfig(
    format="%(asctime)-5.5s %(name)-30.30s %(levelname)-7.7s %(message)s",
    datefmt="%H:%M",
    level=logging.DEBUG,
)
# from experiments.datasets import FFHQStyleGAN2DLoader
from experiments.architectures.image_transforms import create_image_transform
from experiments.architectures.vector_transforms import create_vector_transform
from manifold_flow.flows import ManifoldFlow

## Options

In [2]:
transform = True


## Preprocessing: Transform images and age info to numpy arrays (only need to do this once)

In [None]:
test_fraction = 0.1


def process_images(filenames, ages, category="train", subfolders=True, basedir=f"../data/samples/imdb"):
    filenames_out = []
    ages_out = []
    
    for i, (filename, age) in enumerate(zip(filenames, ages)):
        img = PIL.Image.open(f"{basedir}/raw/{filename}")
        
        dims = np.array(img).shape
        
        if len(dims) != 3 or dims[2] != 3 or dims[0] != dims[1] or dims[0] < 64:  # Let's skip b/w and non-square images (the latter are often corrupted)
            continue
            
        if age < 20 or age > 80:  # Let's limit it to this range
            continue
            
        img = img.resize((64, 64), PIL.Image.ANTIALIAS)
            
        folder = f"{category}/{(i // 1000):03d}" if subfolders else category
        img_filename_out = f"{folder}/{category}_{i:05d}.png"
        
        os.makedirs(f"{basedir}/{folder}", exist_ok=True)
        img.save(f"{basedir}/{img_filename_out}")
        filenames_out.append(img_filename_out)
        ages_out.append(age)
    
    df_out = pandas.DataFrame({'age':ages_out, 'filename':filenames_out})
    df_out.to_csv(f"{category}.csv")
    
    return df_out


if transform:
    df = pandas.read_csv("../data/samples/imdb/raw.csv")
    n = len(df)
    n_test = int(round(n * test_fraction))
    
    np.random.seed(81357)
    idx = list(range(n))
    np.random.shuffle(idx)
    idx_train, idx_test = idx[n_test:], idx[:n_test]
    
    age_train = np.array(df["age"])[idx_train]
    age_test = np.array(df["age"])[idx_test]
    paths_train = np.array(df["path"])[idx_train]
    paths_test = np.array(df["path"])[idx_test]
    
    df_train = process_images(paths_train, age_train, "train")
    df_test = process_images(paths_test, age_test, "test", subfolders=False)


In [None]:
fig = plt.figure(figsize=(5,5))

plt.hist(df_train["age"], range=(19.5, 80.5), bins=61, histtype="step", color="C0", label="Train", density=True)
plt.hist(df_test["age"], range=(19.5, 80.5), bins=61, histtype="step", color="C1", label="Test", density=True)

plt.xlim(0., 100.)
plt.ylim(0., None)

plt.tight_layout()
plt.show()
