In [None]:
#import tensorflow as tf
#tf.__version__
#!conda install -y gdown
#!gdown --id 13t7OBZ1_cyQySXgyTgiIQVP6mu3DBJlT

In [None]:
%load_ext tensorboard
%matplotlib inline

import io
import os
import argparse
from glob import glob
import matplotlib.pyplot as plt
import IPython.display as display
import imageio
import numpy as np

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import tensorflow as tf
from tensorflow.keras import Model, layers, losses, metrics, regularizers, optimizers, initializers
from tensorflow.keras.utils import plot_model

# Default paths.
SCRIPT_PATH = './LSGAN'
EXPERIMENT_ID = 'EXP_1'
MODEL_SAVE_PATH = os.path.join(SCRIPT_PATH, EXPERIMENT_ID)
if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_PATH)
IMG_SAVE_PATH = os.path.join(MODEL_SAVE_PATH, 'generated_img')
if not os.path.exists(IMG_SAVE_PATH):
    os.makedirs(IMG_SAVE_PATH)

DEFAULT_TFRECORDS_DIR = '../input/dcgan-dataset'
DEFAULT_NUM_EPOCHS = 2000
DEFAULT_LEARNING_RATE = 1e-4
DEFAULT_BATCH_SIZE = 64
DEFAULT_SAVE_PERIOD = 15
DEFAULT_LATENT_DEPTH = 100

IMAGE_WIDTH = 128
IMAGE_HEIGHT = 128
IMAGE_CHANNEL = 3
NUM_INPUT_DATA = 8960

args = None

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--tfrecords-dir', type=str, dest='tfrecords_dir',
                        default=DEFAULT_TFRECORDS_DIR,
                        help='Directory of TFRecords files.')
parser.add_argument('--num-train-epochs', type=int,
                        dest='num_train_epochs',
                        default=DEFAULT_NUM_EPOCHS,
                        help='Number of times to iterate over all of the '
                             'training data.')
parser.add_argument('--learning-rate', type=float,
                        dest='learning_rate',
                        default=DEFAULT_LEARNING_RATE,
                        help='How large a learning rate to use when training.')
parser.add_argument('--batch-size', type=int,
                        dest='batch_size',
                        default=DEFAULT_BATCH_SIZE,
                        help='How many images to train on at a time.')
parser.add_argument('--save-period', type=int,
                        dest='save_period',
                        default=DEFAULT_SAVE_PERIOD,
                        help='How many epochs to save ckpt files.')
parser.add_argument('--latent-depth', type=int,
                        dest='latent_depth',
                        default=DEFAULT_LATENT_DEPTH,
                        help='How many latent variables you have.')
args = parser.parse_args('')

