In [1]:
import sys
import numpy as np
import tensorflow as tf
import tensorlayer as tl

import os, time
from glob import glob
from random import shuffle

from utils import *
import vgg

from models import dcgan_decoder, vanilla_encoder

  return f(*args, **kwds)


In [2]:
##params
dataset = 'CelebA'
save_step = 500
sample_step = 500
sample_size = 64

train_size = np.inf
num_epochs = 25

batch_size = 64
original_size = 108
is_crop = True
input_size = 64
c_dim = 3
l = 0.5 # weight between pixel and perceptual loss
lr = 0.0005

In [3]:
## pathes
vgg_path = './checkpoints/imagenet-vgg-verydeep-19.mat'
checkpoint_dir = './checkpoints/'
dcgan_path = './checkpoints/dcgan/'

In [4]:
## Define model
print('Building Model')

input_img = tf.placeholder(tf.float32, shape=[batch_size, input_size, input_size, c_dim], name='input_img')
print('Input Shape: ', input_img.get_shape())
encoder_net, _ = vanilla_encoder.model(input_img, z_dim=100)
print('Latent Shape: ', encoder_net.outputs.get_shape())

[TL] InputLayer  encoder/d/in: (64, 64, 64, 3)
[TL] Conv2dLayer encoder/enc/h0/conv2d: shape:[5, 5, 3, 64] strides:[1, 2, 2, 1] pad:SAME act:<lambda>
[TL] Conv2dLayer encoder/enc/h1/conv2d: shape:[5, 5, 64, 128] strides:[1, 2, 2, 1] pad:SAME act:identity
[TL] BatchNormLayer encoder/enc/h1/bn: decay:0.900000 epsilon:0.000010 act:<lambda> is_train:True
[TL] Conv2dLayer encoder/enc/h2/conv2d: shape:[5, 5, 128, 256] strides:[1, 2, 2, 1] pad:SAME act:identity


Building Model
Input Shape:  (64, 64, 64, 3)


[TL] BatchNormLayer encoder/enc/h2/bn: decay:0.900000 epsilon:0.000010 act:<lambda> is_train:True
[TL] Conv2dLayer encoder/enc/h3/conv2d: shape:[5, 5, 256, 512] strides:[1, 2, 2, 1] pad:SAME act:identity
[TL] BatchNormLayer encoder/enc/h3/bn: decay:0.900000 epsilon:0.000010 act:<lambda> is_train:True
[TL] FlattenLayer encoder/enc/h4/flatten: 8192
[TL] DenseLayer  encoder/enc/h4/lin: 100 identity


Latent Shape:  (64, 100)


In [5]:
decoder_net, _ = dcgan_decoder.model(encoder_net.outputs, image_size=input_size, c_dim=c_dim, batch_size=batch_size)
print('Output Shape: ', decoder_net.outputs.get_shape())
print('Successfully built!')

[TL] InputLayer  generator/g/in: (64, 100)
[TL] DenseLayer  generator/g/h0/lin: 8192 identity
[TL] ReshapeLayer generator/g/h0/reshape: (64, 4, 4, 512)
[TL] BatchNormLayer generator/g/h0/batch_norm: decay:0.900000 epsilon:0.000010 act:relu is_train:False
[TL] DeConv2dLayer generator/g/h1/decon2d: shape:[5, 5, 256, 512] out_shape:[64, 8, 8, 256] strides:[1, 2, 2, 1] pad:SAME act:identity
[TL] BatchNormLayer generator/g/h1/batch_norm: decay:0.900000 epsilon:0.000010 act:relu is_train:False
[TL] DeConv2dLayer generator/g/h2/decon2d: shape:[5, 5, 128, 256] out_shape:[64, 16, 16, 128] strides:[1, 2, 2, 1] pad:SAME act:identity
[TL] BatchNormLayer generator/g/h2/batch_norm: decay:0.900000 epsilon:0.000010 act:relu is_train:False
[TL] DeConv2dLayer generator/g/h3/decon2d: shape:[5, 5, 64, 128] out_shape:[64, 32, 32, 64] strides:[1, 2, 2, 1] pad:SAME act:identity
[TL] BatchNormLayer generator/g/h3/batch_norm: decay:0.900000 epsilon:0.000010 act:relu is_train:False
[TL] DeConv2dLayer generator/

Output Shape:  (64, 64, 64, 3)
Successfully built!


In [6]:
## Define loss and training ops
loss_pixel = tf.nn.l2_loss(input_img - decoder_net.outputs) / (batch_size*input_size*input_size*c_dim)

vgg_net = vgg.VGG(vgg_path)
loss_calc = LossCalculator(vgg_net, decoder_net.outputs)
loss_perc = loss_calc.content_loss(input_img, content_layer='relu4_3', content_weight=1) / batch_size
loss = l * loss_pixel + (1 - l) * loss_perc
train_param = encoder_net.all_params + decoder_net.all_params
train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss, var_list=train_param)

In [7]:
# open up a tf.Session
sess = tf.InteractiveSession()
tl.layers.initialize_global_variables(sess)

In [None]:
# loading pretrained param
print('Loading trained parameters of decoder network...')
decoder_params = tl.files.load_npz(name=dcgan_path+'net_g.npz')
tl.files.assign_params(sess, decoder_params, decoder_net)
print('Successfully loaded trained parameters of the decoder network')
decoder_net.print_params()

In [9]:
# TODO:
# Set path to save param and get training data



data_files = glob(os.path.join("./data", dataset, "*.jpg"))

In [None]:
# Get the model up and training
iter_counter = 0
for epoch in range(num_epochs):
    shuffle(data_files)
    