In [None]:
import os
import sys

import numpy as np  
import pandas as pd 

import sklearn
import tensorflow as tf
from tensorflow import keras

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

import glob
import imageio

import math
import random
import time
import datetime
import shutil
from tqdm import tqdm, tqdm_notebook


from dataclasses import dataclass
from pathlib import Path
import warnings
from scipy import linalg

import xml.etree.ElementTree as ET 

import cv2
from PIL import Image
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

from IPython import display
from IPython.display import Image as IpyImage

from tensorflow.keras.models import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import Layer

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.metrics import Accuracy

from tensorflow.keras.models import save_model
from tensorflow.keras.models import load_model

from tensorflow.keras import backend as K

import pickle

In [None]:
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
    print('and then re-execute this cell.')
else:
    print(gpu_info)

In [None]:
try: # detect TPUs
    # detect and init the TPU
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)

    # instantiate a distribution strategy
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except ValueError: # detect GPUs
    #strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

AUTO = tf.data.experimental.AUTOTUNE
print("Number of accelerators: ", strategy.num_replicas_in_sync)

FishDIR='../../../CustomFish/'
GenFishDIR='./'
GIFDIR = './GIFs/'
#h5DIR = '/content/drive/My Drive/h5s/'
#h5inDIR = '/content/drive/My Drive/h5s/'
h5inDIR = './'
h5DIR = './'

In [None]:
codings_size = 256
dropoutVal = 0.4
LReluAlpha = 0.2

##############
### Layers ###
##############

# from 1710.10196 use mini batch standard dev on discriminator output (only needs to be done once)
# need to combine layers to fade out, inherit Add class
# fade parameter is updated in WGANGP class
class AddWithFade(Add):
    def __init__(self, fade=0.0, **kwargs):
        super(AddWithFade, self).__init__(**kwargs)
        # fade will increase linearly from 0-1 
        self.fade = K.variable(fade, name='fade_param')
 
    def _merge_function(self, inputs):
        assert (len(inputs) == 2)
        #input[0] = lower res layer, input[1] = higher res layer 
        output = ((1. - self.fade) * inputs[0]) + (self.fade * inputs[1])
        return output
    
    def get_config(self):
        base_config = super(AddWithFade, self).get_config()
        base_config["fade"] = self.fade
        return base_config

# update the fade parameter in the model
def update_fade(model,newfade):
    for layer in model.layers:
        if isinstance(layer, AddWithFade):
            K.set_value(layer.fade,newfade)
    
# toggle trainability of WGAN
def toggle_train(WGAN, trainTog=True, size=codings_size):
    if trainTog == False:
        WGAN.trainFreeze = True
        for layer in WGAN.discriminator.layers:
            if layer.output.shape[1] != size:
                layer.trainable = False
        for layer in WGAN.generator.layers:
            if layer.output.shape[1] != size:
                layer.trainable = False
    else:
        WGAN.trainFreeze = False
        for layer in WGAN.discriminator.layers:
            layer.trainable = True
        for layer in WGAN.generator.layers:
            layer.trainable = True     
            
# this has the generator creating realistic deviations across batches
# useful discussion: https://towardsdatascience.com/gan-ways-to-improve-gan-performance-acf37f9f59b
# "We append the similarity o(x) in one of the dense layers in the discriminator
#  to classify whether this image is real or generated. If the mode starts to collapse, 
#  the similarity of generated images increases. The discriminator can use this score 
#  to detect generated images and penalize the generator if mode is collapsing."
class MBStDev(Layer):
    def __init__(self, **kwargs):
        self.smallParam = 1e-8 
        super(MBStDev, self).__init__(**kwargs)
 
    # perform the operation
    def call(self, ins):
        # mean value for each pixel across channels
        pixMean = K.mean(ins, axis=0, keepdims=True)
        # standard deviation across each pixel coord (small param regulates singularity)
        stDev = K.sqrt(K.mean(K.square(ins - pixMean), axis=0, keepdims=True)+self.smallParam)
        # mean standard deviation across each pixel coord
        meanStDev = K.mean(stDev, keepdims=True)
        # scale this up to be the size of one input feature map for each sample
        shape = K.shape(ins)
        outs = K.tile(meanStDev, (shape[0], shape[1], shape[2], 1))
        # concatenate with the output
        joinedInandOut = K.concatenate([ins, outs], axis=-1)
        return joinedInandOut
 
    # corrects the output shape to match the joint values
    def correct_output_shape(self, input_shape):
        # create a copy of the input shape as a list
        input_shape = list(input_shape)
        # add one to the channel dimension (assume channels-last)
        input_shape[-1] += 1
        # convert list to a tuple
        return tuple(input_shape)
    
    def get_config(self):
        base_config = super(MBStDev, self).get_config()
        return base_config

# from 1710.10196 - use pixel normalization, a variant of "local response normalization"
# Used to "disallow the scenario where the magnitudes in the generator and discriminator spiral out 
# of control as a result of competition" - apply BEFORE activation function in generator only
class PixelNorm(Layer):
    def __init__(self, **kwargs):
        self.smallParam = 1e-8 
        super(PixelNorm, self).__init__(**kwargs)
    
    def call(self, ins):
        # -1 is over the filters
        sqPixMean = K.mean(ins**2 + self.smallParam, axis=-1, keepdims=True)
        return ins / K.sqrt(sqPixMean)
    
    def get_config(self):
        base_config = super(PixelNorm, self).get_config()
        return base_config

