In [None]:
import os
import skimage
import skimage.io
import skimage.transform
import numpy as np
import random
import scipy.misc
import torch.optim as optim
from models import *

In [None]:
SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED)
use_cuda = torch.cuda.is_available()
if use_cuda:
    torch.cuda.manual_seed(SEED)

In [None]:
IMAGE_SIZE = 64
BATCH_SIZE = 32
N_EPOCHS = 1000000
LEARNING_RATE = 0.002
PRINT_EVERY = 100
SAVE_IMAGE_EVERY = 100
SAVE_MODEL_EVERY = 1000
IMAGE_DIR = '/mnt/disk0/kevin1kevin1k/final/images/'
MODEL_DIR = '/mnt/disk0/kevin1kevin1k/final/models/'
TAGS_PATH = '../hw4/data/tags_clean.csv'
FACES_DIR = '../hw4/data/faces/'

In [None]:
HAIR_REMOVE = (
    'pubic hair',
    'damage hair',
    'short hair',
    'long hair',
)

EYES_REMOVE = (
    '11 eyes',
    'bicolored eyes',
)

index2tags = dict()
tag2indices = dict()

is_hair = lambda tag: tag.endswith(' hair') and tag not in HAIR_REMOVE
is_eyes = lambda tag: tag.endswith(' eyes') and tag not in EYES_REMOVE

with open(TAGS_PATH) as f:
    for line in f:
        index, tags = line.strip().split(',')
        tag_list = tags.split('\t')
        tags = []
        for t in tag_list:
            tag = t.split(':')[0].strip()
            if is_hair(tag) or is_eyes(tag):
                tags.append(tag)
                if tag not in tag2indices:
                    tag2indices[tag] = []
                tag2indices[tag].append(index)
        if len(tags) >= 1:
            index2tags[index] = tags

num_indices = len(index2tags)
# print(num_indices)

# has_hair = set()
# has_eyes = set()
# for tag in tag2indices:
#     if is_hair(tag):
#         print(tag, len(tag2indices[tag]))
#         has_hair.update(tag2indices[tag])
# for tag in tag2indices:
#     if is_eyes(tag):
#         print(tag, len(tag2indices[tag]))
#         has_eyes.update(tag2indices[tag])

# has_both = sorted(list(has_hair & has_eyes))
# print(len(has_both))

In [None]:
tags_list = sorted(list(tag2indices.keys()))
num_tags = len(tags_list)
def tags_to_vector(tags):
    vec = np.zeros((num_tags))
    for tag in tags:
        index = tags_list.index(tag)
        vec[index] = 1.0
    return vec

In [None]:
images = np.zeros((num_indices, IMAGE_SIZE, IMAGE_SIZE, 3))
tag_matrix = np.zeros((num_indices, num_tags))

def fill_images_and_tag_matrix():
    global images
    global tag_matrix

    for i, face_index in enumerate(sorted(index2tags.keys())):
        # (96, 96, 3)
        image = skimage.io.imread(os.path.join(FACES_DIR, face_index + '.jpg'))

        # (64, 64, 3)
        image_resized = skimage.transform.resize(image, (IMAGE_SIZE, IMAGE_SIZE), mode='constant')

        images[i] = image_resized
        tag_matrix[i] = tags_to_vector(index2tags[face_index])
        if i == 8:
            print(face_index)
            break

    images = images.transpose(0, 3, 1, 2)
#     print(images.shape)

In [None]:
fill_images_and_tag_matrix()

In [None]:
def enable_gradients(net):
    for p in net.parameters():
        p.requires_grad = True

def disable_gradients(net):
    for p in net.parameters():
        p.requires_grad = False

In [None]:
enc = Encoder()
dec = Decoder(num_tags, image_size=IMAGE_SIZE)
dis = Discriminator(num_tags)
if use_cuda:
    enc = enc.cuda()
    dec = dec.cuda()
    dis = dis.cuda()

betas=(0.5, 0.999)
enc_opt = optim.Adam(enc.parameters(), lr=LEARNING_RATE, betas=betas)
dec_opt = optim.Adam(dec.parameters(), lr=LEARNING_RATE, betas=betas)
dis_opt = optim.Adam(dis.parameters(), lr=LEARNING_RATE, betas=betas)

In [None]:
for ep in range(N_EPOCHS):
    indices = random.sample(list(range(num_indices)), BATCH_SIZE)
    X_data = images[indices]
    X_tensor = torch.FloatTensor(X_data)
    if use_cuda:
        X_tensor = X_tensor.cuda()
    X = Variable(X_tensor, requires_grad=False)

    tags_data = tag_matrix[indices]
    tags_tensor = torch.FloatTensor(tags_data)
    if use_cuda:
        tags_tensor = tags_tensor.cuda()
    tags = Variable(tags_tensor, requires_grad=False)
    
    # update dis
    
    enable_gradients(dis)
    disable_gradients(enc)
    
    dis.zero_grad()
    
    EX = enc(X)
    P = dis(EX)
    loss = nn.BCEWithLogitsLoss()
    L_dis = loss(P, tags)
    L_dis.backward(retain_graph=True)
    dis_opt.step()
    
    # update enc and dec
    
    enable_gradients(enc)
    enable_gradients(dec)
    disable_gradients(dis)
    
    enc.zero_grad()
    dec.zero_grad()
    
    X_ = dec(EX, tags)
    
    loss = nn.BCEWithLogitsLoss()
    lambda_E = 0.0001 * min(ep / 500000, 1.0)
    L_enc_dec = (X - X_).norm()**2 / BATCH_SIZE - lambda_E * loss(P, 1 - tags)
    L_enc_dec.backward()
    enc_opt.step()
    dec_opt.step()
    
    if (ep + 1) % SAVE_IMAGE_EVERY == 0:
        X_data = images[3:4]
        X_tensor = torch.FloatTensor(X_data)
        if use_cuda:
            X_tensor = X_tensor.cuda()
        X = Variable(X_tensor, requires_grad=False)

        tags_data = tag_matrix[3:4]
        tags_fixed = tags_to_vector(('black hair', 'green eyes'))
        
        for i, tags_ in enumerate((tags_data * 0.1, tags_data * 0.5, tags_data, tags_fixed)):
            tags_tensor = torch.FloatTensor(tags_)
            if use_cuda:
                tags_tensor = tags_tensor.cuda()
            tags = Variable(tags_tensor, requires_grad=False)

            EX = enc(X)
            X_ = dec(EX, tags)
            img = X_.data.add_(1.0).mul_(0.5).cpu().numpy().transpose(0, 2, 3, 1)[0]
            image_path = os.path.join(IMAGE_DIR, 'test_{}_{}.jpg'.format(ep + 1, i + 1))
            scipy.misc.imsave(image_path, img)

    if (ep + 1) % SAVE_MODEL_EVERY == 0:
        enc_path = os.path.join(MODEL_DIR, 'enc_{}.pt'.format(ep + 1))
        dec_path = os.path.join(MODEL_DIR, 'dec_{}.pt'.format(ep + 1))
        dis_path = os.path.join(MODEL_DIR, 'dis_{}.pt'.format(ep + 1))
        torch.save(enc, enc_path)
        torch.save(dec, dec_path)
        torch.save(dis, dis_path)
    
    if (ep + 1) % PRINT_EVERY == 0:
        print('ep {}, L_dis {:.4f}, L_enc_dec {:.4f}'.format(ep + 1, L_dis.data[0], L_enc_dec.data[0]))

In [None]:
# import matplotlib.pyplot as plt
# %matplotlib inline

# img = images[3].transpose(1, 2, 0)
# plt.figure()
# plt.imshow(img)