In [4]:
import tensorflow as tf
import numpy as np
from nbutil import imshow_multi
import os
from tensorflow.contrib.layers.python.layers import batch_norm
from tensorflow.contrib.layers import xavier_initializer
from PIL import Image
import urllib, cStringIO
# import cv2
from tensorflow.contrib import layers
import math

In [8]:
BATCH_SIZE = 16
PYRAMID_LAYERS = 2
LAYER_UPSCALE = 4
MAX_SIZE = (LAYER_UPSCALE ** PYRAMID_LAYERS)


In [6]:
attr_file = '../data/celeba/list_attr_celeba.txt'
keys = None
attributes_by_image = {}
for i, line in enumerate(open(attr_file)):
    if i == 1:
        keys = line.split()
    elif i > 1:
        image = os.path.join('../data/celeba/img_align_celeba', line.split()[0])
        values = line.split()[1:]
        attributes_by_image[image] = {attr: val == '1' for attr, val in zip(keys, values)}

image_names = attributes_by_image.keys()
attr_vector = np.zeros((len(attributes_by_image), len(keys)))
for i, image_name in enumerate(image_names):
    attrs = attributes_by_image[image_name]
    attr_vector[i] = [(1 if attrs[key] else 0) for key in keys]

def create_qs(train):
    filename_tensor = tf.convert_to_tensor(image_names, dtype=tf.string)
    attr_tensor = tf.convert_to_tensor(attr_vector, dtype=tf.float32)
    filename_q, attr_q = tf.train.slice_input_producer([filename_tensor, attr_tensor], num_epochs=None, shuffle=True)

    # reader = tf.WholeFileReader()
    # filename, image_data = reader.read(filename_q)
    image_255 = tf.image.decode_jpeg(tf.read_file(filename_q))
    image = tf.reshape(tf.cast(image_255, tf.float32) / 255.0, [218, 178, 3]) # images are 178x218
    # image = tf.image.rgb_to_grayscale(image)

    def resize_image(image):
        # image = tf.random_crop(image, [192, 160, 3])
        # return image
        image = tf.image.resize_image_with_crop_or_pad(image, 160, 160)
        img_reshaped = tf.reshape(image, [1, 160, 160, 3])
        image = tf.image.resize_bilinear(img_reshaped, [MAX_SIZE, MAX_SIZE])
        return tf.reshape(image, [MAX_SIZE, MAX_SIZE, 3])
    image = resize_image(image)

    def distort_image(image):
        # noise_amt = tf.abs(tf.random_normal([], stddev=0.2))
        # distorted_image = image + tf.random_uniform([64, 64, 3], maxval=noise_amt)
        distorted_image = tf.image.random_flip_left_right(image)
        distorted_image = tf.image.random_brightness(distorted_image, max_delta=0.3)
        distorted_image = tf.image.random_contrast(distorted_image, lower=0.6, upper=1.6)
        # distorted_image = tf.image.per_image_standardization(distorted_image)
        # distorted_image = tf.clip_by_value(distorted_image, 0, 1)
        return distorted_image

    images_batch, attrs_batch = tf.train.shuffle_batch([distort_image(image), attr_q], batch_size=BATCH_SIZE, capacity=BATCH_SIZE*20, min_after_dequeue=BATCH_SIZE*10)
    return images_batch, attrs_batch

images_batch, attrs_batch = create_qs(True)


In [7]:
keys

['5_o_Clock_Shadow',
 'Arched_Eyebrows',
 'Attractive',
 'Bags_Under_Eyes',
 'Bald',
 'Bangs',
 'Big_Lips',
 'Big_Nose',
 'Black_Hair',
 'Blond_Hair',
 'Blurry',
 'Brown_Hair',
 'Bushy_Eyebrows',
 'Chubby',
 'Double_Chin',
 'Eyeglasses',
 'Goatee',
 'Gray_Hair',
 'Heavy_Makeup',
 'High_Cheekbones',
 'Male',
 'Mouth_Slightly_Open',
 'Mustache',
 'Narrow_Eyes',
 'No_Beard',
 'Oval_Face',
 'Pale_Skin',
 'Pointy_Nose',
 'Receding_Hairline',
 'Rosy_Cheeks',
 'Sideburns',
 'Smiling',
 'Straight_Hair',
 'Wavy_Hair',
 'Wearing_Earrings',
 'Wearing_Hat',
 'Wearing_Lipstick',
 'Wearing_Necklace',
 'Wearing_Necktie',
 'Young']