# this kernel implements equalized learning rate
# adapted from https://github.com/keras-team/keras/blob/master/keras/layers/convolutional.py
class Conv2DELR(Conv2D):
    def __init__(self, *args, cHe=None, kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1.), **kwargs):
        #if kernel_initializer != tf.keras.initializers.RandomNormal(stddev=1.):
        #    print("Warning:  overriding default kernel_initializer in Conv2DELR with {}".format(str(kernel_initializer)))
        if cHe is not None:
            self.c = cHe
        super().__init__(*args, kernel_initializer=kernel_initializer, **kwargs)

    def build(self, input_shape):
        super().build(input_shape)
        # The number of inputs
        n = np.product([int(val) for val in input_shape[1:]])
        # He initialisation constant
        self.c = np.sqrt(2/n)

    def call(self, inputs):
        if self.rank == 2:
            outputs = K.conv2d(
                inputs,
                self.kernel*self.c, # scale kernel
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate)

        if self.use_bias:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
    
    def get_config(self):
        base_config = super(Conv2D, self).get_config()
        base_config['cHe'] = self.c
        return base_config

#############
### Model ###
#############

# from https://keras.io/examples/generative/wgan_gp/
# Define the loss functions to be used for discrimiator
# This should be (fake_loss - real_loss)
# We will add the gradient penalty later to this loss function
def critic_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    # 
    # this is a simple implementation of the drift loss
    epsilonDrift = 0.001
    epsilonLoss = epsilonDrift * tf.reduce_mean(tf.nn.l2_loss(real_img))
    return fake_loss - real_loss + epsilonLoss

# Define the loss functions to be used for generator
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)



## learning rate modifcations
def set_lr(opt, lr):
    opt.lr = lr

def adapt_lr_GAN(optg, optd, histg, histd, lrmax=0.001, lrmin=0.00001):
    pauseFade = False
    if len(hist1) > 20: 
        rhg = histg[-20:]
        rhd = histd[-20:]
        hdiff = rhg - rhd
        avg = np.mean(rhg)
        avd = np.mean(rhd)
        avdiff = np.mean(hdiff)
        changeg = np.mean(rhg[-10:])-np.mean(rhg[:10])
        changed = np.mean(rhd[-10:])-np.mean(rhd[:10])
        changediff = np.mean(hdiff[-10:])-np.mean(hdiff[:10])
        stdg = np.std(rhg)
        stdd = np.std(rhd)
        stddiff = np.std(hdiff)
        if 3*stdg > stddiff:
            optg.lr = max(optg.lr/2,lrmin)
            optd.lr = max(optd.lr/2,lrmin)
            print("Cond 1 adapt: ".format(optg.lr))
        elif changeg > 5 and avg > 30:
            optg.lr = min(2*optg.lr,lrmax)
            optg.lr = min(2*optg.lr,lrmax)
            print("Cond 2 adapt: ".format(optg.lr))
            pauseFade = True
        elif stdg > changeg and avd < 0:
            optg.lr = max(optg.lr/2,lrmin)
            optd.lr = max(optd.lr/2,lrmin)
        print("Cond 3 adapt: ".format(optg.lr))
        print("avg {}; delg {}; stdg {}".format(avg,changeg, stdg))
        print("avd {}; deld {}; stdd {}".format(avg,changeg, stdg))
        print("avdiff {}; deldiff {}; stddiff {}".format(avg,changeg, stdg))
    return pauseFade

def adjust_lr(opt,lrmax=0.00006,lrmin=0.00002):
    # randomly assign a learning rate to both the discriminator and generator each iteration
    newlr = 1./(1./lrmax + tf.random.uniform(shape=(),maxval=1./lrmin))
    opt.lr = newlr

def adjust_lr_prob(opt,lrhigh=0.005,prob=0.003,lrmax=0.001,lrmin=0.0001):#,lrmax=0.001,lrmin=0.0001):
    # randomly assign a learning rate to both the discriminator and generator each iteration
    if tf.random.uniform(shape=(),maxval=1.) < prob:
        opt.lr = lrhigh
    else:
        adjust_lr(opt,lrmax=lrmax,lrmin=lrmin)

