In [1]:
from model_components import *
from data_utils import *

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import time
import functools
import tensorflow as tf
import glob
import cv2
import logging
import os

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

## Dataset Pre-Processing

#### Set paths for unprocessed data

In [2]:
dataset_dir = 'datasets/magnetic_tiles'

free_images_dir = 'MT_Free/Imgs'
blowhole_images_dir = 'MT_Blowhole/Imgs'
break_images_dir = 'MT_Break/Imgs'
crack_images_dir = 'MT_Crack/Imgs'

new_dataset_dir = 'datasets/rescaled_magnetic_tiles'


#### Get names of unprocessed images and their gt

In [3]:
free_img_names, free_img_gt_names = get_img_and_gt_names(dataset_dir, free_images_dir)
blowhole_img_names, blowhole_img_gt_names = get_img_and_gt_names(dataset_dir, blowhole_images_dir)
break_img_names, break_img_gt_names = get_img_and_gt_names(dataset_dir, break_images_dir)
crack_img_names, crack_img_gt_names = get_img_and_gt_names(dataset_dir, crack_images_dir)

#### Process and save defect images and their gt (256x256)

In [4]:
inp, op = crop_and_save_with_gt(crack_img_names, 
                                crack_img_gt_names, 
                                save_dir=os.path.join(new_dataset_dir, crack_images_dir),
                                scale_size=256,
                                random_crop_range=50,
                                max_attempts_per_image=100,
                                gt_thresh=0.5
                               )

print("Image Names: {}, Created {}, Conversion Rate {:.2f}".format(inp, op, op/inp))

Image Names: 57, Created 214, Conversion Rate 3.75


In [5]:
inp, op = crop_and_save_with_gt(blowhole_img_names, 
                                blowhole_img_gt_names, 
                                save_dir=os.path.join(new_dataset_dir, blowhole_images_dir),
                                scale_size=256,
                                random_crop_range=50,
                                max_attempts_per_image=100,
                                gt_thresh=0.5
                               )

print("Image Names: {}, Created {}, Conversion Rate {:.2f}".format(inp, op, op/inp))

Image Names: 115, Created 429, Conversion Rate 3.73


In [6]:
inp, op = crop_and_save_with_gt(break_img_names, 
                                break_img_gt_names, 
                                save_dir=os.path.join(new_dataset_dir, break_images_dir),
                                scale_size=256,
                                random_crop_range=50,
                                max_attempts_per_image=100,
                                gt_thresh=0.5
                               )

print("Image Names: {}, Created {}, Conversion Rate {:.2f}".format(inp, op, op/inp))

Image Names: 85, Created 283, Conversion Rate 3.33


#### Process and save free images (256x256)

In [7]:
inp, op = scale_and_random_crop(free_img_names, 
                                scale_size=256, 
                                random_crop_range=50, 
                                image_count=None, 
                                save_dir=os.path.join(new_dataset_dir, free_images_dir)
                               )


print("Image Names: {}, Created {}, Conversion Rate {:.2f}".format(inp, op,
                                                                   op/inp))

Image Names: 952, Created 2960, Conversion Rate 3.11


#### Load defect image names

In [3]:
# processed defect images (256x256)
blowhole_img_names, blowhole_img_gt_names = get_img_and_gt_names(new_dataset_dir, blowhole_images_dir)
break_img_names, break_img_gt_names = get_img_and_gt_names(new_dataset_dir, break_images_dir)
crack_img_names, crack_img_gt_names = get_img_and_gt_names(new_dataset_dir, crack_images_dir)

#### Load free image names

In [4]:
free_img_names = glob.glob(os.path.join(new_dataset_dir, free_images_dir, "*.jpg"))

train_img_names, test_img_names = train_test_split(free_img_names, train_size=2560)



#### Final image count

In [5]:
print("Free Training Images: {}\nFree Test Images: {}\nBlowhole Test Images: {}\
\nBreak Test Images: {}\nCrack Test Images: {}".format(
                                                                                          len(train_img_names),
                                                                                          len(test_img_names),
                                                                                          len(blowhole_img_names),
                                                                                          len(break_img_names),
                                                                                          len(crack_img_names)
                                                                                         ))

Free Training Images: 2560
Free Test Images: 400
Blowhole Test Images: 427
Break Test Images: 278
Crack Test Images: 213


#### Generators

