In [1]:
import tensorflow as tf
tf.__version__

'2.0.0'

In [2]:
import argparse
import time
import sys
import os
from run_step import *
from datasets import *
from containers import *
from losses import *
from utils import *

In [None]:
if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--config',
#                         type=str,
#                         default='../FUNIT/configs/funit_animals.yaml',
#                         help='configuration file for training and testing')
#     parser.add_argument('--batch_size', type=int, default=0)
#     parser.add_argument('--output_path',
#                         type=str,
#                         default='.',
#                         help="outputs path")
#     parser.add_argument('--test_batch_size',
#                          type=int,
#                          default=4)
#     opts = parser.parse_args()
#     config = get_config(opts.config)
    config = get_config('./configs/funit_animals.yaml')
    output_dir = "./outputs/"
    
#     if opts.batch_size != 0:
#         config['batch_size'] = opts.batch_size

    # Strategy
    strategy = tf.distribute.MirroredStrategy()
    # num_gpus_available = strategy.num_replicas_in_sync

    # Datasets
    datasets = get_datasets(config)
    # - Train
    train_content_dataset = datasets[0]
    train_class_dataset = datasets[1]
    train_dataset = tf.data.Dataset.zip((train_content_dataset, train_class_dataset))
    global_batch_size = config['batch_size']
    def train_ds_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(global_batch_size)
        d = train_dataset.batch(batch_size)
        return d.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
    dist_train_dataset = strategy.experimental_distribute_datasets_from_function(train_ds_fn)
    # - Test
    test_content_dataset = datasets[2]
    test_class_dataset = datasets[3]
    test_dataset = tf.data.Dataset.zip((train_content_dataset, train_class_dataset))
    
    EPOCHS = config['max_iter']
    # Networks
    with strategy.scope():
        networks = FUNIT(config)
        test_networks = FUNIT(config)
        
        #有 networks 在裡頭，不知道如何分離
        def gen_train_step(x, config):
            co_data, cl_data = x
            xa, la = co_data
            xb, lb = cl_data
            global_batch_size = config['batch_size']
            with tf.GradientTape() as g_tape:
                xt_g, xr, xa_gan_feat, xb_gan_feat = networks.gen_update(co_data,cl_data,config)
                
                resp_xr_fake, xr_gan_feat = networks.dis(xr, la)
                resp_xt_fake, xt_gan_feat = networks.dis(xt_g, lb)
                
                # Generator - GAN loss
                l_adv_t = GANloss.gen_loss(resp_xt_fake,lb)
                l_adv_r = GANloss.gen_loss(resp_xr_fake,la)
                # - NOTICE
                l_adv = 0.5 * (l_adv_t + l_adv_r) / global_batch_size
                # Generator - Reconstruction loss
                l_x_rec = recon_loss(xr, xa)
                l_x_rec = tf.reduce_sum(l_x_rec) / global_batch_size
                # Generator - Feature Matching loss
                l_c_rec = featmatch_loss(xr_gan_feat, xa_gan_feat)
                l_c_rec = tf.reduce_sum(l_c_rec) / global_batch_size
                l_m_rec = featmatch_loss(xt_gan_feat, xb_gan_feat)
                l_m_rec = tf.reduce_sum(l_m_rec) / global_batch_size
                
                G_loss = config['gan_w'] * l_adv +\
                         config['r_w'] * l_x_rec +\
                         config['fm_w'] * (l_c_rec + l_m_rec)
                
                loss = G_loss # * (1.0 / global_batch_size)
            gen_grad = g_tape.gradient(loss, networks.gen.trainable_variables)
            networks.opt_gen.apply_gradients(zip(gen_grad, networks.gen.trainable_variables))
            return G_loss
        
        def dis_train_step(x, config):
            co_data, cl_data = x
            xa, la = co_data
            xb, lb = cl_data
            with tf.GradientTape() as d_tape:
                resp_real, real_gen_feat, xt_d, resp_fake, fake_gan_feat =\
                                                    networks.dis_update(co_data,cl_data,config)
                # Discriminator - GAN loss
                l_real = GANloss.dis_loss(resp_real, lb, 'real')
                l_fake = GANloss.dis_loss(resp_fake, lb, 'fake')
                # Discriminator - Gradient Penalty
                l_reg = gradient_penalty(networks.dis, xb, lb)

                D_loss = config['gan_w'] * l_real +\
                         config['gan_w'] * l_fake +\
                         10 * l_reg
                loss = D_loss * (1.0 / config['batch_size'])
            dis_grad = d_tape.gradient(loss, networks.dis.trainable_variables)
            networks.opt_dis.apply_gradients(zip(dis_grad, networks.dis.trainable_variables))
            return D_loss
            
    
    # Problem - https://www.tensorflow.org/tutorials/distribute/custom_training
    with strategy.scope():
        
        @tf.function
        def distributed_train_step(dataset_inputs, config):
            dis_per_replica_losses = strategy.experimental_run_v2(dis_train_step, args=(dataset_inputs, config))
            dis_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, dis_per_replica_losses, axis=None)
            
            gen_per_replica_losses = strategy.experimental_run_v2(gen_train_step, args=(dataset_inputs, config))
            gen_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, gen_per_replica_losses, axis=None)
            return gen_loss, dis_loss
        
        
            
        
        # Checkpoint
        checkpoint_dir = './training_checkpoints'
        gen_ckpt_prefix = os.path.join(checkpoint_dir, "gen_ckpt")
        dis_ckpt_prefix = os.path.join(checkpoint_dir, "dis_ckpt")
        gen_ckpt = tf.train.Checkpoint(optimizer= networks.opt_gen, net= networks.gen)
        dis_ckpt = tf.train.Checkpoint(optimizer= networks.opt_dis, net= networks.dis)
        test_gen_ckpt = tf.train.Checkpoint(optimizer= test_networks.opt_gen, net= test_networks.gen)
        
        iteration = 1
        for epoch in range(1,EPOCHS+1):
            print("epoch %d: " % epoch)
                
            for x in dist_train_dataset:
                start_time = time.time()
                G_loss, D_loss = distributed_train_step(x, config)
                print(" (%d/%d) G_loss: %.4f, D_loss: %.4f, time: %.5f" % (iteration,config['max_iter'],G_loss,D_loss,(time.time() - start_time)))
            
                # Test Step (Print this interval result)
                if iteration % config['image_save_iter'] == 0 or\
                   iteration % config['image_display_iter'] == 0:
                    gen_ckpt.save(gen_ckpt_prefix)
                    dis_ckpt.save(dis_ckpt_prefix)
                    test_gen_ckpt.restore(tf.train.latest_checkpoint(gen_ckpt_prefix))
                    if iteration % config['image_save_iter'] == 0:
                        key_str = '%08d' % iteration
                    else:
                        key_str = 'current'
                    output_train_dataset = train_dataset.batch(global_batch_size).take(4) # opts.test_batch_size
                    output_test_dataset = test_dataset.batch(global_batch_size).take(4) # opts.test_batch_size
                    for idx,(co_data, cl_data) in output_train_dataset.enumerate():
                        test_returns = test_step(test_networks,co_data,cl_data,config)
                        write_images((test_returns['xa'],test_returns['xr'],test_returns['xt'],test_returns['xb']), 
                                     test_returns['display_list'],
                                     os.path.join(output_dir, 'train_%s_%02d' % (key_str, idx)),
                                     max(config['crop_image_height'], config['crop_image_width']))
                    for idx,(co_data, cl_data) in output_test_dataset.enumerate():
                        test_returns = test_step(test_networks,co_data,cl_data,config)
                        write_images((test_returns['xa'],test_returns['xr'],test_returns['xt'],test_returns['xb']), 
                                     test_returns['display_list'],
                                     os.path.join(output_dir, 'test_%s_%02d' % (key_str, idx)),
                                     max(config['crop_image_height'], config['crop_image_width']))

                # Checkpoint