# implementation of wasserstein loss with gradient penalty
# Useful information: https://medium.com/@jonathan_hui/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490
# keras implementation: https://keras.io/examples/generative/wgan_gp/
class WGANGP(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_steps=1,
        gp_weight=10.0,
        #nMaxFade = 800000 this is the recommended value for faces
        nMaxFade = 100000
    ):
        super(WGANGP, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_steps
        self.gp_weight = gp_weight
        self.nMaxFade = nMaxFade
        self.fade = 0.0
        self.nRunFade = 0
        self.trainFreeze = False

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGANGP, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
        
    def get_config(self):
        base_config = super(keras.Model,self).get_config()
        base_config["discriminator"] = self.discriminator
        base_config["d_optimizer"] = self.d_optimizer
        base_config["d_loss_fn"] = self.d_loss_fn
        base_config["generator"] = self.generator
        base_config["g_loss_fn"] = self.g_loss_fn
        base_config["g_optimizer"] = self.g_optimizer
        base_config["latent_dim"] = self.latent_dim
        base_config["discriminator_steps"] = self.d_steps
        base_config["gp_weight"] = self.gp_weight
        base_config["nMaxFade"] = self.nMaxFade
        return base_config

    def gradient_penalty(self, batch_size, real_images, fake_images):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # get the interplated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images, withFade = True, adaptlr = False, randg=False, randd=False):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]
        
        #print("Batch Size {}".format(batch_size))
        
        # update the fade is model is being run with fade
        if self.fade < 1:
            if withFade and not self.trainFreeze:
                self.nRunFade += batch_size
                self.fade = min(self.nRunFade,self.nMaxFade)/self.nMaxFade
                update_fade(self.generator,self.fade)
                update_fade(self.discriminator,self.fade)

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper.
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add gradient penalty to the discriminator loss
        # 6. Return generator and discriminator losses as a loss dictionary.

        # Train discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate discriminator loss using fake and real logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            if randd:
                adjust_lr(self.d_optimizer)
            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator now.
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        if randg:
            adjust_lr(self.g_optimizer)
        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

# take from a model that criticises an nxnx3 figure to one that criticizes a 2nx2nx3 figure
def add_critic_level(old_c_model, new_level=8, nSkip=3,nSkipPassby=1,nColors=3,nameX=None):
    namebase = nameX
    if nameX == None:
        namebase = str(new_level)
        
    thisleveldict = leveldict[new_level]
    filters = thisleveldict['filters']
    conv2filters = leveldict[new_level/2]['filters']
    
    #define new input layer
    layer_in = Input(shape=(new_level, new_level, nColors,),name="Input_C0_{}".format(namebase)) 
    
    # new round start passby here
    conv_0 = Conv2DELR(filters, (1,1), padding="SAME",name="Conv_C0_{}".format(namebase))(layer_in)
    act_0 = LeakyReLU(alpha=0.2,name="Act_C0_{}".format(namebase))(conv_0)
    # first layer set (newround start main model here)
    conv_1 = Conv2DELR(filters, (3,3), padding="SAME",name="Conv_C1_{}".format(namebase))(act_0)
    act_1 = LeakyReLU(alpha=0.2,name="Act_C1_{}".format(namebase))(conv_1)
    # second layer set
    conv_2 = Conv2DELR(conv2filters, (3,3), padding="SAME",name="Conv_C2_{}".format(namebase))(act_1)
    act_2 = LeakyReLU(alpha=0.2,name="Act_C2_{}".format(namebase))(conv_2)
    newcriticlayers = AveragePooling2D(name="AvPool_C0_{}".format(namebase))(act_2)
    critic = newcriticlayers
    # append the earlier model
    for i in range(nSkip,len(old_c_model.layers)):
        critic = old_c_model.layers[i](critic)
    criticModel = Model(layer_in, critic)
    
    # define passby model, first downsample
    pathtofade = AveragePooling2D(name="AvPool_CPass_{}".format(namebase))(layer_in)
    # note: including conv_0, and act_0 from previous layer
    for i in range(nSkipPassby,nSkip):
        pathtofade = old_c_model.layers[i](pathtofade)
    critic_fade = AddWithFade(name="Fade_CPass_{}".format(namebase))([pathtofade,newcriticlayers])
    for i in range(nSkip,len(old_c_model.layers)):
        critic_fade = old_c_model.layers[i](critic_fade)
    criticFadeOut = Model(layer_in, critic_fade)
    return [criticModel, criticFadeOut]

# take from a model that generates an nxnx3 figure to one that generates a 2nx2nx3 figure
def add_generator_level(old_g_model, new_level=8, nDrop=3,nDropPassby=1,nColors=3,nameX=None):
    namebase = nameX
    if nameX == None:
        namebase = str(new_level)
        
    thisleveldict = leveldict[new_level]
    filters = thisleveldict['filters']
    
    # get input
    layer_in = old_g_model.input
    # remove final layer
    newendofold = old_g_model.layers[-2].output 
    sizeaugment = UpSampling2D(name="UpSample_G0_{}".format(namebase))(newendofold)
    
    # first layer set
    conv_1 = Conv2DELR(filters, (3,3), padding="SAME",name="Conv_G1_{}".format(namebase))(sizeaugment)
    pixnorm_1 = PixelNorm(name="Pix_G1_{}".format(namebase))(conv_1)
    act_1 = LeakyReLU(alpha=0.2,name="Act_G1_{}".format(namebase))(pixnorm_1)
    # second layer set
    conv_2 = Conv2DELR(filters, (3,3), padding="SAME",name="Conv_G2_{}".format(namebase))(act_1)
    pixnorm_2 = PixelNorm(name="Pix_G2_{}".format(namebase))(conv_2)
    act_2 = LeakyReLU(alpha=0.2,name="Act_G2_{}".format(namebase))(pixnorm_2)
    
    # save for combo below
    conv_out = Conv2DELR(nColors, (1,1), padding="SAME",name="Conv_Gout_{}".format(namebase))(act_2)
    genModel = Model(layer_in, conv_out)
    
    # define passby model
    #old_end = old_g_model.layers[-1].output
    #sizeaugment_fade = UpSampling2D(name="UpSample_GPass_{}".format(namebase))(old_end)
    old_end = old_g_model.layers[-1]
    sizeaugment_fade = old_end(sizeaugment)
    conv_out_fade = AddWithFade(name="Fade_GPass_{}".format(namebase))([sizeaugment_fade,conv_out])
    
    genModelFade = Model(layer_in, conv_out_fade)
    
    return [genModel, genModelFade]

# nLevels = number of pixelation levels (4,8,16,32,64,128) 
def create_critics(nColors=3, initialSize = 4, nLevels = 6):
    # first create the lowest level critic
    thisleveldict = leveldict[initialSize]
    filters = thisleveldict['filters']
    
    #define new input layer
    layer_in = Input(shape=(initialSize, initialSize, nColors,)) 
    # new round start passby here
    # define new input processing layer
    conv_0 = Conv2DELR(filters, (1,1), padding="SAME")(layer_in)
    act_0 = LeakyReLU(alpha=0.2)(conv_0)
    # first layer set (newround start main model here)
    # apply minibatch standard deviation
    miniBDev = MBStDev()(act_0)
    conv_1 = Conv2DELR(filters, (3,3), padding="SAME")(miniBDev)
    act_1 = LeakyReLU(alpha=0.2)(conv_1)
    # second layer set (the end is a 4x4 convolution)
    conv_2 = Conv2DELR(filters, (4,4), padding="SAME")(act_1)
    act_2 = LeakyReLU(alpha=0.2)(conv_2)
    dense_out = Flatten()(act_2)
    out_classifier = Dense(1)(dense_out)
    
    # define and compile model
    initial_critic = Model(layer_in,out_classifier)
    
    # collect all models
    modellist = [[initial_critic,initial_critic]]
    curlevel = 4
    for i in range(1,nLevels):
        curlevel = curlevel * 2
        # modellist[-1][0] corresponds to the version with no fade
        newmodels = add_critic_level(modellist[-1][0], new_level=curlevel)
        modellist.append(newmodels)
    return modellist


# nLevels = number of pixelation levels (4,8,16,32,64,128) 
def create_gens(nInputs = codings_size, nColors=3, initialSize = 4, nLevels = 6):
    # first create the lowest level critic
    thisleveldict = leveldict[initialSize]
    filters = thisleveldict['filters']
    # #4.1 weight initialization and maxnorm constraint
    kinit = tf.keras.initializers.RandomNormal(stddev=1.)
    
    #define new input layer
    layer_in = Input(shape=(nInputs,))
    dense_0 = Dense(nInputs*initialSize*initialSize, kernel_initializer=kinit)(layer_in)
    reshape_0 = Reshape((initialSize, initialSize, nInputs))(dense_0)
    #may want to add activiation functions
    # first layer set (start with 4x4)
    conv_1 = Conv2DELR(filters, (4,4), padding="SAME")(reshape_0)
    pixnorm_1 = PixelNorm()(conv_1)
    act_1 = LeakyReLU(alpha=0.2)(pixnorm_1)
    # second layer set 
    conv_2 = Conv2DELR(filters, (3,3), padding="SAME")(act_1)
    pixnorm_2 = PixelNorm()(conv_2)
    act_2 = LeakyReLU(alpha=0.2)(pixnorm_2)
    # save for combo below
    conv_out = Conv2DELR(nColors, (1,1), padding="SAME")(act_2)
    genModel = Model(layer_in, conv_out)
    
        
    # collect all models
    modellist = [[genModel,genModel]]
    curlevel = 4
    for i in range(1,nLevels):
        curlevel = curlevel * 2
        # modellist[-1][0] corresponds to the version with no fade
        newmodels = add_generator_level(modellist[-1][0], new_level=curlevel)
        modellist.append(newmodels)
    return modellist

def create_gans(gens,crits,nInputs = codings_size):
    ganlist = []
    assert len(gens)==len(crits), "Generators and Discriminators created with different lengths"
    for i in range(len(gens)):
        with strategy.scope():
            # compile standard
            wgan1 = WGANGP(discriminator=crits[i][0],generator=gens[i][0],latent_dim=nInputs,discriminator_steps=1)
            wgan1.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)
            # compile fade
            wgan2 = WGANGP(discriminator=crits[i][1],generator=gens[i][1],latent_dim=nInputs,discriminator_steps=1)
            wgan2.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)
            # add to gan list
            ganlist.append([wgan1,wgan2])
    return ganlist
 