In [5]:
initializer = tf.truncated_normal_initializer(0.00, 0.02)

def lrelu(x):
    # leaky relu
    alpha = 0.1
    return tf.maximum(alpha*x,x)

def identity(x): return x

def l2_loss(y, y_):
    return (y - y_) * (y - y_)

Z_SIZE = 32

def upscaler(img, z_vec, scope):
    # img can be none:
    with tf.variable_scope(scope):
        orig = img
        n_layers = int(round(math.log(LAYER_UPSCALE) / math.log(2)))
        z_reshaped = tf.reshape(tf.cast(z_vec, tf.float32), [-1, 1, 1, z_vec.get_shape()[-1].value])
        for i in xrange(n_layers):
            n_channels = 32 ** (n_layers - 1 - i)
            img = layers.conv2d(img,
                      n_channels * 2, 
                      scope='up_'+str(i), 
                      kernel_size=4, 
                      activation_fn=lrelu, 
                      stride=2,
                      normalizer_fn=layers.batch_norm, 
                      weights_initializer=initializer)
            
            _, h, w, _ = [x.value for x in img.get_shape()]
            z_tiled = tf.tile(z_reshaped, [1, h, w, 1])
            img = tf.concat([img, z_tiled], axis=3)
            
            if i + 1 < n_layers:
                act = lrelu
                out_channels = n_channels
            else:
                act = tf.tanh
                out_channels = 3
            img = layers.conv2d(img, 
                                n_channels, 
                                scope='up_1x1_'+str(i),
                                kernel_size=1,
                                stride=1,
                                activation_fn=act,
                                weights_initializer=initializer)
        
        _, h, w, _ = [x.value for x in orig.get_shape()]
        small_again = tf.image.resize_images(img, [h, w])
        sampling_loss = tf.reduce_mean(l2_loss(img, small_again))
        return img, sampling_loss
    
def upscaler_critic(small, big, scope):
    # small may be None, if this is the first layer:
    with tf.variable_scope(scope):
        _, h, w, _ = [x.value for x in big.get_shape()]
        if small:
            small_upscaled = tf.image.resize_images(small, [h, w])
            img = tf.concat([big, small_upscaled], axis=3)
        n_layers = 2
        for i in xrange(n_layers):
            n_channels = 32 * 2**i
            img = layers.conv2d(img, 
                                n_channels, 
                                scope='critic_'+str(i),
                                kernel_size=5,
                                stride=2,
                                activation_fn=lrelu,
                                normalizer_fn=layers.batch_norm, 
                                weights_initializer=initializer)
        # once we've downsampled using convolution, a final layer computes critic scores across the image,
        # which are averaged:
        img = layers.conv2d(img,1,
                            scope='final_'+str(i),
                            kernel_size=5,
                            stride=1,
                            activation_fn=identity,
                            weights_initializer=initializer)
        return tf.reduce_mean(img, axis=[1,2])
    
scopename = 'pg'
with tf.variable_scope(scopename):
    attributes = ['Male', 'Smiling']
    attr_indices = [keys.index(attr) for attr in attributes]
    cond_vec = tf.transpose(tf.gather(tf.transpose(attrs_batch, [1, 0]), attr_indices), [1, 0])
    
    # generator:
    z = tf.random_uniform([BATCH_SIZE, Z_SIZE], minval=-1, maxval=1)
    gen_image_layers = [tf.reshape(z, [-1, 1, 1, Z_SIZE])]
    gen_losses = []
    for i in xrange(PYRAMID_LAYERS):
        gen_img, scale_loss = upscaler(gen_image_layers[-1], z, 'upscaler_' + str(i))
        gen_image_layers.append(gen_img)
        gen_losses.append(scale_loss)
    
    # critic:
    for i in xrange(PYRAMID_LAYERS):
        pass
    
    real_disc_output = discriminator(images_batch, cond_vec)
    fake_disc_output = discriminator(reconstruction, cond_vec)
