Skip to content

Commit

Permalink
add label flip options
Browse files Browse the repository at this point in the history
  • Loading branch information
kaonashi-tyc committed Apr 30, 2017
1 parent 12b3553 commit b73feb0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 76 deletions.
120 changes: 53 additions & 67 deletions model/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Used to save handles(important nodes in computation graph) for later evaluation
LossHandle = namedtuple("LossHandle", ["d_loss", "g_loss", "const_loss", "l1_loss",
"category_loss", "cheat_loss", "tv_loss"])
InputHandle = namedtuple("InputHandle", ["real_data", "embedding_ids", "external_data", "external_ids"])
InputHandle = namedtuple("InputHandle", ["real_data", "embedding_ids", "no_target_data", "no_target_ids"])
EvalHandle = namedtuple("EvalHandle", ["encoder", "generator", "target", "source", "embedding"])
SummaryHandle = namedtuple("SummaryHandle", ["d_merged", "g_merged"])

Expand Down Expand Up @@ -152,17 +152,17 @@ def discriminator(self, image, is_training, reuse=False):

return tf.nn.sigmoid(fc1), fc1, fc2

def build_model(self, is_training=True, inst_norm=False, with_no_target_source=False):
def build_model(self, is_training=True, inst_norm=False, no_target_source=False):
real_data = tf.placeholder(tf.float32,
[self.batch_size, self.input_width, self.input_width,
self.input_filters + self.output_filters],
name='real_A_and_B_images')
embedding_ids = tf.placeholder(tf.int64, shape=None, name="embedding_ids")
external_data = tf.placeholder(tf.float32,
[self.batch_size, self.input_width, self.input_width,
self.input_filters + self.output_filters],
name='exteranl_A_and_B_images')
external_ids = tf.placeholder(tf.int64, shape=None, name="external_embedding_ids")
no_target_data = tf.placeholder(tf.float32,
[self.batch_size, self.input_width, self.input_width,
self.input_filters + self.output_filters],
name='no_target_A_and_B_images')
no_target_ids = tf.placeholder(tf.int64, shape=None, name="no_target_embedding_ids")

# target images
real_B = real_data[:, :, :, :self.input_filters]
Expand Down Expand Up @@ -214,36 +214,37 @@ def build_model(self, is_training=True, inst_norm=False, with_no_target_source=F
d_loss = d_loss_real + d_loss_fake + category_loss / 2.0
g_loss = cheat_loss + l1_loss + self.Lcategory_penalty * fake_category_loss + const_loss + tv_loss

if with_no_target_source:
# external data are examples that don't have the corresponding target images
# however, except L1 loss, we can compute category loss and binary loss with those examples
if no_target_source:
# no_target source are examples that don't have the corresponding target images
# however, except L1 loss, we can compute category loss, binary loss and constant losses with those examples
# it is useful when discriminator get saturated and d_loss drops to near zero
# those data could be used as additional source of losses
external_A = external_data[:, :, :, self.input_filters:self.input_filters + self.output_filters]
external_B, encoded_external_A = self.generator(external_A, embedding, external_ids,
is_training=is_training,
inst_norm=inst_norm, reuse=True)
external_labels = tf.reshape(tf.one_hot(indices=external_ids, depth=self.embedding_num),
shape=[self.batch_size, self.embedding_num])
external_AB = tf.concat([external_A, external_B], 3)
external_D, external_D_logits, external_category_logits = self.discriminator(external_AB,
is_training=is_training,
reuse=True)
encoded_external_B = self.encoder(external_B, is_training, reuse=True)[0]
external_const_loss = tf.reduce_mean(
tf.square(encoded_external_A - encoded_external_B)) * self.Lconst_penalty
external_category_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=external_category_logits,
labels=external_labels)) * self.Lcategory_penalty