########################
### Image processing ###
########################
datagen = ImageDataGenerator(
        rotation_range=25,
        width_shift_range=0.05,
        height_shift_range=0.1,
        shear_range=10,
        zoom_range=[0.95,1.2],
        brightness_range=[0.8,1.4],
        horizontal_flip=True,
        fill_mode='nearest')

# adjust the saturation of the image
def colorSaturationAdjust(img, satRange=[0.6,1.4]):
    rSat = random.triangular(satRange[0],satRange[1])
    gSat = random.triangular(satRange[0],satRange[1])
    bSat = random.triangular(satRange[0],satRange[1])
    satVar = [rSat, gSat, bSat]
    return np.clip(img * satVar,0,1)

def multiprocessingI(imglist):
    imglistproc = []
    for img in datagen.flow((imglist+1.)/2.000)[0]:
        imglistproc.append(np.clip(2.*(colorSaturationAdjust(img/255.0)-0.5),-1.,1.))
        #imglistproc.append(colorSaturationAdjust(img/255.0))
    return np.asarray(imglistproc)

def multiprocessingI(imglist,new_length,batch_size):
    imglistproc = []
    for img in datagen.flow((imglist+1.)/2.000,batch_size=batch_size)[0]:
        newimg = np.clip(2.*(colorSaturationAdjust(img/255.0)-0.5),-1.,1.)
        imglistproc.append(cv2.resize(newimg, (new_length,new_length), interpolation = cv2.INTER_AREA))
        #imglistproc.append(colorSaturationAdjust(img/255.0))
    return np.asarray(imglistproc)