In [None]:
def _parse_function(example):
    features = tf.io.parse_single_example(
        example,
        features={
            'filename': tf.io.FixedLenFeature([], tf.string, default_value = ''),
            'image/encoded': tf.io.FixedLenFeature([], tf.string,
                                                default_value='')
        })
    image_encoded = features['image/encoded']

    # Decode the JPEG.
    image = tf.io.decode_jpeg(image_encoded, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    return image

In [None]:
train_data_files = glob(os.path.join(args.tfrecords_dir, 'train-*'))
train_dataset = tf.data.TFRecordDataset(train_data_files) \
        .map(_parse_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
for img in train_dataset.take(5):
    img_np = img.numpy()
    plt.figure()
    plt.axis('off')
    plt.imshow(img_np)

In [None]:
def preprocessing_data(image):
    image = tf.image.resize(image, [IMAGE_HEIGHT, IMAGE_WIDTH])
    return image

In [None]:
train_dataset = train_dataset.map(preprocessing_data) \
        .cache() \
        .shuffle(NUM_INPUT_DATA) \
        .batch(args.batch_size) \
        .prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
def generator_model():
    inputs = layers.Input(shape=(args.latent_depth,))
    x = layers.Dense(8*8*512)(inputs)
    x = layers.Reshape((8, 8, 512))(x)
    x = layers.BatchNormalization(epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    

    x = layers.Conv2DTranspose(256, 3, 2, padding='same', kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
    x = layers.BatchNormalization(epsilon=1e-5)(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(128, 3, 2, padding='same', kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
    x = layers.BatchNormalization(epsilon=1e-5)(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(64, 3, 2, padding='same', kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
    x = layers.BatchNormalization(epsilon=1e-5)(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(3, 3, 2, padding='same', activation = 'sigmoid', kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)

    model = Model(inputs, x, name="generator")
    model.summary()
    plot_model(model, to_file=os.path.join(MODEL_SAVE_PATH, "generator.png"), show_shapes=True)
    return model

In [None]:
def discriminator_model():
    inputs = layers.Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL))
    x = layers.Conv2D(32, 5, 2, padding='same', kernel_initializer=initializers.RandomNormal(stddev=0.02))(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(64, 5, 2, padding='same', kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 5, 2, padding='same', kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 5, 2, padding='same', kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
    x = layers.LeakyReLU(0.2)(x)
 
    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)
 
    model = Model(inputs, x, name="discriminator")
    model.summary()
    plot_model(model, to_file=os.path.join(MODEL_SAVE_PATH, "discriminator.png"), show_shapes=True)
    return model

In [None]:
class LSGAN(tf.keras.Model):
    def __init__(self):
        super(LSGAN, self).__init__()
        self.input_layer = layers.Input(shape=(args.latent_depth,))
        self.generator = generator_model()
        self.discriminator = discriminator_model()
        self.mse = losses.MeanSquaredError()
        self.out = self.call(self.input_layer)
        super(LSGAN, self).__init__(inputs = self.input_layer, outputs = self.out, name='lsgan')
 
        self.d_loss_tracker = metrics.Mean(name='losses/d_loss')
        self.g_loss_tracker = metrics.Mean(name='losses/g_loss')
        self.d_norm_grad_tracker = metrics.Mean(name='grads/d_norm_grad')
        self.g_norm_grad_tracker = metrics.Mean(name='grads/g_norm_grad')

    def build(self, input_shape, **kwags):
        super(LSGAN, self).build(input_shape, **kwags)
 
    def call(self, inputs, training=False):
        images = self.generator(inputs, training)
        outputs = self.discriminator(images, training)
        return images, outputs
 
    def compile(self, d_optimizer, g_optimizer, **kwags):
        super(LSGAN, self).compile(**kwags)
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
 
    @tf.function
    def train_step(self, real_images):
        
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, args.latent_depth))
 
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors, training=True)

        # Add random noise to the labels - important trick!
        #labels += 0.05 * tf.random.uniform(tf.shape(labels))
 
        # Train the discriminator
        with tf.GradientTape() as tape:
            r_logit = self.discriminator(real_images, training=True)
            f_logit = self.discriminator(generated_images, training=True)

            r_loss = self.mse(tf.ones_like(r_logit), r_logit)
            f_loss = self.mse(tf.zeros_like(f_logit), f_logit)
            d_loss = r_loss + f_loss

        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        self.d_loss_tracker.update_state(d_loss)
        self.d_norm_grad_tracker.update_state(tf.math.log(tf.linalg.global_norm(grads)))

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, args.latent_depth))

        with tf.GradientTape() as tape:
            # Decode them to fake images
            generated_images = self.generator(random_latent_vectors, training=True)

            f_logit = self.discriminator(generated_images, training=True)
            g_loss = self.mse(tf.ones_like(f_logit), f_logit)

        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        self.g_loss_tracker.update_state(g_loss)
        self.g_norm_grad_tracker.update_state(tf.math.log(tf.linalg.global_norm(grads)))
        
        return {"losses/d_loss": self.d_loss_tracker.result(), "losses/g_loss": self.g_loss_tracker.result()}
 
    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker, self.g_norm_grad_tracker, self.d_norm_grad_tracker]

In [None]:
def lsgan_model():
    model = LSGAN()
    model.summary()
    plot_model(model, to_file=os.path.join(MODEL_SAVE_PATH, "lsgan.png"), show_shapes=True)
    return model

In [None]:
model = lsgan_model()

generator_optimizer = tf.keras.optimizers.Adam(lr=args.learning_rate)
discriminator_optimizer = tf.keras.optimizers.Adam(lr=args.learning_rate)

In [None]:
sample_noise = tf.random.stateless_normal([4, 1, args.latent_depth], seed = [40, 40])

def generate_and_save_images(model, epoch):

    plt.figure(figsize=(15,10))

    for i in range(4):
        images, _ = model(sample_noise[i], training=False)
        
        image = images[0, :, :, :]
        image = np.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])

        plt.subplot(1, 4, i+1)
        plt.imshow(image)
        plt.axis('off')
        plt.title('Epoch_{:04d}'.format(epoch))

    plt.tight_layout()  
    plt.savefig(os.path.join(IMG_SAVE_PATH, 'image_at_epoch_{:04d}.png'.format(epoch)))
    plt.show()