In [6]:
def training_generator(training_image_names, batch_size):
    
    while(True):
        training_image_names = shuffle(training_image_names)

        for offset in range(0, len(training_image_names), batch_size):
            image_set_names = training_image_names[offset:batch_size+offset]
            training_images = load_and_normalize(image_set_names)

            yield training_images

In [7]:
def testing_generator(testing_image_names, testing_gt_image_names, test_batch_size):
    
    while(True):
        testing_image_names, testing_gt_image_names = shuffle(testing_image_names, testing_gt_image_names)
        
        for offset in range(0, len(testing_image_names), test_batch_size):
            image_set_names = testing_image_names[offset:test_batch_size+offset]
            image_gt_set_names = testing_gt_image_names[offset:test_batch_size+offset]
            
            testing_images = load_and_normalize(image_set_names)
            testing_gt_images = load_and_normalize(image_gt_set_names)
            
            yield testing_images, testing_gt_images

### Training and Evaluation

In [9]:
def get_summary_dir(checkpoint_dir):
    r = glob.glob(os.path.join(checkpoint_dir, "logs*"))
    log_dir_name = os.path.join(checkpoint_dir, "logs{}".format(str(len(r))))
                  
    return log_dir_name

In [10]:
def train(trainx_names, testx_names, 
          blowhole_x_names, blowhole_y_names, 
          crack_x_names, crack_y_names, 
          break_x_names, break_y_names
         ):
    
    penultimate_layer_units = 1024
    latent_dimensions = 200
    batch_size = 16
    # FREQ_PRINT = 80
    learning_rate = 0.00005
    nb_epochs = 500
    freq_epoch_test = 5
    num_test_images = 32
    
    # Image input placeholder
    x = tf.placeholder(dtype=tf.float32, shape=(None,256,256,3))
    
    # Ground truth input placeholder
    gt = tf.placeholder(dtype=tf.float32, shape=(None,256,256,3))
    
    # mean and variance of the free image scores
    mean_inp = tf.placeholder(dtype=tf.float32)
    var_inp = tf.placeholder(dtype=tf.float32)
    
    # Training mode placeholder
    training_mode = tf.placeholder(dtype=tf.bool)
    
    # Encoder
    with tf.variable_scope('encoder_model'):
        encoding_real_image = get_encoder_model(x, latent_dimensions=latent_dimensions,
                                             reuse=False, training_mode=True)
    # Generator
    with tf.variable_scope('generator_model'):
        z = tf.random_normal([batch_size, latent_dimensions])
        generated_image = get_generator_model(z, reuse=False, training_mode=True)
        regenerated_real_image = get_generator_model(encoding_real_image, reuse=True, training_mode=False)
    
    # Discriminator
    with tf.variable_scope('discriminator_model'):
        discriminator_fake, dis_fake_penultimate_layer = get_discriminator_model(generated_image, z, reuse=False, 
                                                     training_mode=True, penultimate_layer_units=penultimate_layer_units)
        discriminator_real, dis_real_penultimate_layer = get_discriminator_model(x, encoding_real_image, reuse=True,
                                                     training_mode=True, penultimate_layer_units=penultimate_layer_units)
    
    # Prepare labels for the loss functions
    with tf.variable_scope('labels'):
        
        # Step 1
        # Set swapped labels
        labels_dis_enc = tf.zeros_like(discriminator_real)
        labels_dis_gen = tf.ones_like(discriminator_fake)
        labels_gen = tf.zeros_like(discriminator_fake)
        labels_enc = tf.ones_like(discriminator_real)
        
        # Step 2
        # Create soft labels for the discriminator
        random_soft = tf.random_uniform(shape=(tf.shape(labels_dis_enc)), minval=0.0, maxval=0.1)
        soft_labels_dis_enc = tf.add(labels_dis_enc, random_soft)
        soft_labels_dis_gen = tf.subtract(labels_dis_gen, random_soft)

        # Step 3
        # With a low chance, assign noisy (swapped) labels
        random_flip = tf.ones_like(labels_dis_enc) * tf.random_uniform(shape=(1,), minval=0, maxval=1)
        mask = random_flip >= 0.05
        labels_dis_enc = tf.where(mask, soft_labels_dis_enc, soft_labels_dis_gen)
        labels_dis_gen = tf.where(mask, soft_labels_dis_gen, soft_labels_dis_enc)
    
    # Loss Functions
    with tf.variable_scope('loss_functions'):
        loss_dis_enc = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_dis_enc,
                                                                              logits=discriminator_real))
        loss_dis_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(discriminator_fake),
                                                                              logits=discriminator_fake))
        loss_discriminator = loss_dis_gen + loss_dis_enc
        # generator
        loss_generator = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_gen,
                                                                                logits=discriminator_fake))
        # encoder
        loss_encoder = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_enc,
                                                                              logits=discriminator_real))
    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        dvars = [var for var in tvars if 'discriminator_model' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]
        evars = [var for var in tvars if 'encoder_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)]
        update_ops_enc = [x for x in update_ops if ('encoder_model' in x.name)]
        update_ops_dis = [x for x in update_ops if ('discriminator_model' in x.name)]

        optimizer_dis = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, name='dis_optimizer')
        optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, name='gen_optimizer')
        optimizer_enc = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, name='enc_optimizer')

        with tf.control_dependencies(update_ops_gen):
            gen_op = optimizer_gen.minimize(loss_generator, var_list=gvars)
        with tf.control_dependencies(update_ops_enc):
            enc_op = optimizer_enc.minimize(loss_encoder, var_list=evars, global_step=tf.train.get_or_create_global_step())
        with tf.control_dependencies(update_ops_dis):
            dis_op = optimizer_dis.minimize(loss_discriminator, var_list=dvars)

    with tf.name_scope('loss'):
        with tf.name_scope('discriminator'):
            tf.summary.scalar('loss_total', loss_discriminator, ['dis'])
            tf.summary.scalar('loss_dis_enc', loss_dis_enc, ['dis'])
            tf.summary.scalar('loss_dis_gen', loss_dis_gen, ['dis'])

        with tf.name_scope('generator'):
            tf.summary.scalar('loss_generator', loss_generator, ['gen'])
            tf.summary.scalar('loss_encoder', loss_encoder, ['gen'])

    with tf.name_scope('train_img_regen'):
        for p in range(4):
            tf.summary.image('img_{}_regen'.format(p+1), regenerated_real_image[p:p+1,:,:,:], 1, ['image'])
            tf.summary.image('img_{}_input'.format(p+1), x[p:p+1,:,:,:], 1, ['image'])

    sum_op_dis = tf.summary.merge_all('dis')
    sum_op_gen = tf.summary.merge_all('gen')
    sum_op_im = tf.summary.merge_all('image')

        
    '''
    ----------------------------------------TRAINING OPS END HERE-----------------------------------------------
    '''    
        
    # TESTING GRAPH


    # Encoder
    with tf.variable_scope('encoder_model'):
        encoding_test = get_encoder_model(x, latent_dimensions=latent_dimensions,
                                             reuse=True, training_mode=False)
    # Generator
    with tf.variable_scope('generator_model'):
        regenerated_image_test = get_generator_model(encoding_test, reuse=True, training_mode=False)

    # Discriminator
    with tf.variable_scope('discriminator_model'):
        discriminator_fake_test, dis_fake_penultimate_layer_test = get_discriminator_model(regenerated_image_test, 
                                                                                      encoding_test, 
                                                                                      reuse=True, 
                                                                                      training_mode=True, 
                                                                                      penultimate_layer_units=penultimate_layer_units
                                                                                     )

        discriminator_real_test, dis_real_penultimate_layer_test = get_discriminator_model(x, 
                                                                                      encoding_test, 
                                                                                      reuse=True,
                                                                                      training_mode=True, 
                                                                                      penultimate_layer_units=penultimate_layer_units
                                                                                     )

    with tf.name_scope('testing'):
        with tf.variable_scope('reconstruction_loss'):
            delta = x - regenerated_image_test
            delta_flat = tf.layers.flatten(delta)
            gen_score = tf.norm(delta_flat, ord='euclidean', axis=1,
                              keep_dims=False, name='epsilon')

        with tf.variable_scope('discriminator_loss'):
            fm = dis_real_penultimate_layer_test - dis_fake_penultimate_layer_test
            fm = tf.contrib.layers.flatten(fm)
            dis_score = tf.norm(fm, ord='euclidean', axis=1,
                             keep_dims=False, name='d_loss')
            dis_score = tf.squeeze(dis_score)

            
        weight1, weight2, weight3, weight4, weight5 = 0.1, 0.2, 0.3, 0.4, 0.5 
        
        with tf.variable_scope('score'):
            mean_score1 = tf.reduce_mean((1 - weight1) * gen_score + weight1 * dis_score)
            mean_score2 = tf.reduce_mean((1 - weight2) * gen_score + weight2 * dis_score)
            mean_score3 = tf.reduce_mean((1 - weight3) * gen_score + weight3 * dis_score)
            mean_score4 = tf.reduce_mean((1 - weight4) * gen_score + weight4 * dis_score)
            mean_score5 = tf.reduce_mean((1 - weight5) * gen_score + weight5 * dis_score)
            

    with tf.name_scope('test_anomaly_score'):
        tf.summary.scalar("mean_score_w=0.1", mean_score1, ['scr'])
        tf.summary.scalar("mean_score_w=0.2", mean_score2, ['scr'])
        tf.summary.scalar("mean_score_w=0.3", mean_score3, ['scr'])
        tf.summary.scalar("mean_score_w=0.4", mean_score4, ['scr'])
        tf.summary.scalar("mean_score_w=0.5", mean_score5, ['scr'])
        
        
    with tf.variable_scope('accuracy'):
        
        # For defect accuracy calculation
        all_test_scores = (1 - weight1) * gen_score + weight1 * dis_score
        free_thresh_0 = mean_inp
        free_thresh_1 = mean_inp + tf.sqrt(var_inp)
        free_thresh_2 = mean_inp + 2* tf.sqrt(var_inp)
        free_thresh_3 = mean_inp + 3 * tf.sqrt(var_inp)
        
        bool_list_0 = tf.greater_equal(all_test_scores, free_thresh_0)
        test_acc_0 = tf.reduce_sum(tf.cast(bool_list_0, tf.int32))/tf.size(all_test_scores)
        
        bool_list_1 = tf.greater_equal(all_test_scores, free_thresh_1)
        test_acc_1 = tf.reduce_sum(tf.cast(bool_list_1, tf.int32))/tf.size(all_test_scores)
        
        bool_list_2 = tf.greater_equal(all_test_scores, free_thresh_2)
        test_acc_2 = tf.reduce_sum(tf.cast(bool_list_2, tf.int32))/tf.size(all_test_scores)
        
        bool_list_3 = tf.greater_equal(all_test_scores, free_thresh_3)
        test_acc_3 = tf.reduce_sum(tf.cast(bool_list_3, tf.int32))/tf.size(all_test_scores)
        
        # For calculating optimal anomaly score based on free image scores
        mean, var = tf.nn.moments(all_test_scores, axes=[0])

        
    
    with tf.name_scope('test_accuracy'):
        
        tf.summary.scalar('threshold with w=0.1, stddev_0', test_acc_0, ['test_acc'])
        tf.summary.scalar('threshold with w=0.1, stddev_1', test_acc_1, ['test_acc'])
        tf.summary.scalar('threshold with w=0.1, stddev_2', test_acc_2, ['test_acc'])
        tf.summary.scalar('threshold with w=0.1, stddev_3', test_acc_3, ['test_acc'])
    
    
    with tf.name_scope('test_img_regen'):
        for p in range(2):
            tf.summary.image('{}_0_input'.format(p+1), x[p:p+1,:,:,:], 1, ['t_image'])
            tf.summary.image('{}_1_regen'.format(p+1), regenerated_image_test[p:p+1,:,:,:], 1, ['t_image'])
            tf.summary.image('{}_2_ground_truth'.format(p+1), gt[p:p+1,:,:,:], 1, ['t_image'])
            tf.summary.image('{}_3_difference'.format(p+1), delta[p:p+1,:,:,:], 1, ['t_image'])
            
            
    sum_op_scr = tf.summary.merge_all('scr')
    sum_op_t_img = tf.summary.merge_all('t_image')
    sum_op_test_acc = tf.summary.merge_all('test_acc')
    
    gs = tf.train.get_global_step()
    
    
    '''
    ----------------------------------------TEST OPS END HERE-----------------------------------------------
    '''    
    
    
    # TRAINING
    
    train_gen = training_generator(trainx_names, batch_size=batch_size)
    test_gen = training_generator(testx_names, batch_size=batch_size)
    
    blowhole_gen = testing_generator(blowhole_x_names, blowhole_y_names, test_batch_size=num_test_images)
    crack_gen = testing_generator(crack_x_names, crack_y_names, test_batch_size=num_test_images)
    break_gen = testing_generator(break_x_names, break_y_names, test_batch_size=num_test_images)
               
            
    checkpoint_dir = "train/train01/"
    summary_dir = get_summary_dir(checkpoint_dir)
    
    free_writer = tf.summary.FileWriter(os.path.join(summary_dir, "free"))
    blowhole_writer = tf.summary.FileWriter(os.path.join(summary_dir, "blowhole"))
    crack_writer = tf.summary.FileWriter(os.path.join(summary_dir, "crack"))
    break_writer = tf.summary.FileWriter(os.path.join(summary_dir, "break"))


    step_saver =tf.train.CheckpointSaverHook(checkpoint_dir=checkpoint_dir, save_steps=1600, save_secs=None)

    summary_saver = tf.train.SummarySaverHook(save_steps=8,
                                              save_secs=None,
                                              output_dir=summary_dir, 
                                              summary_op=[sum_op_dis, sum_op_gen, sum_op_im]
                                             )
    config = tf.ConfigProto(log_device_placement=True)
    
    mnt = tf.train.MonitoredTrainingSession(config=config, hooks=[step_saver, summary_saver])


    
    with mnt as sess:
        
        
        train_batch = 0
        epoch = 0
        step = 0
        while not mnt.should_stop() and epoch < nb_epochs:

            begin = time.time()
            nr_batches_train = int(len(trainx_names) / batch_size)

            # get data using generator
            trainx = next(train_gen)
            
            train_loss_dis, train_loss_gen, train_loss_enc = [0, 0, 0]

            # training
            for t in range(nr_batches_train):

                print("Starting Epoch {}, Batch {}, Step {}".format(epoch+1, t+1, step+1))     
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train discriminator
                feed_dict = {x: trainx,
                             training_mode: True,
                             }

                _, ld, sm = sess.run([dis_op,
                                      loss_discriminator,
                                      sum_op_dis],
                                     feed_dict=feed_dict)
                train_loss_dis += ld

                # train generator and encoder
                feed_dict = {x: trainx,
                             training_mode: True,
                             }
                _,_, le, lg, sm = sess.run([gen_op,
                                            enc_op,
                                            loss_encoder,
                                            loss_generator,
                                            sum_op_gen],
                                           feed_dict=feed_dict)
                train_loss_gen += lg
                train_loss_enc += le


                train_batch += 1
                step+=1

            train_loss_gen /= nr_batches_train
            train_loss_enc /= nr_batches_train
            train_loss_dis /= nr_batches_train

            print("Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | loss dis = %.4f "
                  % (epoch+1, time.time() - begin, train_loss_gen, train_loss_enc, train_loss_dis))
            
            # Inspect reconstruction
            if (epoch+1) % freq_epoch_test == 0:  
                    ran_from = 0
                    ran_to =  4
                    sm = sess.run(sum_op_im, feed_dict={x: trainx[ran_from:ran_to],training_mode: False})
            
            # Test
            
            if (epoch+1) % freq_epoch_test == 0:
                print("Evaluating")
                
                # Shuffle
                testx = next(test_gen)

                blowhole_x, blowhole_y = next(blowhole_gen)
                break_x, break_y = next(break_gen)
                crack_x, crack_y = next(crack_gen)
                

                # Free Test
                free_score_summary, free_t_img_summary, current_step, mean_score, var_score = \
                                                                            sess.run([sum_op_scr, sum_op_t_img, gs,
                                                                                     mean, var], 
                                                                            feed_dict={x: testx,
                                                                                       gt: np.zeros_like(testx),
                                                                                       mean_inp: 0,
                                                                                       var_inp: 0,
                                                                                       training_mode: False})
                free_writer.add_summary(free_score_summary, current_step)
                free_writer.add_summary(free_t_img_summary, current_step)
                free_writer.flush()
                

                # Blowhole
                blowhole_score_summary, blowhole_t_img_summary, blowhole_acc_summary = \
                                                                          sess.run([sum_op_scr, sum_op_t_img, sum_op_test_acc], 
                                                                          feed_dict={x: blowhole_x,
                                                                                     gt: blowhole_y,
                                                                                     mean_inp: mean_score,
                                                                                     var_inp: var_score,
                                                                                     training_mode: False})
                blowhole_writer.add_summary(blowhole_score_summary, current_step)
                blowhole_writer.add_summary(blowhole_t_img_summary, current_step)
                blowhole_writer.add_summary(blowhole_acc_summary, current_step)
                blowhole_writer.flush()
                
                # Crack
                crack_score_summary, crack_t_img_summary, crack_acc_summary = sess.run([sum_op_scr, sum_op_t_img, sum_op_test_acc], 
                                                                    feed_dict={x: crack_x, 
                                                                               gt: crack_y,
                                                                               mean_inp: mean_score,
                                                                               var_inp: var_score,
                                                                               training_mode: False})
                crack_writer.add_summary(crack_score_summary, current_step)
                crack_writer.add_summary(crack_t_img_summary, current_step)
                crack_writer.add_summary(crack_acc_summary, current_step)
                crack_writer.flush()
                    
                # Break
                break_score_summary, break_t_img_summary, break_acc_summary = sess.run([sum_op_scr, sum_op_t_img, sum_op_test_acc], 
                                                                    feed_dict={x: break_x,
                                                                               gt: break_y,
                                                                               mean_inp: mean_score,
                                                                               var_inp: var_score,
                                                                               training_mode: False})
                break_writer.add_summary(break_score_summary, current_step)
                break_writer.add_summary(break_t_img_summary, current_step)
                break_writer.add_summary(break_acc_summary, current_step)
                break_writer.flush()
                
            
            epoch += 1