##################################
### Image display and file I/O ###
##################################
def plot_multiple_images(images, n_cols=None):
    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)
    plt.figure(figsize=(n_cols, n_rows))
    plt.subplots_adjust(hspace=0.03, wspace=0)
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1,snap=True)
        plt.imshow(np.clip((image + 1.)/2.,0.,1.), cmap="binary")
        #plt.imshow(image, cmap="binary")
        plt.axis("off")

def read_image(src):
    img = cv2.imread(src)
    if img is None:
        print(src)
        raise FileNotFoundError
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def write_image(img,filename):
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(filename, img)

def constructGIF(ix,filename):
    gif_path = GIFDIR+filename+".gif"
    frames_path = GIFDIR+filename+"_{i}.jpg"
    with imageio.get_writer(gif_path, mode='I') as writer:
        for i in range(ix):
            writer.append_data(imageio.imread(frames_path.format(i=i)))

#IpyImage(filename=GIFDIR+modelname+"_test.gif")

def plot_history(d_hist, g_hist):
    plt.plot(d_hist, label='crit')
    plt.plot(g_hist, label='gen')
    plt.legend()
    plt.show()
    plt.savefig(GIFDIR+'/plot_line_plot_loss.png')
    plt.close()
    
def saveh5(model, name):
    filename = h5DIR + name + '.h5'
    model.save(filename)

def saveh5s(model, name):
    saveh5(model.generator, 'gen_'+ name)
    saveh5(model.discriminator, 'crit_'+ name)

def write_images_and_GIF(imgsetforGIF,modelname):
    frames_path = GIFDIR+modelname+"_{i}.jpg"
    for i, imgs in enumerate(imgsetforGIF):
        plot_multiple_images(imgs, 8)
        plt.savefig(frames_path.format(i=i))
        plt.close()
    constructGIF(len(imgsetforGIF),modelname)

customOs={'LeakyReLU': LeakyReLU,'Conv2DELR': Conv2DELR,'PixelNorm':PixelNorm,
          'AddWithFade':AddWithFade,'MBStDev':MBStDev}
def loadh5s(name):
    filename = h5inDIR + 'gen_' +name + '.h5'
    generator = load_model(filename,custom_objects=customOs)
    filename = h5inDIR + 'crit_' +name + '.h5'
    critic = load_model(filename,custom_objects=customOs)
    return generator, critic

def load_hists(tag=""):
    with open(h5inDIR + 'g_loss_hist'+tag+'.dat', 'rb') as filehandle:
        # read the data as binary data stream
        g_loss_hist = pickle.load(filehandle)
    with open(h5inDIR + 'd_loss_hist'+tag+'.dat', 'rb') as filehandle:
        # read the data as binary data stream
        d_loss_hist = pickle.load(filehandle)
    return d_loss_hist, g_loss_hist

def dump_hists(d_loss_hist, g_loss_hist):
    with open(h5DIR + 'g_loss_hist.dat', 'wb') as filehandle:
        # store the data as binary data stream
        pickle.dump(g_loss_hist, filehandle)
    with open(h5DIR + 'd_loss_hist.dat', 'wb') as filehandle:
        # store the data as binary data stream
        pickle.dump(d_loss_hist, filehandle)

def dump_hists(d_loss_hist, g_loss_hist, tag=""):
    with open(h5DIR + 'g_loss_hist'+tag+'.dat', 'wb') as filehandle:
        # store the data as binary data stream
        pickle.dump(g_loss_hist, filehandle)
    with open(h5DIR + 'd_loss_hist'+tag+'.dat', 'wb') as filehandle:
        # store the data as binary data stream
        pickle.dump(d_loss_hist, filehandle)

In [None]:
batch_size = 32
#batch_size = 64
modelname = 'WGANGPU64'

tf.random.set_seed(7777)
np.random.seed(7777)
fixednoise = tf.random.normal(shape=[32, codings_size])

In [None]:
filelist = os.listdir(FishDIR)
FishFiles = [ FishDIR+i for i in filelist if i[-3:]=='jpg' ]
Xtrain = np.stack([ read_image(i) * np.float32(2. / 255.) - 1  for i in FishFiles ])

dataset = tf.data.Dataset.from_tensor_slices(Xtrain)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

In [None]:
# do we fade and if so how much?
setFade = False
fadenImages = 1000000 #800k is NVIDIAs choice
netEpochs = 1000

# Resize without ELR, shorter path