In [None]:
train_summary_writer = tf.summary.create_file_writer(os.path.join(MODEL_SAVE_PATH, 'summaries', 'train'))

class MyCallback(tf.keras.callbacks.Callback):
    def __init__(self, ckpt, manager, period):
        super(MyCallback, self).__init__()
        self.period = period
        self.ckpt = ckpt
        self.epoch = int(ckpt.epoch)
        self.manager = manager

    def on_epoch_begin(self, epoch, logs=None):
        self.ckpt.epoch.assign_add(1)
        self.epoch += 1
    
    def on_epoch_end(self, epoch, logs=None):
        with train_summary_writer.as_default():
            for metric in self.model.metrics:
                tf.summary.scalar(metric.name, metric.result(), step=self.epoch)
          
        if  self.epoch % self.period == 0:
            display.clear_output(wait=True)
            generate_and_save_images(self.model, self.epoch)
            save_path = self.manager.save()
            print("Saved checkpoint for epoch {}: {}".format(self.epoch, save_path))

In [None]:
#!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
#!unzip ngrok-stable-linux-amd64.zip

#LOG_DIR = './LSGAN/EXP_1/summaries/'
#get_ipython().system_raw(
#    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
#    .format(LOG_DIR)
#)
#get_ipython().system_raw('./ngrok http 6006 &')
#! curl -s http://localhost:4040/api/tunnels | python3 -c \
#    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

In [None]:
ckpt = tf.train.Checkpoint(epoch=tf.Variable(0), generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 model = model)
manager = tf.train.CheckpointManager(ckpt, os.path.join(MODEL_SAVE_PATH, 'ckpt'), max_to_keep=3)
if manager.latest_checkpoint:
    status = ckpt.restore(manager.latest_checkpoint)
    status.assert_existing_objects_matched()
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")
 
model.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
)
 
my_callback = MyCallback(ckpt, manager, args.save_period)

print('Start learning!')
model.fit(
    train_dataset,
    epochs = args.num_train_epochs,
    callbacks = [my_callback],
    initial_epoch = int(ckpt.epoch),
)
print('Learning finished!')
 
'''
model.evaluate(
    test_dataset
)
'''

In [None]:
anim_file = os.path.join(MODEL_SAVE_PATH, 'lsgan.gif')

with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob(os.path.join(IMG_SAVE_PATH, '*.png'))
    filenames = sorted(filenames)
    last = -1
    for i,filename in enumerate(filenames):
        frame = 2*(i**0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

In [None]:
display.clear_output(wait=True)
generate_and_save_images(model, int(ckpt.epoch))

tf.saved_model.save(model, os.path.join(MODEL_SAVE_PATH, 'lsgan'))
print('lsgan.pb file is created successfully!!')