d_loss_external = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=external_D_logits,
labels=tf.zeros_like(external_D)))
cheat_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=external_D_logits,
labels=tf.ones_like(external_D)))
d_loss = d_loss_real + d_loss_fake + d_loss_external + category_loss + external_category_loss
g_loss = cheat_loss + l1_loss + \
(self.Lcategory_penalty * fake_category_loss + external_category_loss) / 2.0 + \
(const_loss + external_const_loss) / 2.0 + tv_loss
# those data could be used as additional source of losses to break the saturation
no_target_A = no_target_data[:, :, :, self.input_filters:self.input_filters + self.output_filters]
no_target_B, encoded_no_target_A = self.generator(no_target_A, embedding, no_target_ids,
is_training=is_training,
inst_norm=inst_norm, reuse=True)
no_target_labels = tf.reshape(tf.one_hot(indices=no_target_ids, depth=self.embedding_num),
shape=[self.batch_size, self.embedding_num])
no_target_AB = tf.concat([no_target_A, no_target_B], 3)
no_target_D, no_target_D_logits, no_target_category_logits = self.discriminator(no_target_AB,
is_training=is_training,
reuse=True)
encoded_no_target_B = self.encoder(no_target_B, is_training, reuse=True)[0]
no_target_const_loss = tf.reduce_mean(
tf.square(encoded_no_target_A - encoded_no_target_B)) * self.Lconst_penalty
no_target_category_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=no_target_category_logits,
labels=no_target_labels)) * self.Lcategory_penalty

d_loss_no_target = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=no_target_D_logits,
labels=tf.zeros_like(
no_target_D)))
cheat_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=no_target_D_logits,
labels=tf.ones_like(no_target_D)))
d_loss = d_loss_real + d_loss_fake + d_loss_no_target + (category_loss + no_target_category_loss) / 3.0
g_loss = cheat_loss / 2.0 + l1_loss + \
(self.Lcategory_penalty * fake_category_loss + no_target_category_loss) / 2.0 + \
(const_loss + no_target_const_loss) / 2.0 + tv_loss

d_loss_real_summary = tf.summary.scalar("d_loss_real", d_loss_real)
d_loss_fake_summary = tf.summary.scalar("d_loss_fake", d_loss_fake)
Expand All @@ -266,8 +267,8 @@ def build_model(self, is_training=True, inst_norm=False, with_no_target_source=F
# expose useful nodes in the graph as handles globally
input_handle = InputHandle(real_data=real_data,
embedding_ids=embedding_ids,
external_data=external_data,
external_ids=external_ids)
no_target_data=no_target_data,
no_target_ids=no_target_ids)

loss_handle = LossHandle(d_loss=d_loss,
g_loss=g_loss,
Expand Down Expand Up @@ -357,8 +358,8 @@ def generate_fake_samples(self, input_images, embedding_ids):
feed_dict={
input_handle.real_data: input_images,
input_handle.embedding_ids: embedding_ids,
input_handle.external_data: input_images,
input_handle.external_ids: embedding_ids
input_handle.no_target_data: input_images,
input_handle.no_target_ids: embedding_ids
})
return fake_images, real_images, d_loss, g_loss, l1_loss

Expand Down Expand Up @@ -492,7 +493,7 @@ def filter_embedding_vars(var):
op = tf.assign(var, val, validate_shape=False)
self.sess.run(op)