#                 if iteration % config['snapshot_save_iter'] == 0:
#                     gen_path = gen_manager.save()
#                     dis_path = dis_manager.save()
#                     print('Saved model at iteration %d: %s, %s' % (iteration, gen_path, dis_path))
                    
                iteration += 1
                if iteration >= config['max_iter']:
                    print("End of iteration")
                    sys.exit(0)

Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_train.txt
	Number of classes: 119
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_train.txt
	Number of classes: 119
Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_test.txt
	Number of classes: 30
Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_test.txt
	Number of classes: 30
epoch 1: 
INFO:tensorflow:batch_all_reduce: 48 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 50 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

 (101/100000) G_loss: 142.9775, D_loss: 1.6515, time: 0.47674
 (102/100000) G_loss: 149.5659, D_loss: 1.0895, time: 0.46046
 (103/100000) G_loss: 138.6330, D_loss: 1.5499, time: 0.46237
 (104/100000) G_loss: 156.2024, D_loss: 1.0958, time: 0.46056
 (105/100000) G_loss: 150.5532, D_loss: 1.8694, time: 0.45712
 (106/100000) G_loss: 152.6653, D_loss: 1.3480, time: 0.45766
 (107/100000) G_loss: 140.5617, D_loss: 0.8225, time: 0.46278
 (108/100000) G_loss: 150.3112, D_loss: 1.1982, time: 0.46022
 (109/100000) G_loss: 159.9258, D_loss: 1.2904, time: 0.45733
 (110/100000) G_loss: 148.5028, D_loss: 1.3268, time: 0.45877
 (111/100000) G_loss: 138.3707, D_loss: 1.7226, time: 0.45809
 (112/100000) G_loss: 141.7455, D_loss: 1.4467, time: 0.45872
 (113/100000) G_loss: 132.5948, D_loss: 0.5914, time: 0.45863
 (114/100000) G_loss: 170.2041, D_loss: 0.6984, time: 0.45992
 (115/100000) G_loss: 162.4535, D_loss: 0.6633, time: 0.46264
 (116/100000) G_loss: 149.2420, D_loss: 4.6778, time: 0.45981
 (117/10

KeyboardInterrupt: 

In [None]:
# List of working processes of output.
# Split Distributed training example: 
# -- https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py

In [6]:
test1 = tf.random.normal([64,128,128,3])
test2 = tf.random.normal([2,128])

In [4]:
# block = ContentEncoder(3,2,64,'in','relu','reflect')
# block = ClassEncoder(4,64,64,'none','relu','reflect')
# block = Decoder(3,2,32,3,'relu','reflect')
# block = MLP(32,256,3,'relu')