In [1]:
import tensorflow as tf
tf.__version__

'2.0.0'

In [2]:
# limit GPU growth
physical_devices = tf.config.experimental.list_physical_devices('GPU')
print(len(physical_devices))
assert len(physical_devices) > 0, 'Not enough GPU hardware devices available'
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(physical_device, True)

4


In [2]:
import argparse
from run_step import *
from datasets import *
from containers import *
from losses import *
from utils import *

In [None]:
if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--config',
#                         type=str,
#                         default='../FUNIT/configs/funit_animals.yaml',
#                         help='configuration file for training and testing')
#     parser.add_argument('--batch_size', type=int, default=0)
#     parser.add_argument('--output_path',
#                         type=str,
#                         default='.',
#                         help="outputs path")
#     parser.add_argument('--test_batch_size',
#                          type=int,
#                          default=4)
#     opts = parser.parse_args()
#     config = get_config(opts.config)
    config = get_config('./configs/funit_animals.yaml')
    
#     if opts.batch_size != 0:
#         config['batch_size'] = opts.batch_size

    # Strategy
    strategy = tf.distribute.MirroredStrategy()
    # num_gpus_available = strategy.num_replicas_in_sync

    # Datasets
    datasets = get_datasets(config)
    # - Train
    train_content_dataset = datasets[0]
    train_class_dataset = datasets[1]
    train_dataset = tf.data.Dataset.zip((train_content_dataset, train_class_dataset))
    global_batch_size = config['batch_size']
    def train_ds_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(global_batch_size)
        d = train_dataset.batch(batch_size)
        return d.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
    dist_train_dataset = strategy.experimental_distribute_datasets_from_function(train_ds_fn)
    # - Test
#     test_content_dataset = datasets[2]
#     test_class_dataset = datasets[3]
#     test_dataset = tf.data.Dataset.zip((train_content_dataset, train_class_dataset))
#     def test_ds_fn(input_context):
#         batch_size = input_context.get_per_replica_batch_size(global_batch_size)
#         d = test_dataset.batch(batch_size)
#         return d.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
#     dist_test_dataset = strategy.experimental_distribute_datasets_from_function(test_ds_fn)
    
    # Mean loss
    lossnames = ["G_loss","D_loss"]
    metrics_list = []
    for itemname in lossnames:
        metrics_list.append(tf.keras.metrics.Mean(itemname, dtype=tf.float32))
    
    EPOCHS = config['max_iter']
    # Networks
    with strategy.scope():
        networks = FUNIT(config)

        def gen_train_step(x, config):
            co_data, cl_data = x
            xa, la = co_data
            xb, lb = cl_data
            global_batch_size = config['batch_size']
            with tf.GradientTape() as g_tape:
                xt_g, xr, xa_gan_feat, xb_gan_feat = networks.gen_update(co_data,cl_data,config)
                
                resp_xr_fake, xr_gan_feat = networks.dis(xr, la)
                resp_xt_fake, xt_gan_feat = networks.dis(xt_g, lb)
                
                # Generator - GAN loss
                l_adv_t = GANloss.gen_loss(resp_xt_fake,lb)
                l_adv_r = GANloss.gen_loss(resp_xr_fake,la)
                l_adv = 0.5 * (l_adv_t + l_adv_r)
                # Generator - Reconstruction loss
                l_x_rec = recon_loss(xr, xa)
                l_x_rec = tf.reduce_sum(l_x_rec) / global_batch_size
                # Generator - Feature Matching loss
                l_c_rec = featmatch_loss(xr_gan_feat, xa_gan_feat)
                l_c_rec = tf.reduce_sum(l_c_rec) / global_batch_size
                l_m_rec = featmatch_loss(xt_gan_feat, xb_gan_feat)
                l_m_rec = tf.reduce_sum(l_m_rec) / global_batch_size
                
                G_loss = config['gan_w'] * l_adv +\
                         config['r_w'] * l_x_rec +\
                         config['fm_w'] * (l_c_rec + l_m_rec)
                
                loss = G_loss * (1.0 / global_batch_size)
            gen_grad = g_tape.gradient(loss, networks.gen.trainable_variables)
            networks.opt_gen.apply_gradients(zip(gen_grad, networks.gen.trainable_variables))
            return G_loss
        
        def dis_train_step(x, config):
            co_data, cl_data = x
            xa, la = co_data
            xb, lb = cl_data
            with tf.GradientTape() as d_tape:
                resp_real, real_gen_feat, xt_d, resp_fake, fake_gan_feat =\
                                                    networks.dis_update(co_data,cl_data,config)
                # Discriminator - GAN loss
                l_real = GANloss.dis_loss(resp_real, lb, 'real')
                l_fake = GANloss.dis_loss(resp_fake, lb, 'fake')
                # Discriminator - Gradient Penalty
                l_reg = gradient_penalty(networks.dis, xb, lb)

                D_loss = config['gan_w'] * l_real +\
                         config['gan_w'] * l_fake +\
                         10 * l_reg
                loss = D_loss * (1.0 / config['batch_size'])
            dis_grad = d_tape.gradient(loss, networks.dis.trainable_variables)
            networks.opt_dis.apply_gradients(zip(dis_grad, networks.dis.trainable_variables))
            return D_loss
    
    # Problem - https://www.tensorflow.org/tutorials/distribute/custom_training
    with strategy.scope():
        
        @tf.function
        def distributed_train_step(dataset_inputs, config):
            dis_per_replica_losses = strategy.experimental_run_v2(dis_train_step, args=(dataset_inputs, config))
            dis_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, dis_per_replica_losses, axis=None)
            
            gen_per_replica_losses = strategy.experimental_run_v2(gen_train_step, args=(dataset_inputs, config))
            gen_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, gen_per_replica_losses, axis=None)
            return gen_loss, dis_loss
        
        for epoch in range(EPOCHS):
            print("epoch %d: " % epoch)
            for x in dist_train_dataset:
                G_loss, D_loss = distributed_train_step(x, config)
                print(" G_loss: %.4f, D_loss: %.4f" % (G_loss,D_loss), end='\r')

Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_train.txt
	Number of classes: 119
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_train.txt
	Number of classes: 119
Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_test.txt
	Number of classes: 30