## II. Train

In [43]:
with tf.Graph().as_default():
    train(
          train_img_names, test_img_names, 
          blowhole_img_names, blowhole_img_gt_names, 
          crack_img_names, crack_img_gt_names,
          break_img_names, break_img_gt_names
         )
    



Encoder: 

Input shape of x is (?, 256, 256, 3)
Output shape of layer_01 is (?, 128, 128, 16)
Output shape of layer_02 is (?, 64, 64, 32)
Output shape of layer_03 is (?, 32, 32, 64)
Output shape of layer_04 is (?, 16, 16, 128)
Output shape of layer_04 is (?, 200)

Generator:

Input shape of z is (16, 200)
Output shape of layer_01 is (16, 8, 8, 1024)
Output shape of layer_02 is (16, 16, 16, 512)
Output shape of layer_03 is (16, 32, 32, 256)
Output shape of layer_04 is (16, 64, 64, 128)
Output shape of layer_05 is (16, 128, 128, 64)
Output shape of layer_06 is (16, 256, 256, 3)

Generator:

Input shape of z is (?, 200)
Output shape of layer_01 is (?, 8, 8, 1024)
Output shape of layer_02 is (?, 16, 16, 512)
Output shape of layer_03 is (?, 32, 32, 256)
Output shape of layer_04 is (?, 64, 64, 128)
Output shape of layer_05 is (?, 128, 128, 64)
Output shape of layer_06 is (?, 256, 256, 3)

