# **Import Modules**

In [1]:
import sys, os, copy, glob
import random as pyrand

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

sys.path.insert(0, '../input/kerasstylegan/')

In [2]:
from externals.io import TFRecordReader
from stylegan import gan, functional, processing


# from stylegan.callbacks import LearningRateScheduler

In [3]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

In [4]:
from kaggle_secrets import UserSecretsClient
from kaggle_datasets import KaggleDatasets

In [5]:
import tensorflow as tf
from tensorflow.keras import backend, optimizers
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.callbacks import ModelCheckpoint

# **Config**

In [6]:
print(tf.__version__)

In [7]:
tqdm.pandas()

In [8]:
class Config:
    
    experiment_id = 0
    
    seed = 1265
        
    low_resolution = 4
    resolution = 256
    num_clusters = 1
    
    use_crop = False
    use_pseudo_labels = False
    
    crop_resolution = 128

    batch_size = 16
    
    epochs = 5
    
    shuffle = True
    buffer_size = int(3e4)

In [9]:
try:
    
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    
    print('Running on TPU ', tpu.master())
    
except ValueError:
  
    tpu = None

if tpu is not None:
    
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)

    strategy = tf.distribute.experimental.TPUStrategy(tpu)

else:
    
    strategy = tf.distribute.get_strategy()

Config.num_replicas = strategy.num_replicas_in_sync

print('REPLICAS: ', Config.num_replicas)

In [10]:
def set_seed(tf_seed=Config.seed, np_seed=Config.seed, py_rand=Config.seed, py_hash=Config.seed):

    os.environ['PYTHONHASHSEED'] = str(py_hash)
    pyrand.seed(py_rand)

    tf.random.set_seed(tf_seed)
    np.random.seed(np_seed)

In [11]:
set_seed()

In [12]:
augmenter = []

if Config.use_crop:

    random_crop = [
        preprocessing.RandomCrop(height=Config.crop_resolution, width=Config.crop_resolution), 
        preprocessing.Resizing(height=Config.resolution, width=Config.resolution, interpolation='bilinear', crop_to_aspect_ratio=False)
    ]
    
    random_crop =  Sequential(random_crop)
    random_crop.build((None, Config.resolution, Config.resolution, 3))
        
augmenter.append(preprocessing.RandomFlip('horizontal'))
augmenter.append(preprocessing.RandomTranslation(height_factor=0.2, width_factor=0.2, interpolation='nearest'))
augmenter.append(preprocessing.RandomRotation(factor=0.2))
augmenter.append(preprocessing.RandomZoom(height_factor=(-0.3, 0.2), width_factor=(-0.3, 0.2)))

augmenter = Sequential(augmenter)
augmenter.build((None, Config.resolution, Config.resolution, 3))

# **Utils**

In [13]:
def decode_image(bytes_str):
    
    image = tf.io.decode_raw(bytes_str, out_type='uint8')
    image = tf.reshape(image, shape=(256, 256, 3))

    return image

In [14]:
def dataset_generator(filenames, reader, shuffle=Config.shuffle, buffer_size=Config.buffer_size):

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
    dataset = dataset.with_options(options)
    dataset = dataset.map(reader.read_example, num_parallel_calls=AUTOTUNE)
    
    if shuffle:
        
        dataset = dataset.shuffle(buffer_size, reshuffle_each_iteration=True)
    
    return dataset

In [15]:
def show_images(images, n=5, seed=Config.seed):
        
    random = np.random.RandomState(seed=seed)
    
    indices = np.arange(len(images))
    indices = random.choice(indices, size=n*n, replace=False)
    
    fig, axs = plt.subplots(nrows=n, ncols=n, figsize=(15, 15))
    axs = axs.flatten()
    
    for i, idx in enumerate(indices):
        
        axs[i].imshow(images[idx])
        axs[i].axis('off')
    
    fig.tight_layout()
    fig.show()

# **Load Data**

In [16]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()

user_secrets.set_tensorflow_credential(user_credential)

GCS_PATH = {}
GCS_PATH['celeba-hq'] = KaggleDatasets().get_gcs_path('celeba-hq-256')

print(GCS_PATH)

In [17]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

options = tf.data.Options()
options.experimental_deterministic = False

In [18]:
def get_celebA_HQ_reader():
    
    dtypes = {'image': 'bytes'}
    decode_fn = {'image': decode_image}

    reader = TFRecordReader(dtypes=dtypes, decode_fn=decode_fn)

    filenames = tf.io.gfile.glob(f'{GCS_PATH["celeba-hq"]}/tfrecords/celebA-HQ/256/*')
    filenames = np.array(filenames)

    return reader, filenames

In [19]:
celeba_hq_reader, celeba_hq_filenames = get_celebA_HQ_reader()

