In [1]:
import tensorflow as tf
from trainer.config import config
from trainer.utils import dataset
from trainer import utils
from trainer.models import networks
from trainer import models
from trainer import callbacks

LOG_DIR = config.job_dir
MODEL_DIR = config.model_dir

# Load Data (Build your custom data loader and replace below)
train_horses, train_zebras, test_horses, test_zebras = dataset.generate_dataset()
dataset_count = 1000
# Select and Compile Model
g_AB = networks.create_generator(shape=(config.in_h, config.in_w, 3))

g_BA = networks.create_generator(shape=(config.in_h, config.in_w, 3))

d_A = networks.create_discriminator(shape=(config.in_h, config.in_w, 3))

d_B = networks.create_discriminator(shape=(config.in_h, config.in_w, 3))

model = models.CycleGAN(shape = (None, None, 3),
                        g_AB=g_AB,
                        g_BA=g_BA,
                        d_B=d_B,
                        d_A=d_A)

model.compile(optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
              d_loss='mse',
              g_loss = [
                 'mse', 'mse',
                 'mae', 'mae',
                 'mae', 'mae'
              ], loss_weights = [
                 1,  1,
                 config.cycle_consistency_loss, config.cycle_consistency_loss,
                 1,  1
              ],
              metrics=[utils.ssim, utils.psnr, utils.mae, utils.mse])

Unknown args: ['-f', '/Users/chriszhou/Library/Jupyter/runtime/kernel-865f341a-55b4-4025-9863-c27b4cface8d.json']
Parsed args: {'bs': 1, 'in_h': 256, 'in_w': 256, 'epochs': 20, 'm': True, 'is_test': False, 'cycle_consistency_loss': 10, 'disc_loss': 1, 'id_loss': 1, 'buffer_size': 1000, 'job_dir': 'gs://duke-bme590-cz/ds-cyclegan/tmp/1574131206.8193479', 'model_dir': './trained_models'}


In [3]:
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR, write_graph=True, update_freq='epoch')
start_tensorboard = callbacks.StartTensorBoard(LOG_DIR)

prog_bar = tf.keras.callbacks.ProgbarLogger(count_mode='steps', stateful_metrics=None)
log_code = callbacks.LogCode(LOG_DIR, './trainer')
copy_keras = callbacks.CopyKerasModel(MODEL_DIR, LOG_DIR)

saving = callbacks.MultiModelCheckpoint(MODEL_DIR + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5',
                                        monitor='val_ssim', verbose=1, freq='epoch', mode='max', save_best_only=True,
                                        multi_models=[('g_AB', g_AB), ('g_BA', g_BA), ('d_A', d_A), ('d_B', d_B)])

reduce_lr = callbacks.MultiReduceLROnPlateau(training_models=[model.d_A, model.d_B, model.combined],
                                             monitor='val_ssim', mode='max', factor=0.5, patience=3, min_lr=0.000002)
# early_stopping = callbacks.MultiEarlyStopping(multi_models=[g_AB, g_BA, d_A, d_B], full_model=model,
#                                               monitor='val_ssim', mode='max', patience=1,
#                                               restore_best_weights=True, verbose=1)

image_gen = callbacks.GenerateImages(g_AB, test_horses, test_zebras, LOG_DIR, interval=int(dataset_count/config.bs))

# Fit the model

In [5]:
for x in train_horses:
    break
    


In [2]:
g_AB.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
reflection_padding2d (Reflectio multiple             0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 multiple             9472        reflection_padding2d[0][0]       
__________________________________________________________________________________________________
batch_normalization (BatchNorma multiple             256         conv2d[0][0]                     
______________________________________________________________________________________________