In [None]:
generator = Sequential([
    Input(shape=(codings_size,)),
    Dense(8 * 8 * 128),
    Reshape([8, 8, 128]),
    #Conv2DTranspose(256, (3,3), padding="SAME"),
    #PixelNorm(),
    #LeakyReLU(alpha=0.2),
    #Conv2DTranspose(256, (5,5), padding="SAME"),
    #PixelNorm(),
    #LeakyReLU(alpha=0.2),
    #UpSampling2D(),
    # image 4x4 -> 8x8
    Conv2DTranspose(128, (5,5), padding="SAME"),
    PixelNorm(),
    LeakyReLU(alpha=0.2),
    Conv2DTranspose(128, (4,4), padding="SAME"),
    PixelNorm(),
    LeakyReLU(alpha=0.2),
    UpSampling2D(),
    # image 8x8 -> 16x16
    #Conv2DTranspose(128, (3,3), padding="SAME"),
    #PixelNorm(),
    #LeakyReLU(alpha=0.2),
    Conv2DTranspose(128, (5,5), padding="SAME"),
    PixelNorm(),
    LeakyReLU(alpha=0.2),
    # image 16x16 -> 32x32
    UpSampling2D(),
    Conv2DTranspose(128, (5,5),  padding="SAME"),
    PixelNorm(),
    LeakyReLU(alpha=0.2),
    Conv2DTranspose(128, (5,5), padding="SAME"),
    PixelNorm(),
    LeakyReLU(alpha=0.2),
    # image 32x32 -> 64x64
    UpSampling2D(),
    Conv2DTranspose(128, (5,5),  padding="SAME"),
    PixelNorm(),
    LeakyReLU(alpha=0.2),
    #Conv2DTranspose(64, (5,5), strides=(2,2), padding="SAME"),
    #PixelNorm(),
    #LeakyReLU(alpha=0.2),
    # image 64x64 -> 128x128
    #keras.layers.Conv2DTranspose(32, (5,5), strides=(2,2), padding="SAME"),
    #PixelNorm(),
    #LeakyReLU(alpha=0.2),
    Conv2DTranspose(3, (5,5), padding="SAME",activation="tanh")
])
discriminator = keras.models.Sequential([
    Input(shape=(64, 64, 3,)),
    Conv2D(128,  (5,5), padding="SAME"),
    LeakyReLU(alpha=0.2),
    Dropout(dropoutVal),
    # image 64x64 -> 32x32
    Conv2D(128,  (4,4), strides=2, padding="SAME"),
    LeakyReLU(alpha=0.2),
    Dropout(dropoutVal),
    # image 32x32 -> 16x16 
    Conv2D(128,  (4,4), strides=2, padding="SAME"),
    LeakyReLU(alpha=0.2),
    Dropout(dropoutVal),
    Conv2D(128,  (3,3), padding="SAME"),
    LeakyReLU(alpha=0.2),
    Dropout(dropoutVal),
    # image 16x16 -> 8x8
    Conv2D(128,  (4,4), strides=2, padding="SAME"),
    LeakyReLU(alpha=0.2),
    Dropout(dropoutVal),
    # image 16x16 -> 4x4
    Conv2D(256,  (4,4), strides=2, padding="SAME"),
    LeakyReLU(alpha=0.2),
    Dropout(dropoutVal),
    Conv2D(256,  (3,3), padding="SAME"),
    LeakyReLU(alpha=0.2),
    Dropout(dropoutVal),
    Conv2D(256,  (4,4), padding="SAME"),
    LeakyReLU(alpha=0.2),
    MBStDev(),
    Flatten(),
    Dense(1) # note: use WGAN
])

optimizercrit=Adam(lr=0.001, beta_1=0, beta_2=0.99, epsilon=1e-8)
optimizergen=Adam(lr=0.001, beta_1=0, beta_2=0.99, epsilon=1e-8)

# compile first
wgan = WGANGP(discriminator=discriminator,generator=generator,latent_dim=codings_size,discriminator_steps=1)
wgan.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)

modelname = 'WGANGPU64'

# Train model

In [None]:
# compile first
wgan = WGANGP(discriminator=discriminator,generator=generator,latent_dim=codings_size,discriminator_steps=4)
wgan.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)

modelname = 'WGANGPU64'

In [None]:
d_loss_hist, g_loss_hist = list(), list()
iCount=0

In [None]:
def train_wgan(wgan, dataset, batch_size, codings_size, image_size, n_epochs=10,iCount=0):
    generator = wgan.generator
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))  
        d_loss, g_loss = 0, 0
        starttime = time.time()
        for X_batch in dataset:
            with tf.device('/device:GPU:0'):
                lossdict = wgan.train_step(multiprocessingI(X_batch,image_size,batch_size))
            d_loss += lossdict['d_loss']
            g_loss += lossdict['g_loss']
        if epoch % 5 == 0 or epoch+1 == n_epochs: 
            generated_images = generator(fixednoise)
            plot_multiple_images(generated_images, 8)
            frames_path = GIFDIR+modelname+"_{i}.png"
            plt.savefig(frames_path.format(i=iCount))
            plt.show()    #note: show before save has the file blank
            plt.close()
            iCount+=1
        #if epoch % 100 == 99:
        #  saveh5s(wgan,savemodelname)
        d_loss_hist.append(d_loss)
        g_loss_hist.append(g_loss)
        #if epoch % 10 == 0 and epoch > 20: 
        #    adapt_lr_GAN(wgan.g_optimizer, wgan.d_optimizer, g_loss_hist, d_loss_hist)
        print('Epoch took {time:.3f} seconds: d_loss={dl:.3f},  g_loss={gl:.3f}'.format(time=time.time() - starttime, dl = d_loss, gl = g_loss)) 
        