#     reals_are_real = tf.reduce_mean(l2_loss(real_disc_output, tf.ones_like(real_disc_output)))
#     fakes_are_real = tf.reduce_mean(l2_loss(fake_disc_output, tf.ones_like(fake_disc_output)))
#     fakes_are_fake = tf.reduce_mean(l2_loss(fake_disc_output, -tf.ones_like(fake_disc_output)))
    # the critic outputs large positive numbers for real images, and larhe negative numbers for fake ones:
    reals_are_real = tf.reduce_mean(-real_disc_output)
    fakes_are_real = tf.reduce_mean(-fake_disc_output)
    fakes_are_fake = tf.reduce_mean(fake_disc_output)
    
    diff_loss = tf.reduce_mean(l2_loss(images_batch, reconstruction))
    kl = kl_divergence(z_mean, z_log_stddev)
    gen_loss = diff_loss + kl + (fakes_are_real if USE_GAN else 0)
    
    disc_loss = reals_are_real + fakes_are_fake
    
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scopename+'/discriminator')
    gen_vars = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if v not in disc_vars]
    
    global_step = tf.contrib.framework.get_or_create_global_step()
    train_gen = tf.train.AdamOptimizer(0.001).minimize(gen_loss, var_list=gen_vars, global_step=global_step)
    
    train_disc_op = tf.train.RMSPropOptimizer(0.0001).minimize(disc_loss, var_list=disc_vars)
    weight_clip_op = tf.group(*[w.assign(tf.clip_by_value(w, -0.01, 0.01)) for w in disc_vars])
    train_disc = tf.group(weight_clip_op, train_disc_op)

print 'ok'

ok


In [6]:
session = None
saver = None
save_path = None

def create_session():
    global session
    global saver
    global save_path
    
    if session: session.close()
    
    session = tf.InteractiveSession()

    save_path = None # 'models/lsgan-face-cond-1'
    
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    session.run(init_op)
    tf.train.start_queue_runners(sess=session)

    import os
    saver = None
    if save_path:
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(save_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(session, ckpt.model_checkpoint_path)
            print 'Restored from checkpoint', ckpt.model_checkpoint_path
        else:
            print 'Did not restore from checkpoint'
    else:
        print 'Will not save progress'

create_session()


Will not save progress


In [10]:
def avg(x): return sum(x) / float(len(x))
def flatten(l):
    return [item for sublist in l for item in sublist]

savecount = 0

gen_losses = []
disc_losses = []

while True:
    examples_ = None
    step_ = global_step.eval()
    
    feed = {}
    
    if USE_GAN:
        _, _, gen_loss_, disc_loss_ = session.run([train_gen, train_disc, gen_loss, disc_loss], feed_dict=feed)
        gen_losses.append(gen_loss_)
        disc_losses.append(disc_loss_)
    else:
        _, gen_loss_ = session.run([train_gen, gen_loss], feed_dict=feed)
        gen_losses.append(gen_loss_)
    
    if step_ % 200 == 0:
        if USE_GAN:
            print "Step: {}, disc loss: {}, gen loss: {}".format(step_, avg(disc_losses), avg(gen_losses))
        else:
            print "Step: {}, gen loss: {}".format(step_, avg(gen_losses))
        disc_losses = []
        gen_losses = []
        
        if step_ % 800 == 0:
            examples_, originals_ = session.run([reconstruction[:3], images_batch[:3]])
            imshow_multi(flatten(zip(originals_, examples_)))
        
        if step_ % 2000 == 0 and saver:
            should_save = True
            if should_save:
                saver.save(session, save_path + '/model.ckpt', global_step=step_)
                print 'Saved'
                savecount += 1
                if savecount > 4:
                    create_session()
                    savecount = 0


KeyboardInterrupt: 

In [8]:
kl_loss_, diff_loss_ = session.run([kl, diff_loss])
print kl_loss_, diff_loss_

0.00730133 0.095987
