In [None]:
import os
import re
import copy
import pickle
import datetime

import numpy as np
import matplotlib.pyplot as plt
from skimage.measure import block_reduce
from skimage.transform import resize

import importlib
from IPython.display import clear_output

import data_utils
from CGAN import CGAN
from Params import Params

## Load and preprocess the data

In [None]:
with open("./data/scaler.pkl", "rb") as f:
    scaler = pickle.load(f)
with open("./data/table.pkl", "rb") as f:
    lookup_table = pickle.load(f)
face_data = np.load("./data/face_data.npy")
landmarks = np.load("./data/landmarks.npy")

# Visualize the data
data_utils.visualize_z(face_data, z_channel=2)

## Train the model

In [None]:
X = face_data[:, :, :, 2:3]
network = CGAN(Params(X))

In [None]:
def train(network, X, X_cond=None, train_steps=500, prev_steps=0, interval=5, suffix=None):
    
    for step in range(train_steps):

        for i in range(Params.steps_D):
            
            indices = np.random.randint(X.shape[0], size=(Params.batch_size))
            noise = np.random.uniform(size=(Params.batch_size, Params.n_rand))
            if X_cond is None:
                z = noise
            else:
                z = np.concatenate([noise, X_cond[indices]], axis=1)
            X_real = X[indices]
            X_fake = network.G.predict(z)
            
            X_D = np.concatenate((X_real, X_fake))
            Y_D = np.concatenate((np.zeros((X_real.shape[0], 1)) + Params.real_l,
                                  np.zeros((X_fake.shape[0], 1)) + Params.fake_l))
            
            loss_D, acc_D = network.D.train_on_batch(X_D, Y_D)            
        
        for i in range(Params.steps_GD):
            X_GD = np.random.uniform(size=(Params.batch_size, Params.n_rand))
            Y_GD = np.ones((X_GD.shape[0], 1))
            loss_GD, acc_GD = network.GD.train_on_batch(X_GD, Y_GD)

        if step % (5 * interval) == 0 and step > 0:
            clear_output()

        if step % interval == 0 and step > 0:
            print("Step {}:".format(step + prev_steps))
            print()
            data_utils.visualize_z(X_fake)
            print()
            print("Descriminator :: loss = {}, acc = {}".format(loss_D, acc_D))
            print("Adversarial   :: loss = {}, acc = {}".format(loss_GD, acc_GD))

    if suffix is None:
        suffix = str(datetime.date.today())
    with open("./models/CGAN-{}-{}.pkl".format(prev_steps + train_steps, suffix), "wb") as f:
        pickle.dump(network, f)

In [None]:
train(network, X)

## Things to try

* Penalize overconfidence by labeling real images with 0.9 (Better GAN training)

* Switch to Adam optimizer (DCGAN paper)

* Switch from max-pool and upsampling to convolution stride (DCGAN)

* Add conditional data to random noise (consider label as one hot) to account for different topology of the data, impede "averaging" of faces w.r.t. facial expressions. (Info-GAN ??)

* Remove dropout from generator (not done)

* Make discriminator more complex (harder to learn, smarter, made same level as generator)

In [None]:
for fname in sorted(os.listdir("./models/")):
    if fname.endswith(".pkl"):
        with open(os.path.join("./models/", fname), "rb") as f:
            model = pickle.load(f)
            X_fake = model.G.predict(np.random.uniform(size=(model.params.batch_size, model.params.n_rand)))
            print("Model:", fname)
            data_utils.visualize_z(X_fake)
            del model