def update_2nd_gen(avmodel, newmodel, decayconst = 0.999):
    weightlist = list()
    aw = avmodel.get_weights()
    nw = newmodel.get_weights()
    for i in range(len(aw)):
        weightlist.append(decayconst * aw[i] + (1.-decayconst)*nw[i])
    avmodel.set_weights(weightlist)
    
def train_wgan_withav(wgan, dataset, batch_size, codings_size, image_size, avmodel=None, n_epochs=10,iCount=0,randd=False,randg=False):
    generator = wgan.generator
    if avmodel is None: 
        avmodel = keras.models.clone_model(wgan.generator)
        avmodel.set_weights(wgan.generator.get_weights())
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))  
        d_loss, g_loss = 0, 0
        starttime = time.time()
        for X_batch in dataset:
            lossdict = wgan.train_step(multiprocessingI(X_batch,image_size,batch_size),randd=randd,randg=randg)
            d_loss += lossdict['d_loss']
            g_loss += lossdict['g_loss']
            # update moving average once per batch
            update_2nd_gen(avmodel,wgan.generator)
        if epoch % 5 == 0 or epoch+1 == n_epochs: 
            generated_images = generator(fixednoise)
            plot_multiple_images(generated_images, 8)
            #frames_path = GIFDIR+modelname+"_{i}.png"
            #plt.savefig(frames_path.format(i=iCount))
            plt.show()    #note: show before save has the file blank
            plt.close()
        if epoch % 10 == 9 or epoch+1 == n_epochs:
            generated_images = avmodel(fixednoise)
            plot_multiple_images(generated_images, 8)
            frames_path = GIFDIR+modelname+"_{i}.png"
            plt.savefig(frames_path.format(i=iCount))
            plt.show()    #note: show before save has the file blank
            plt.close()
            iCount+=1
        #if epoch % 100 == 99:
        #  saveh5s(wgan,savemodelname)
        d_loss_hist.append(d_loss)
        g_loss_hist.append(g_loss)
        print('Epoch took {time:.3f} seconds: d_loss={dl:.3f},  g_loss={gl:.3f}'.format(time=time.time() - starttime, dl = d_loss, gl = g_loss)) 
    return avmodel

In [None]:
wgan.d_optimizer.lr = 0.0001
wgan.g_optimizer.lr = 0.0001
avgan = train_wgan_withav(wgan, dataset, batch_size, codings_size, 64, n_epochs=100,iCount=0)

In [None]:
wgan.d_optimizer.lr = 0.00005
wgan.g_optimizer.lr = 0.00005
avgan = train_wgan_withav(wgan, dataset, batch_size, codings_size, 64, n_epochs=300,iCount=10)

In [None]:
wgan.d_optimizer.lr = 0.000025
wgan.g_optimizer.lr = 0.000025
avgan = train_wgan_withav(wgan, dataset, batch_size, codings_size, 64, avmodel=avgan, n_epochs=600,iCount=40)

In [None]:
wgan = WGANGP(discriminator=discriminator,generator=generator,latent_dim=codings_size,discriminator_steps=4)
wgan.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)
avgan = train_wgan_withav(wgan, dataset, batch_size, codings_size, 64, avmodel=avgan, n_epochs=200,iCount=100,randg=True, randd=True)

In [None]:
wgan.d_optimizer.lr = 0.00003
wgan.g_optimizer.lr = 0.00003
avgan = train_wgan_withav(wgan, dataset, batch_size, codings_size, 64, avmodel=avgan, n_epochs=200,iCount=120,randg=True)

In [None]:
wgan.d_optimizer.lr = 0.000025
wgan.g_optimizer.lr = 0.000025
avgan = train_wgan_withav(wgan, dataset, batch_size, codings_size, 64, avmodel=avgan, n_epochs=300,iCount=140)

In [None]:
wgan.d_optimizer.lr = 0.00002
wgan.g_optimizer.lr = 0.00002
avgan = train_wgan_withav(wgan, dataset, batch_size, codings_size, 64, avmodel=avgan, n_epochs=600,iCount=170)

In [None]:
modname="64_WGAN-GP"
saveh5s(wgan,modname)
saveh5(avgan,'avgen'+modname)
!tar -czf gif64WGAN.tar.gz GIFs
dump_hists(d_loss_hist,g_loss_hist,modname)

In [None]:
modname="64_WGAN-GP_a"
!tar -czf gif64_a.tar.gz GIFs
dump_hists(d_loss_hist,g_loss_hist,modname)
saveh5(wgan.generator,'gen_'+modname)
saveh5(wgan.discriminator,'crit_'+modname)
saveh5(avgan,'avgen'+modname)



# Generate from an existing h5 (version saved without fading)

In [None]:
loadmodelname = "64_WGAN-GP_a"
generator, critic = loadh5s(loadmodelname)