Discriminator: 

Input shape of x is (16, 256, 256, 3)
Output shape of x_layer_01 is (16, 128, 128, 64)

Starting Epoch 24, Batch 2, Step 48
Epoch 24 | time = 3s | loss gen = 0.9173 | loss enc = 0.5829 | loss dis = 1.7688 
Evaluating
Starting Epoch 25, Batch 1, Step 49
Starting Epoch 25, Batch 2, Step 50
Epoch 25 | time = 2s | loss gen = 0.8861 | loss enc = 0.5465 | loss dis = 1.6589 
Starting Epoch 26, Batch 1, Step 51
Starting Epoch 26, Batch 2, Step 52
Epoch 26 | time = 2s | loss gen = 0.8199 | loss enc = 0.6299 | loss dis = 1.7940 
Evaluating
Starting Epoch 27, Batch 1, Step 53
Starting Epoch 27, Batch 2, Step 54
Epoch 27 | time = 3s | loss gen = 0.7798 | loss enc = 0.6348 | loss dis = 1.5995 
Starting Epoch 28, Batch 1, Step 55
Starting Epoch 28, Batch 2, Step 56
Epoch 28 | time = 2s | loss gen = 0.8983 | loss enc = 0.6130 | loss dis = 1.7615 
Evaluating
Starting Epoch 29, Batch 1, Step 57
Starting Epoch 29, Batch 2, Step 58
Epoch 29 | time = 2s | loss gen = 0.8677 | loss enc = 0.6114 | loss dis = 1.8158 
Starting Epoch 30, Batch 1, Step 59
Starting Epoch 30, Batch 2, Step 60
Epoch 3

KeyboardInterrupt: 