In [1]:
import tensorflow as tf
tf.__version__

'2.0.0'

In [18]:
# limit GPU growth

In [9]:
from datasets import get_dataset
from blocks import *
from losses import *
from utils import show

In [2]:
# config
data_loader_train = "../FUNIT/datasets/animals/"
data_list_train = "../FUNIT/datasets/animals_list_train.txt"
data_loader_test = ""
data_list_test = ""
crop_size = (128,128) #(height,width)
resize_size = (140,140)

In [None]:
import argparse
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='../FUNIT/configs/funit_animals.yaml',
                        help='configuration file for training and testing')
    parser.add_argument('--batch_size', type=int, default=0)
    opts = parser.parse_args()
    config = get_config(opts.config)
    
    epochs = config['max_iter']
    if opts.batch_size != 0:
        config['batch_size'] = opts.batch_size
        
    # Networks
    networks = FUNIT(config)

    # Datasets
    datasets = get_datasets(config)
    train_content_dataset = datasets[0]
    train_class_dataset = datasets[1]
    train_dataset = tf.data.Dataset.zip((train_content_dataset, train_class_dataset))
    test_content_dataset = datasets[2]
    test_class_dataset = datasets[3]
    test_dataset = tf.data.Dataset.zip((train_content_dataset, train_class_dataset))
    
    for epoch in range(epochs):
        for (co_data, cl_data) in train_dataset:
            train_step(co_data,cl_data,config)
    
    # NOTICED - need to check correctness of AdaIN
    # layer = AdaptiveInstanceNorm2D()

In [None]:
# @tf.function
def train_step(nets,co_data,cl_data,config):
    xa, la = co_data
    xb, lb = cl_data
    return_items = {}
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        xt, xr, xa_gan_feat, xb_gan_feat = nets.gen_update(co_data,cl_data,config)
        resp_real, real_gen_feat, xt, resp_fake, fake_gan_feat = nets.dis_update(co_data,cl_data,config)
        
        # Generator - GAN loss
        l_adv_t = GANloss.gen_loss(xt,lb)
        l_adv_r = GANloss.gen_loss(xr,la)
        l_adv = 0.5 * (l_adv_t + l_adv_r)
        # Generator - Reconstruction loss
        l_x_rec = recon_loss(xr, xa)
        # Generator - Feature Matching loss
        _, xr_gan_feat = nets.dis(xr, la)
        _, xt_gan_feat = nets.dis(xt, lb)
        l_c_rec = featmatch_loss(xr_gan_feat, xa_gan_feat)
        l_m_rec = featmatch_loss(xt_gan_feat, xb_gan_feat)
        
        G_loss = config['gan_w'] * l_dav +\
                 config['r_w'] * l_x_rec +\
                 config['fm_w'] * (l_c_rec + l_m_rec)
        
        # Discriminator - GAN loss
        l_real = GANloss.dis_loss(xb, lb, 'real')
        l_fake = GANloss.dis_loss(xt, lb, 'fake')
        # Discriminator - Gradient Penalty
        # ??? grad_dout = autograd.grad(...)[0]
        l_reg = gradient_penalty(resp_real, xb)
        
        D_loss = config['gan_w'] * l_real +\
                 config['gan_w'] * l_fake +\
                 10 * l_reg
        
    # Update Gradient
    # - Gradient computing
    gen_grad = g_tape.gradient(G_loss, nets.gen.trainable_variables)
    dis_grad = d_tape.gradient(D_loss, nets.dis.trainable_variables)
    # - Optimizer
    nets.opt_gen.apply_gradients(zip(gen_grad, nets.gen.trainable_variables))
    nets.opt_gen.apply_gradients(zip(dis_grad, nets.dis.trainable_variables))
    
    return_items['G_loss'] = G_loss.numpy()
    return_items['D_loss'] = D_loss.numpy()
    return return_items

In [None]:
#================================ All Outputs
# 先用Dataset輸出到matplot測試
import matplotlib.pyplot as plt
# Constrained Layout Guide - https://matplotlib.org/3.1.1/tutorials/intermediate/constrainedlayout_guide.html
# Customizing Figure Layouts - https://matplotlib.org/3.1.1/tutorials/intermediate/gridspec.html
# matplotlib.pyplot.figure - https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.figure.html
def write_images(images, display_labels):
    
#================================ TensorBoard (later)
#================================ Check Points (later)
#================================ Distributed Training (later)

In [5]:
test1 = tf.random.normal([2,25,25,32])
test2 = tf.random.normal([2,128])

In [4]:
# block = ContentEncoder(3,2,64,'in','relu','reflect')
# block = ClassEncoder(4,64,64,'none','relu','reflect')
# block = Decoder(3,2,32,3,'relu','reflect')
# block = MLP(32,256,3,'relu')

In [6]:
result = block(test2)