Data Loader
	Root: ../FUNIT/datasets/animals/
	List: ../FUNIT/datasets/animals_list_test.txt
	Number of classes: 30
epoch 0: 
INFO:tensorflow:batch_all_reduce: 48 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 50 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10


In [None]:
    for epoch in range(epochs):
        print("epoch %d:" % epoch)
        for (co_data, cl_data) in train_dataset:
            dis_per_replica_result = strategy.experimental_run_v2(train_dis, args=(networks, co_data, cl_data, config))
            dis_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                        dis_per_replica_result,
                                        axis=None)
            gen_per_replica_result = strategy.experimental_run_v2(train_gen, args=(networks, co_data, cl_data, config))
            gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                        gen_per_replica_result,
                                        axis=None)
            print(" G_loss: %.4f, D_loss: %.4f" % (gen_loss,dis_loss), end='\r')

#             train_returns = train_step(networks,co_data,cl_data,config)
#             print(" G_loss: %.4f, D_loss: %.4f" % (train_returns['G_loss'],train_returns['D_loss']), end='\r')
#             for idx, itemname in enumerate(lossnames):
#                 metrics_list[idx](train_returns[itemname])

#         for idx, itemname in enumerate(lossnames):
#             print("    {}: {:.4f}".format(itemname,metrics_list[idx].result()))
#             metrics_list[idx].reset_states()

        '''if epoch % config['image_save_iter'] == 0 or\
            epoch % config['image_display_iter'] == 0:
            if epoch % config['image_save_iter'] == 0:
                key_str = '%08d' % (epoch + 1)
            else:
                key_str = 'current'
            output_train_dataset = train_dataset.take(opts.test_batch_size)
            output_test_dataset = test_dataset.take(opts.test_batch_size)
            for idx, (co_data, cl_data) in output_train_dataset.enumerate():
                test_returns = test_step(co_data,cl_data,config)
                write_images(zip(test_returns['xa'],test_returns['xr'],test_returns['xt'],test_returns['xb']), 
                             test_returns['display_list'],
                             'train_%s_%02d' % (key_str, idx),
                             max(config['crop_image_height'], config['crop_image_width']))
            for idx, (co_data, cl_data) in output_test_dataset.enumerate():
                test_returns = test_step(co_data,cl_data,config)
                write_images(zip(test_returns['xa'],test_returns['xr'],test_returns['xt'],test_returns['xb']), 
                             test_returns['display_list'],
                             'test_%s_%02d' % (key_str, idx),
                             max(config['crop_image_height'], config['crop_image_width']))'''

In [None]:
# List of working processes of output.
#================================ TensorBoard (later)
#================================ Check Points (later)
#================================ Distributed Training (later)

In [6]:
test1 = tf.random.normal([64,128,128,3])
test2 = tf.random.normal([2,128])

In [4]:
# block = ContentEncoder(3,2,64,'in','relu','reflect')
# block = ClassEncoder(4,64,64,'none','relu','reflect')
# block = Decoder(3,2,32,3,'relu','reflect')
# block = MLP(32,256,3,'relu')