In [None]:
import os
import re
import copy
import pickle
import datetime
from collections import deque

import numpy as np
import matplotlib.pyplot as plt
from skimage.measure import block_reduce
from skimage.transform import resize

import importlib
from IPython.display import clear_output

import data_utils
import CGAN
import Params

### Load the data

In [None]:
with open("./data/lossy/table.pkl", "rb") as f:
    lookup_table = pickle.load(f)
face_data = np.load("./data/lossy/face_data.npy")
landmarks = np.load("./data/lossy/landmarks.npy")
labels = np.load("./data/lossy/labels.npy")

In [None]:
label_set, index_set = np.unique(labels, return_inverse=True)

label_mat = np.zeros((face_data.shape[0], len(label_set)))
label_mat[np.arange(face_data.shape[0]), index_set] = 1

X = face_data[:, :, :, 2:3]
X_cond = label_mat

In [None]:
label = 7
print("Expression {}: {}".format(label, label_set[label]))
data_utils.visualize_z(X[np.nonzero(X_cond[:, label] == 1)], z_channel=0)

### Train

In [None]:
def train(M, X, X_cond=None, train_steps=20, prev_steps=0, prev_hist=None, interval=5, suffix="Recent"):
    
    with open("./models/CGAN-{}-config.txt".format(suffix), "wt") as f:
        f.write(str(M))
    
    if prev_hist is None:
        history = {
            "loss_D_real": [],
            "loss_D_fake": [],
            "loss_GD": [],
            "acc_D_real": [],
            "acc_D_fake": [],
            "acc_GD": []
        }
    else:
        history = prev_hist
    
    for step in range(1, train_steps + 1):

        # Train the discriminator
        for i in range(M.params.steps_D):
            
            indices = np.random.randint(X.shape[0], size=(M.params.batch_size))
            noise = np.random.uniform(size=(M.params.batch_size, M.params.n_rand))
            if X_cond is None:
                z = noise
            else:
                z = np.concatenate([noise, X_cond[indices]], axis=1)
                
            X_real = X[indices]
            X_fake = M.G.predict(z)
            
            Y_real = np.zeros((X_real.shape[0], 1)) + M.params.real_l
            Y_fake = np.zeros((X_fake.shape[0], 1)) + M.params.fake_l

            loss_D_real, acc_D_real = M.D.train_on_batch(X_real, Y_real)
            loss_D_fake, acc_D_fake = M.D.train_on_batch(X_fake, Y_fake)

        # Train the generator
        for i in range(M.params.steps_GD):
            
            noise = np.random.uniform(size=(M.params.batch_size, M.params.n_rand))
            if X_cond is None:
                X_GD = noise
            else:
                label_mat = np.zeros((M.params.batch_size, X_cond.shape[1]))
                label_idx = np.random.randint(X_cond.shape[1], size=(M.params.batch_size))
                label_mat[np.arange(M.params.batch_size), label_idx] = 1
                X_GD = np.concatenate([noise, label_mat], axis=1)

            Y_GD = np.zeros((X_GD.shape[0], 1)) + M.params.real_l
            loss_GD, acc_GD = M.GD.train_on_batch(X_GD, Y_GD)
            
        # Update history for statistics
        history["loss_D_real"].append(loss_D_real)
        history["loss_D_fake"].append(loss_D_fake)
        history["loss_GD"].append(loss_GD)
        history["acc_D_real"].append(acc_D_real)
        history["acc_D_fake"].append(acc_D_fake)
        history["acc_GD"].append(acc_GD)

        # Clear the display output and display window statistics (graph)
        if step % (10 * interval) == 0:
            clear_output(wait=True)
            # data_utils.visualize_history(history)

        # Print window statistics
        if step % interval == 0:
            print("Step {}:".format(step + prev_steps))

            data_utils.visualize_z(X_fake)

            print("Descriminator (real) :: loss = {}, acc = {}".format(np.mean(history["loss_D_real"][-10:]),
                                                                       np.mean(history["acc_D_real"][-10:])))
            print("Descriminator (fake) :: loss = {}, acc = {}".format(np.mean(history["loss_D_fake"][-10:]),
                                                                       np.mean(history["acc_D_fake"][-10:])))
            print("Adversarial          :: loss = {}, acc = {}".format(np.mean(history["loss_GD"][-10:]),
                                                                       np.mean(history["acc_GD"][-10:])))
            print()
        
        if (step + prev_steps) % 500 == 0:
            with open("./models/CGAN-{}-{}-model.pkl".format(suffix, step + prev_steps), "wb") as f:
                pickle.dump(M, f)
            with open("./models/CGAN-{}-{}-history.pkl".format(suffix, step + prev_steps), "wb") as f:
                pickle.dump(history, f)

    return history

In [None]:
network = CGAN.CGAN(Params.Params(X, X_cond))

In [None]:
history = train(network, 
                X, 
                X_cond, 
                interval=50,
                train_steps=3600, 
                prev_steps=6400, 
                prev_hist=history, 
                suffix="dcgan_cond_lossy")

In [None]:
data_utils.visualize_history(history, window=100, n_points=600, power=2)

#### Test

In [None]:
def test(M, labels=[0]):
    
    fakes = []
    for i in labels:
        label_mat = np.zeros((network.params.batch_size, network.params.n_cond))
        label_idx = np.vstack(np.zeros(network.params.batch_size) + i).astype(np.int32)
        label_mat[np.arange(network.params.batch_size), label_idx] = 1

        noise = np.random.uniform(size=(network.params.batch_size, network.params.n_rand))
        noise = np.concatenate([noise, label_mat], axis=1)

        fake_data = network.G.predict(noise)
        fake_pred = network.D.predict(fake_data)
        indices = sorted(range(fake_pred.shape[0]), key=lambda x: fake_pred[x], reverse=True)

        fakes.append((fake_data[indices], face_pres[indices]))
        
        data_utils.visualize_z(fake_data[indices][-2:])
    
    return fakes

In [None]:
with open("./models/CGAN-dcgan_cond_noisy/CGAN-dcgan_cond_noisy-6000-model.pkl", "rb") as f:
    network = pickle.load(f)

In [None]:
fakes = test(network)

In [None]:
data_utils.visualize_history(history, window=100, n_points=600, power=2)