def train(self, lr=0.0002, epoch=100, schedule=10, resume=True, tune_mode=None, external_source=None,
def train(self, lr=0.0002, epoch=100, schedule=10, resume=True, flip_labels=False,
freeze_encoder=False, fine_tune=None, sample_steps=50, checkpoint_steps=500):
g_vars, d_vars = self.retrieve_trainable_vars(freeze_encoder=freeze_encoder)
input_handle, loss_handle, _, summary_handle = self.retrieve_handles()
Expand All @@ -506,26 +507,13 @@ def train(self, lr=0.0002, epoch=100, schedule=10, resume=True, tune_mode=None,
tf.global_variables_initializer().run()
real_data = input_handle.real_data
embedding_ids = input_handle.embedding_ids
external_data = input_handle.external_data
external_ids = input_handle.external_ids
no_target_data = input_handle.no_target_data
no_target_ids = input_handle.no_target_ids

# filter by one type of labels
data_provider = TrainDataProvider(self.data_dir, filter_by=fine_tune)
total_batches = data_provider.compute_total_batch_num(self.batch_size)
val_batch_iter = data_provider.get_val_iter(self.batch_size)
external_data_iter = None
if tune_mode == 'external' and external_source:
# external source specified, those examples don't
# have corresponding target example, however, they are a valuable
# provider to prevent discriminator from getting stuck
# tune_mode can be either shuffle or external
# shuffling meaning it we still using the training examples themselves
# however we shuffle the embedding ids to provide 'new' examples
# external source are sources that could be completely independent of the training
# data, i.e. the glyphs from another language in case of zi2zi
print("loading external sources -> {0}".format(external_source))
external_data_iter = NeverEndingLoopingProvider(external_source) \
.get_random_embedding_iter(self.batch_size, data_provider.get_all_labels())

saver = tf.train.Saver(max_to_keep=3)
summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
Expand All @@ -551,29 +539,27 @@ def train(self, lr=0.0002, epoch=100, schedule=10, resume=True, tune_mode=None,
for bid, batch in enumerate(train_batch_iter):
counter += 1
labels, batch_images = batch
external_batch, external_batch_ids = batch_images, labels[:]
if external_data_iter:
external_batch_ids, external_batch = next(external_data_iter)
elif tune_mode == 'shuffle':
np.random.shuffle(external_batch_ids)
shuffled_ids = labels[:]
if flip_labels:
np.random.shuffle(shuffled_ids)
# Optimize D
_, batch_d_loss, d_summary = self.sess.run([d_optimizer, loss_handle.d_loss,
summary_handle.d_merged],
feed_dict={
real_data: batch_images,
embedding_ids: labels,
learning_rate: current_lr,
external_data: external_batch,
external_ids: external_batch_ids
no_target_data: batch_images,
no_target_ids: shuffled_ids
})
# Optimize G
_, batch_g_loss = self.sess.run([g_optimizer, loss_handle.g_loss],
feed_dict={
real_data: batch_images,
embedding_ids: labels,
learning_rate: current_lr,
external_data: external_batch,
external_ids: external_batch_ids
no_target_data: batch_images,
no_target_ids: shuffled_ids
})
# magic move to Optimize G again
# according to https://github.com/carpedm20/DCGAN-tensorflow
Expand All @@ -591,8 +577,8 @@ def train(self, lr=0.0002, epoch=100, schedule=10, resume=True, tune_mode=None,
real_data: batch_images,
embedding_ids: labels,
learning_rate: current_lr,
external_data: external_batch,
external_ids: external_batch_ids
no_target_data: batch_images,
no_target_ids: shuffled_ids
})
passed = time.time() - start_time
log_format = "Epoch: [%2d], [%4d/%4d] time: %4.4f, d_loss: %.5f, g_loss: %.5f, " + \
Expand Down
14 changes: 5 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@
help='number of batches in between two samples are drawn from validation set')
parser.add_argument('--checkpoint_steps', dest='checkpoint_steps', type=int, default=500,
help='number of batches in between two checkpoints')
parser.add_argument('--tune_mode', dest='tune_mode', type=str, default=None,
help='tune the model in different ways, could be shuffle|external')
parser.add_argument('--external_source', dest='external_source', type=str, default=None,
help='external source of images that used to regulate the model')
parser.add_argument('--flip_labels', dest='flip_labels', type=int, default=None,
help='whether flip training data labels or not, in fine tuning')
args = parser.parse_args()


Expand All @@ -54,10 +52,8 @@ def main(_):
embedding_dim=args.embedding_dim, L1_penalty=args.L1_penalty, Lconst_penalty=args.Lconst_penalty,
Ltv_penalty=args.Ltv_penalty, Lcategory_penalty=args.Lcategory_penalty)
model.register_session(sess)
if args.tune_mode:
if args.tune_mode not in ['shuffle', 'external']:
raise RuntimeError("tune_mode has to be either shuffle or external")
model.build_model(is_training=True, inst_norm=args.inst_norm, with_no_target_source=True)
if args.flip_labels:
model.build_model(is_training=True, inst_norm=args.inst_norm, no_target_source=True)
else:
model.build_model(is_training=True, inst_norm=args.inst_norm)
fine_tune_list = None
Expand All @@ -67,7 +63,7 @@ def main(_):
model.train(lr=args.lr, epoch=args.epoch, resume=args.resume,
schedule=args.schedule, freeze_encoder=args.freeze_encoder, fine_tune=fine_tune_list,
sample_steps=args.sample_steps, checkpoint_steps=args.checkpoint_steps,
tune_mode=args.tune_mode, external_source=args.external_source)
flip_labels=args.flip_labels)


if __name__ == '__main__':
Expand Down

0 comments on commit b73feb0

Please sign in to comment.