Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for variable image sizes #6

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/basemodel.py → basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def _build_gen_graph(self):

def _build_train_graph(self, X):
'''build computational graph for training'''
pass
pass
6 changes: 3 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

model_zoo = ['DCGAN', 'LSGAN', 'WGAN', 'WGAN-GP', 'EBGAN', 'BEGAN', 'DRAGAN', 'CoulombGAN']

def get_model(mtype, name, training):
def get_model(mtype, name, training,image_shape=[64,64,3]):
model = None
if mtype == 'DCGAN':
model = dcgan.DCGAN
Expand All @@ -26,7 +26,7 @@ def get_model(mtype, name, training):

assert model, mtype + ' is work in progress'

return model(name=name, training=training)
return model(name=name, training=training,image_shape=image_shape)


def get_dataset(dataset_name):
Expand All @@ -35,7 +35,7 @@ def get_dataset(dataset_name):
lsun_bedroom_128 = './data/lsun/bedroom_128_tfrecords/*.tfrecord'

if dataset_name == 'celeba':
path = celebA_128
path = celebA_64
n_examples = 202599
elif dataset_name == 'lsun':
path = lsun_bedroom_128
Expand Down
8 changes: 4 additions & 4 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def convert(source_dir, target_dir, crop_size, out_size, exts=[''], num_shards=1
if not tf.gfile.Exists(source_dir):
print('source_dir does not exists')
return

if tfrecords_prefix and not tfrecords_prefix.endswith('-'):
tfrecords_prefix += '-'

Expand Down Expand Up @@ -122,11 +122,11 @@ def export_images(db_path, out_dir, flat=False, limit=-1):

if __name__ == "__main__":
# CelebA
convert('./data/celebA', './data/celebA_128_tfrecords', crop_size=[128, 128], out_size=[128, 128],
convert('/home/ibhat/image_completion/dcgan-completion.tensorflow/data/celebA', './data/celebA_tfrecords', crop_size=[64, 64], out_size=[64, 64],
exts=['jpg'], num_shards=128, tfrecords_prefix='celebA')

# LSUN
# export_images('./tf.gans-comparison/data/lsun/bedroom_val_lmdb/',
# export_images('./tf.gans-comparison/data/lsun/bedroom_val_lmdb/',
# './tf.gans-comparison/data/lsun/bedroom_val_images/', flat=True)
# convert('./data/lsun/bedroom_train_images', './data/lsun/bedroom_128_tfrecords', crop_size=[128, 128],
# convert('./data/lsun/bedroom_train_images', './data/lsun/bedroom_128_tfrecords', crop_size=[128, 128],
# out_size=[128, 128], exts=['webp'], num_shards=128, tfrecords_prefix='lsun_bedroom')
17 changes: 10 additions & 7 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def build_parser():
parser = ArgumentParser()
models_str = ' / '.join(config.model_zoo)
parser.add_argument('--model', help=models_str, required=True)
parser.add_argument('--model', help=models_str, required=True)
parser.add_argument('--name', help='default: name=model')
parser.add_argument('--dataset', '-D', help='CelebA / LSUN', required=True)
parser.add_argument('--sample_size', '-N', help='# of samples. It should be a square number. (default: 16)',
Expand All @@ -27,8 +27,8 @@ def sample_z(shape):

def get_all_checkpoints(ckpt_dir, force=False):
'''
When the learning is interrupted and resumed, all checkpoints can not be fetched with get_checkpoint_state
(The checkpoint state is rewritten from the point of resume).
When the learning is interrupted and resumed, all checkpoints can not be fetched with get_checkpoint_state
(The checkpoint state is rewritten from the point of resume).
This function fetch all checkpoints forcely when arguments force=True.
'''

Expand All @@ -41,14 +41,17 @@ def get_all_checkpoints(ckpt_dir, force=False):
ckpts = map(lambda x: os.path.join(ckpt_dir, x), ckpts) # fn => path
else:
ckpts = tf.train.get_checkpoint_state(ckpt_dir).all_model_checkpoint_paths

return ckpts


def eval(model, name, dataset, sample_shape=[4,4], load_all_ckpt=True):
def eval(model, name, sample_dir,dataset, sample_shape=[4,4], load_all_ckpt=True):
if name == None:
name = model.name
dir_name = os.path.join('eval', dataset, name)
if sample_dir = None:
dir_name = os.path.join('eval', dataset, name)
else:
dir_name = os.path.join(sample_dir,dataset,name)
if tf.gfile.Exists(dir_name):
tf.gfile.DeleteRecursively(dir_name)
tf.gfile.MakeDirs(dir_name)
Expand All @@ -69,7 +72,7 @@ def eval(model, name, dataset, sample_shape=[4,4], load_all_ckpt=True):
print("Evaluating {} ...".format(v))
restorer.restore(sess, v)
global_step = int(v.split('/')[-1].split('-')[-1])

fake_samples = sess.run(model.fake_sample, {model.z: z_})

# inverse transform: [-1, 1] => [0, 1]
Expand Down
41 changes: 20 additions & 21 deletions inputpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,25 @@
import tensorflow as tf


def read_parse_preproc(filename_queue):
def read_parse_preproc(filename_queue,image_size):
''' read, parse, and preproc single example. '''
with tf.variable_scope('read_parse_preproc'):
reader = tf.TFRecordReader()
key, records = reader.read(filename_queue)

# parse records
features = tf.parse_single_example(
records,
features={
"image": tf.FixedLenFeature([], tf.string)
}
)

image = tf.decode_raw(features["image"], tf.uint8)
image = tf.reshape(image, [128, 128, 3]) # The image_shape must be explicitly specified
image = tf.image.resize_images(image, [64, 64])
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.resize_images(image, [image_size, image_size])
image = tf.cast(image, tf.float32)
image = image / 127.5 - 1.0 # preproc - normalize

return [image]


Expand All @@ -31,51 +30,51 @@ def get_batch(tfrecords_list, batch_size, shuffle=False, num_threads=1, min_afte
with tf.variable_scope(name):
filename_queue = tf.train.string_input_producer(tfrecords_list, shuffle=shuffle, num_epochs=num_epochs)
data_point = read_parse_preproc(filename_queue)

if min_after_dequeue is None:
min_after_dequeue = batch_size * 10
capacity = min_after_dequeue + 3*batch_size
if shuffle:
batch = tf.train.shuffle_batch(data_point, batch_size=batch_size, capacity=capacity,
batch = tf.train.shuffle_batch(data_point, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue, num_threads=num_threads, allow_smaller_final_batch=True)
else:
batch = tf.train.batch(data_point, batch_size, capacity=capacity, num_threads=num_threads,
batch = tf.train.batch(data_point, batch_size, capacity=capacity, num_threads=num_threads,
allow_smaller_final_batch=True)

return batch


def get_batch_join(tfrecords_list, batch_size, shuffle=False, num_threads=1, min_after_dequeue=None, num_epochs=None):
def get_batch_join(tfrecords_list, batch_size, shuffle=False, num_threads=1, min_after_dequeue=None, num_epochs=None,image_size=64):
name = "batch_join" if not shuffle else "shuffle_batch_join"
with tf.variable_scope(name):
filename_queue = tf.train.string_input_producer(tfrecords_list, shuffle=shuffle, num_epochs=num_epochs)
example_list = [read_parse_preproc(filename_queue) for _ in range(num_threads)]
example_list = [read_parse_preproc(filename_queue,image_size=image_size) for _ in range(num_threads)]

if min_after_dequeue is None:
min_after_dequeue = batch_size * 10
capacity = min_after_dequeue + 3*batch_size
if shuffle:
batch = tf.train.shuffle_batch_join(tensors_list=example_list, batch_size=batch_size, capacity=capacity,
batch = tf.train.shuffle_batch_join(tensors_list=example_list, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue, allow_smaller_final_batch=True)
else:
batch = tf.train.batch_join(example_list, batch_size, capacity=capacity, allow_smaller_final_batch=True)

return batch


# interfaces
def shuffle_batch_join(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None):
return get_batch_join(tfrecords_list, batch_size, shuffle=True, num_threads=num_threads,
num_epochs=num_epochs, min_after_dequeue=min_after_dequeue)
def shuffle_batch_join(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None,image_size = 64):
return get_batch_join(tfrecords_list, batch_size, shuffle=True, num_threads=num_threads,
num_epochs=num_epochs, min_after_dequeue=min_after_dequeue,image_size=image_size)

def batch_join(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None):
return get_batch_join(tfrecords_list, batch_size, shuffle=False, num_threads=num_threads,
return get_batch_join(tfrecords_list, batch_size, shuffle=False, num_threads=num_threads,
num_epochs=num_epochs, min_after_dequeue=min_after_dequeue)

def shuffle_batch(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None):
return get_batch(tfrecords_list, batch_size, shuffle=True, num_threads=num_threads,
return get_batch(tfrecords_list, batch_size, shuffle=True, num_threads=num_threads,
num_epochs=num_epochs, min_after_dequeue=min_after_dequeue)

def batch(tfrecords_list, batch_size, num_threads, num_epochs, min_after_dequeue=None):
return get_batch(tfrecords_list, batch_size, shuffle=False, num_threads=num_threads,
return get_batch(tfrecords_list, batch_size, shuffle=False, num_threads=num_threads,
num_epochs=num_epochs, min_after_dequeue=min_after_dequeue)
2 changes: 1 addition & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def get_all_modules_cwd():
return [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]


__all__ = get_all_modules_cwd()
__all__ = get_all_modules_cwd()
51 changes: 29 additions & 22 deletions models/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from utils import expected_shape
import ops
from basemodel import BaseModel

import math
'''Original hyperparams:
optimizer - SGD
init - stddev 0.02
Expand All @@ -13,15 +13,14 @@
class DCGAN(BaseModel):
def __init__(self, name, training, D_lr=2e-4, G_lr=2e-4, image_shape=[64, 64, 3], z_dim=100):
self.beta1 = 0.5
super(DCGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr,
super(DCGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr,
image_shape=image_shape, z_dim=z_dim)

def _build_train_graph(self):
with tf.variable_scope(self.name):
X = tf.placeholder(tf.float32, [None] + self.shape)
z = tf.placeholder(tf.float32, [None, self.z_dim])
global_step = tf.Variable(0, name='global_step', trainable=False)

G = self._generator(z)
D_real_prob, D_real_logits = self._discriminator(X)
D_fake_prob, D_fake_logits = self._discriminator(G, reuse=True)
Expand Down Expand Up @@ -65,23 +64,29 @@ def _build_train_graph(self):
self.z = z
self.D_train_op = D_train_op
self.G_train_op = G_train_op
self.G_loss = G_loss
self.D_loss = D_loss
self.fake_sample = G
self.global_step = global_step

def _discriminator(self, X, reuse=False):
with tf.variable_scope('D', reuse=reuse):
net = X

with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=ops.lrelu,
width = self.shape[0]
filter_num = 64
stride = 2
num_conv_layers = 4
with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=stride, padding='SAME', activation_fn=ops.lrelu,
normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params):
net = slim.conv2d(net, 64, normalizer_fn=None)
expected_shape(net, [32, 32, 64])
net = slim.conv2d(net, 128)
expected_shape(net, [16, 16, 128])
net = slim.conv2d(net, 256)
expected_shape(net, [8, 8, 256])
net = slim.conv2d(net, 512)
expected_shape(net, [4, 4, 512])
for layer_num in range(1,num_conv_layers + 1):
if layer_num == 1: # No batch norm for the first convolution
net = slim.conv2d(net, filter_num, normalizer_fn=None)
else:
net = slim.conv2d(net, filter_num)
output_dim = math.ceil(width/stride) # Since padding='SAME', refer : https://www.tensorflow.org/api_guides/python/nn#Convolution -- Ishaan
expected_shape(net, [output_dim, output_dim, filter_num])
width = width // 2
filter_num = filter_num*2

net = slim.flatten(net)
logits = slim.fully_connected(net, 1, activation_fn=None)
Expand All @@ -94,16 +99,18 @@ def _generator(self, z, reuse=False):
net = z
net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu)
net = tf.reshape(net, [-1, 4, 4, 1024])

with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, padding='SAME',
filter_num = 512
input_size = 4
stride = 2
with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=stride, padding='SAME',
activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params):
net = slim.conv2d_transpose(net, 512)
expected_shape(net, [8, 8, 512])
net = slim.conv2d_transpose(net, 256)
expected_shape(net, [16, 16, 256])
net = slim.conv2d_transpose(net, 128)
expected_shape(net, [32, 32, 128])
while input_size < (self.shape[0]//stride):
net = slim.conv2d_transpose(net, filter_num)
expected_shape(net, [input_size*stride, input_size*stride, filter_num])
filter_num = filter_num//2
input_size = input_size*stride

net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None)
expected_shape(net, [64, 64, 3])
expected_shape(net, [self.shape[0], self.shape[1], 3])

return net
10 changes: 5 additions & 5 deletions models/wgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class WGAN(BaseModel):
def __init__(self, name, training, D_lr=5e-5, G_lr=5e-5, image_shape=[64, 64, 3], z_dim=100):
self.ld = 10. # lambda
self.n_critic = 5
super(WGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr,
super(WGAN, self).__init__(name=name, training=training, D_lr=D_lr, G_lr=G_lr,
image_shape=image_shape, z_dim=z_dim)

def _build_train_graph(self):
Expand Down Expand Up @@ -51,7 +51,7 @@ def _build_train_graph(self):

# weight clipping
''' It is right that clips gamma of the batch_norm? '''

# ver 1. clips all variables in critic
C_clips = [tf.assign(var, tf.clip_by_value(var, -0.01, 0.01)) for var in C_vars] # with gamma

Expand Down Expand Up @@ -95,8 +95,8 @@ def _critic(self, X, reuse=False):
''' K-Lipschitz function '''
with tf.variable_scope('critic', reuse=reuse):
net = X
with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, activation_fn=ops.lrelu,

with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, activation_fn=ops.lrelu,
normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params):
net = slim.conv2d(net, 64, normalizer_fn=None)
expected_shape(net, [32, 32, 64])
Expand All @@ -118,7 +118,7 @@ def _generator(self, z, reuse=False):
net = slim.fully_connected(net, 4*4*1024, activation_fn=tf.nn.relu)
net = tf.reshape(net, [-1, 4, 4, 1024])

with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, activation_fn=tf.nn.relu,
with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params):
net = slim.conv2d_transpose(net, 512)
expected_shape(net, [8, 8, 512])
Expand Down
Loading