optimizercrit=Adam(lr=0.00003, beta_1=0, beta_2=0.99, epsilon=1e-8)
optimizergen=Adam(lr=0.00003, beta_1=0, beta_2=0.99, epsilon=1e-8)

wgan = WGANGP(discriminator=critic,generator=generator,latent_dim=codings_size,discriminator_steps=4)
wgan.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)
#generator.load_weights('gen_64_WGAN-GP_a')

In [None]:
wgan = WGANGP(discriminator=critic,generator=generator,latent_dim=codings_size,discriminator_steps=4)
wgan.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)

In [None]:
generated_images = wganFade.generator(fixednoise)
plot_multiple_images(generated_images, 8)

In [None]:
plot_history(d_loss_hist[25:], g_loss_hist[25:])

# Testing

In [None]:
loadmodelname = "64_WGAN-GP_b"
generator, critic = loadh5s(loadmodelname)

optimizercrit=Adam(lr=0.00003, beta_1=0, beta_2=0.99, epsilon=1e-8)
optimizergen=Adam(lr=0.00003, beta_1=0, beta_2=0.99, epsilon=1e-8)

wgan = WGANGP(discriminator=critic,generator=generator,latent_dim=codings_size,discriminator_steps=4)
wgan.compile(d_optimizer=optimizercrit,g_optimizer=optimizergen,g_loss_fn=generator_loss,d_loss_fn=critic_loss)
#generator.load_weights('gen_64_WGAN-GP_a')

In [None]:
tf.random.set_seed(7777)
np.random.seed(7777)
fixednoise = tf.random.normal(shape=[32, codings_size])

In [None]:
ims = generator(fixednoise)
plot_multiple_images(ims, 8)

In [None]:
noise = tf.random.normal(shape=[100, codings_size])
ims = generator(noise)
plot_multiple_images(ims, 10)

In [None]:
pA = noise[28]
pB = noise[20]
pC = noise[74]
pD = noise[92]
nstep = 9
v1 = (pB-pA)/nstep
v2 = (pC-pB)/nstep
v3 = (pD-pC)/nstep
v4 = (pA-pD)/nstep

data = [pB + i * v2 for i in range(nstep+1)]
data = data + [pC + i * v3 for i in range(nstep+1)]
data = data + [pD + i * v4 for i in range(nstep+1)]
data = data + [pA + i * v1 for i in range(nstep+1)]

imx = generator(np.array(data))
plot_multiple_images(imx, 10)

plt.savefig("RoughTransition.png")

In [None]:
pA = noise[6]
pB = noise[21]
pC = noise[34]
pD = noise[72]
nstep = 9
v1 = (pB-pA)/nstep
v2 = (pC-pB)/nstep
v3 = (pD-pC)/nstep
v4 = (pA-pD)/nstep

data = [pA + i * v1 for i in range(nstep+1)]
data = data + [pB + i * v2 for i in range(nstep+1)]
data = data + [pC + i * v3 for i in range(nstep+1)]
data = data + [pD + i * v4 for i in range(nstep+1)]

imx = generator(np.array(data))
plot_multiple_images(imx, 10)

plt.savefig("SmoothTransition.png")

# An evaluation metric

In [None]:
# larger plot labels
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

Xr = list()
for X in Xtrain:
    Xr.append(cv2.resize(X, (64,64), interpolation = cv2.INTER_AREA))

np.array(Xr).shape

In [None]:
lottanoise = tf.random.normal(shape=[256, codings_size])
Xf = generator(lottanoise)
for i in range(5):
    lottanoise = tf.random.normal(shape=[256+2*i+6, codings_size])
    Xf = tf.concat([Xf, generator(lottanoise)], 0)

In [None]:
def get_pixel_dist(list1, list2, areSame=True):
    distlist = list()
    totalmin = 1000000
    rClose = []
    for X in list1:
        mindist = 10000000000
        for Y in list2:
            if (X*1. == Y*1.).all() and areSame:
                pass
            else:
                ceval = np.sqrt(np.sum((X-Y)**2))
                if ceval < mindist:
                    mindist = ceval
                    if mindist < totalmin:
                        rClose = [X,Y]
                        totalmin = mindist
        distlist.append(mindist)
    return distlist, rClose

In [None]:
realdistlist, closest = get_pixel_dist(Xr,Xr)

In [None]:
fakedistlist, closestf = get_pixel_dist(np.array(Xf),np.array(Xf))

In [None]:
crossdistlist, closestc = get_pixel_dist(Xr,np.array(Xf),areSame=False)

In [None]:
plt.hist([realdistlist,fakedistlist,crossdistlist],label=["Real-Real","Fake-Fake","Real-Fake"],density=True)
plt.xlabel("Pixel distance")
plt.ylabel("Fraction per unit pixel")
plt.legend(loc="upper right")
plt.savefig("pixelDistributions.png")
plt.show()

In [None]:
GIFDIR="./GIFs/"
def constructGIF(filename,ix):
    gif_path = GIFDIR+filename+".gif"
    frames_path = GIFDIR+filename+"_{i}.png"
    with imageio.get_writer(gif_path, mode='I') as writer:
        for i in range(ix):
            writer.append_data(imageio.imread(frames_path.format(i=i)))
constructGIF("WGANGPU64",250)
IpyImage(filename=GIFDIR+"WGANGPU64"+".gif")