In [20]:
def batch_preprocessing(add_noise):
    
    def _preprocessing(images):

        batch_size = tf.shape(images)[0]

        images = tf.cast(images, dtype=tf.float32)
        images = processing.normalize(images)
        
        if Config.use_crop and (processing.random_uniform_state(precision=2) < 0.1):
            
            images = random_crop(images, training=True)
            
        if Config.use_pseudo_labels:
            
            labels = processing.sample_pseudo_labels(size=batch_size, num_classes=Config.num_clusters)
            
        else:
            
            labels = processing.sample_adversarial_labels(size=batch_size, num_classes=Config.num_clusters)
        
        if add_noise:
            
            images += tf.random.normal(shape=tf.shape(images), mean=0.0, stddev=1.0)
            
        return images, labels
    
    return _preprocessing

In [21]:
def get_celeba_hq_dataset_generator(filenames, batch_size=32, shuffle=Config.shuffle, 
                                    buffer_size=Config.buffer_size, add_noise=False):

    dataset = dataset_generator(filenames, celeba_hq_reader, shuffle, buffer_size)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.map(batch_preprocessing(add_noise), num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(AUTOTUNE)

    return dataset

In [22]:
sample_dataset = get_celeba_hq_dataset_generator(celeba_hq_filenames[0], batch_size=25)

sample_x, sample_y = sample_dataset.as_numpy_iterator().next()

show_images(processing.denormalize(sample_x) / 255.0)

# **Build Model**

In [23]:
def generator_wgan_loss(*args, **kwargs):
    
    loss = functional.generator_wgan_loss(*args, **kwargs)
    
    return loss

def discriminator_wgan_loss(*args, **kwargs):
    
    loss = functional.discriminator_wgan_loss(*args, **kwargs)
    
    return loss

def wgan_gradient_penalty(*args, **kwargs):
    
    loss = functional.wgan_gradient_penalty(*args, **kwargs)
    
    return loss 

In [24]:
set_seed()

mapping_cfgs = {'latent_dim': 256, 'disentangled_latent_dim': 256, 'num_clusters': None,
                'depth': 4, 'learning_rate_multiplier': 0.01}

synthesis_cfgs = {'resolution': Config.resolution, 'low_resolution': Config.low_resolution, 'constant_input_dim': 256,
                  'in_filters': 8192, 'in_decay': 1.0, 'use_skip': True, 'fused': True, 'learning_rate_multiplier': 1.0}

discriminator_cfgs = {'resolution': Config.resolution, 'low_resolution': Config.low_resolution, 'num_clusters': Config.num_clusters, 
                      'in_filters': 4096, 'in_decay': 1.0, 'use_skip': True, 'learning_rate_multiplier': 1.0}

with strategy.scope():
        
    model = gan.build_stylegan(mapping_cfgs=mapping_cfgs, synthesis_cfgs=synthesis_cfgs, discriminator_cfgs=discriminator_cfgs, 
                               batch_size=None, use_pseudo_labels=Config.use_pseudo_labels, augmenter=None, 
                               ada_target=1.0, ada_step=100, on_batch_ada=True, ada_state_estimator=None, recursive_lookup=True)

    generator_loss = generator_wgan_loss
    discriminator_loss = discriminator_wgan_loss
    gradient_penalty = wgan_gradient_penalty
    
    generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.002, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
    discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.002, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
    
    model.compile(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer,
                  generator_loss=generator_loss, discriminator_loss=discriminator_loss, gradient_penalty=gradient_penalty, 
                  clip_min=None, clip_max=None, use_clip=False, run_eagerly=False)
    
#     model.load_weights('../input/stylegan-celebahq-weights-v0/stylegan-celebA-HQ-0.h5')
        
# model.summary()

# **Training**

In [25]:
dataset = get_celeba_hq_dataset_generator(celeba_hq_filenames, batch_size=Config.batch_size)

In [26]:
backend.clear_session()

experiment_id = Config.experiment_id

In [None]:
callbacks = None

# callbacks = []
# callbacks.append(checkpoint(f'stylegan-celebA-HQ-{experiment_id}.h5'))

model.fit(dataset, epochs=Config.epochs, callbacks=callbacks)

In [None]:
model.save_weights(f'stylegan-celebA-HQ-{experiment_id}.h5')

# **Sample Debugging**

In [None]:
# set_seed(2121)

In [None]:
sample_latent = model.sample_latent(size=25)

sample_gen = model([sample_latent], training=True)

show_images(processing.denormalize(sample_gen) / 255.0)

In [None]:
sample_latent = model.sample_latent(size=25)

sample_gen = model([sample_latent], training=True)

show_images(processing.denormalize(sample_gen) / 255.0)

In [None]:
for layer in model.generator.layers:
    
    if 'noise' in layer.name:
        
        print(tf.math.reduce_min(layer.strength), tf.math.reduce_max(layer.strength))