Skip to content

Commit

Permalink
add author's model
Browse files Browse the repository at this point in the history
  • Loading branch information
carpedm20 committed Apr 7, 2017
1 parent 0f89f95 commit a15037a
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -10,8 +10,8 @@ Tensorflow implementation of [BEGAN: Boundary Equilibrium Generative Adversarial
- Python 2.7
- [Pillow](https://pillow.readthedocs.io/en/4.0.x/)
- [tqdm](https://github.com/tqdm/tqdm)
- [TensorFlow 1.1.0](https://github.com/tensorflow/tensorflow) (**Need nightly build** which can be found in [here](https://github.com/tensorflow/tensorflow#installation))
- [requests](https://github.com/kennethreitz/requests) (Only used for downloading CelebA dataset)
- [TensorFlow 1.1.0](https://github.com/tensorflow/tensorflow) (**Need nightly build** which can be found in [here](https://github.com/tensorflow/tensorflow#installation), if not you'll see `ValueError: 'image' must be three-dimensional.`)


## Usage
Expand Down
1 change: 1 addition & 0 deletions config.py
Expand Up @@ -19,6 +19,7 @@ def add_argument_group(name):
net_arg.add_argument('--conv_hidden_num', type=int, default=128,
choices=[64, 128],help='n in the paper')
net_arg.add_argument('--z_num', type=int, default=64, choices=[64, 128])
net_arg.add_argument('--use_authors_model', type=str2bool, default=True)

# Data
data_arg = add_argument_group('Data')
Expand Down
103 changes: 103 additions & 0 deletions layers.py
@@ -0,0 +1,103 @@
# Code from https://github.com/david-berthelot/tf_img_tech/blob/master/tfswag/layers.py
import numpy as N
import numpy.linalg as LA
import tensorflow as tf

__author__ = 'David Berthelot'


def unboxn(vin, n):
"""vin = (batch, h, w, depth), returns vout = (batch, n*h, n*w, depth), each pixel is duplicated."""
s = tf.shape(vin)
vout = tf.concat([vin] * (n ** 2), 0) # Poor man's replacement for tf.tile (required for Adversarial Training support).
vout = tf.reshape(vout, [s[0] * (n ** 2), s[1], s[2], s[3]])
vout = tf.batch_to_space(vout, [[0, 0], [0, 0]], n)
return vout


def boxn(vin, n):
"""vin = (batch, h, w, depth), returns vout = (batch, h//n, w//n, depth), each pixel is averaged."""
if n == 1:
return vin
s = tf.shape(vin)
vout = tf.reshape(vin, [s[0], s[1] // n, n, s[2] // n, n, s[3]])
vout = tf.reduce_mean(vout, [2, 4])
return vout


class LayerBase:
pass


class LayerConv(LayerBase):
def __init__(self, name, w, n, nl=lambda x, y: x + y, strides=(1, 1, 1, 1),
padding='SAME', conv=None, use_bias=True, data_format="NCHW"):
"""w = (wy, wx), n = (n_in, n_out)"""
self.nl = nl
self.strides = list(strides)
self.padding = padding
self.data_format = data_format
with tf.name_scope(name):
if conv is None:
conv = tf.Variable(tf.truncated_normal([w[0], w[1], n[0], n[1]], stddev=0.01), name='conv')
self.conv = conv
self.bias = tf.Variable(tf.zeros([n[1]]), name='bias') if use_bias else 0

def __call__(self, vin):
return self.nl(tf.nn.conv2d(vin, self.conv, strides=self.strides,
padding=self.padding, data_format=self.data_format), self.bias)

class LayerEncodeConvGrowLinear(LayerBase):
def __init__(self, name, n, width, colors, depth, scales, nl=lambda x, y: x + y, data_format="NCHW"):
with tf.variable_scope(name) as vs:
encode = []
nn = n
for x in range(scales):
cl = []
for y in range(depth - 1):
cl.append(LayerConv('conv_%d_%d' % (x, y), [width, width],
[nn, nn], nl, data_format=data_format))
cl.append(LayerConv('conv_%d_%d' % (x, depth - 1), [width, width],
[nn, nn + n], nl, strides=[1, 2, 2, 1], data_format=data_format))
encode.append(cl)
nn += n
self.encode = [LayerConv('conv_pre', [width, width], [colors, n], nl, data_format=data_format), encode]
self.variables = tf.contrib.framework.get_variables(vs)

def __call__(self, vin, carry=0, train=True):
vout = self.encode[0](vin)
for convs in self.encode[1]:
for conv in convs[:-1]:
vtmp = tf.nn.elu(conv(vout))
vout = carry * vout + (1 - carry) * vtmp
vout = convs[-1](vout)
return vout, self.variables


class LayerDecodeConvBlend(LayerBase):
def __init__(self, name, n, width, colors, depth, scales, nl=lambda x, y: x + y, data_format="NCHW"):
with tf.variable_scope(name) as vs:
decode = []
for x in range(scales):
cl = []
n2 = 2 * n if x else n
cl.append(LayerConv('conv_%d_%d' % (x, 0), [width, width],
[n2, n], nl, data_format=data_format))
for y in range(1, depth):
cl.append(LayerConv('conv_%d_%d' % (x, y), [width, width], [n, n], nl, data_format=data_format))
decode.append(cl)
self.decode = [decode, LayerConv('conv_post', [width, width], [n, colors], data_format=data_format)]
self.variables = tf.contrib.framework.get_variables(vs)

def __call__(self, data, carry, train=True):
vout = data
layers = []
for x, convs in enumerate(self.decode[0]):
vout = tf.concat([vout, data], 3) if x else vout
vout = unboxn(convs[0](vout), 2)
data = unboxn(data, 2)
for conv in convs[1:]:
vtmp = tf.nn.elu(conv(vout))
vout = carry * vout + (1 - carry) * vtmp
layers.append(vout)
return self.decode[1](vout), self.variables
43 changes: 37 additions & 6 deletions trainer.py
Expand Up @@ -67,6 +67,7 @@ def __init__(self, config, data_loader):
self.save_step = config.save_step
self.lr_update_step = config.lr_update_step

self.use_authors_model = config.use_authors_model
self.build_model()

self.saver = tf.train.Saver()
Expand Down Expand Up @@ -144,13 +145,43 @@ def build_model(self):
(tf.shape(x)[0], self.z_num), minval=-1.0, maxval=1.0)
self.k_t = tf.Variable(0., trainable=False, name='k_t')

G, self.G_var = GeneratorCNN(
self.z, self.z_num, channel, repeat_num, self.data_format, reuse=False)
if self.use_authors_model:
from layers import LayerEncodeConvGrowLinear, LayerDecodeConvBlend

d_out, self.D_var = DiscriminatorCNN(
tf.concat([G, x], 0), channel, self.z_num, repeat_num,
self.conv_hidden_num, self.data_format)
AE_G, AE_x = tf.split(d_out, 2)
G_in = slim.fully_connected(self.z, np.prod([8, 8, self.conv_hidden_num]))
G_in = reshape(G_in, 8, 8, self.conv_hidden_num, self.data_format)

G_enc = LayerDecodeConvBlend("G_decode", self.conv_hidden_num, 3, channel,
2, repeat_num, data_format=self.data_format)
G, self.G_var = G_enc(G_in, 0)
G = tf.reshape(G, [-1, self.input_scale_size, self.input_scale_size, channel])

D_enc = LayerEncodeConvGrowLinear("D_encode", self.conv_hidden_num, 3, channel,
2, repeat_num - 1, data_format=self.data_format)
D_enc, D_enc_var = D_enc(tf.concat([G, x], 0), 0)

out = tf.reshape(D_enc, [-1, np.prod([8, 8, int_shape(D_enc)[-1]])])
out = slim.fully_connected(out, self.z_num)

# Decoder
out = slim.fully_connected(out, np.prod([8, 8, self.conv_hidden_num]))
out = reshape(out, 8, 8, self.conv_hidden_num, self.data_format)

D_dec = LayerDecodeConvBlend("D_decode", self.conv_hidden_num, 2, channel,
2, repeat_num, data_format=self.data_format)
D, D_dec_var = D_dec(out, 0)
AE_G, AE_x = tf.split(D, 2)

self.D_var = D_enc_var + D_dec_var
else:
G, self.G_var = GeneratorCNN(
self.z, self.conv_hidden_num, channel, repeat_num, self.data_format, reuse=False)

d_out, self.D_var = DiscriminatorCNN(
tf.concat([G, x], 0), channel, self.z_num, repeat_num,
self.conv_hidden_num, self.data_format)
AE_G, AE_x = tf.split(d_out, 2)
import ipdb; ipdb.set_trace()

self.G = denorm_img(G, self.data_format)
self.AE_G, self.AE_x = denorm_img(AE_G, self.data_format), denorm_img(AE_x, self.data_format)
Expand Down

0 comments on commit a15037a

Please sign in to comment.