In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from keras.utils import to_categorical  # Only for categorical one hot encoding
from tensorflow.keras import layers
from sklearn.metrics import accuracy_score
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from tensorflow.keras import backend as K
import time

In [None]:
lfw = tfds.load('lfw', split='train', shuffle_files=True)

[1mDownloading and preparing dataset lfw/0.1.0 (download: 172.20 MiB, generated: Unknown size, total: 172.20 MiB) to /root/tensorflow_datasets/lfw/0.1.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/lfw/0.1.0.incomplete0UWD7T/lfw-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=13233.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Computing statistics...', max=1.0, style=ProgressStyle(de…



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


[1mDataset lfw downloaded and prepared to /root/tensorflow_datasets/lfw/0.1.0. Subsequent calls will reuse this data.[0m


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
def plotImages(imgs):
    fig = plt.figure(figsize=(8, 8))
 
    for i in range(imgs.shape[0]):
      plt.subplot(8, 8, i+1)
      plt.imshow(tf.cast(imgs[i, :, :, :] * 127.5 + 127.5, tf.uint8))
      plt.axis('off')
    plt.show()

In [None]:
class PixelNorm(layers.Layer):
  def __init__(self, epsilon=1e-8, name=''):
    if name != '':
      super(PixelNorm, self).__init__(name=name)
    else:
      super(PixelNorm, self).__init__()
    self.epsilon = tf.constant(epsilon, dtype=tf.float32)
  
  def call(self, x):
    return x * tf.math.rsqrt(tf.reduce_mean(tf.math.square(x), axis=-1, keepdims=True) + self.epsilon)

  def get_config(self):
    config = {'name': self.name}
    base_config = super(PixelNorm, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

class FadeAdd(layers.Layer):
  def __init__(self):
    super(FadeAdd, self).__init__()
    self.alpha = tf.Variable(initial_value=0., trainable=False)
 
  def incrementAlpha(self, step=0.1):
    self.alpha.assign(tf.minimum(self.alpha+step, 1.))
    # print("New Alpha: ", self.alpha)
 
  def call(self, input):
    new, old = input
    self.alpha.assign(tf.minimum(self.alpha, 1.))
    return (new*self.alpha) + (old*(1-self.alpha))

class MinibatchStddev(layers.Layer):
  def __init__(self, group_size=4):
    super(MinibatchStddev, self).__init__()
    self.group_size = group_size

  def call(self, layer):
    group_size = tf.minimum(self.group_size, tf.shape(layer)[0])
    shape = tf.shape(layer)
    minibatch = tf.reshape(layer,(group_size, -1, shape[1], shape[2], shape[3]))
    minibatch -= tf.reduce_mean(minibatch, axis=0, keepdims=True)
    minibatch = tf.reduce_mean(tf.math.square(minibatch), axis = 0)
    minibatch = tf.math.sqrt(minibatch + 1e8)
    minibatch = tf.reduce_mean(minibatch, axis=[1,2, 3], keepdims=True)
    minibatch = tf.tile(minibatch,[group_size, shape[1], shape[2], 1]) 
    return K.concatenate([layer, minibatch], axis=3)          # NHW1

fmap_base = 4096
fmap_max = 512
fmap_decay = 1.
def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)

def generatorBase():
  resolution = 4
  res = int(np.log2(resolution))

  inputLayer = layers.Input((512))
  x = PixelNorm()(inputLayer)
  x = layers.Dense(4*4*nf(res-1), use_bias=True, kernel_initializer='random_normal', activation='relu')(x)
  # x = layers.BatchNormalization(momentum=0.8)(x)
  # x = layers.LeakyReLU(alpha=0.2)(x) 
  x = layers.Reshape((4, 4, nf(res-1)))(x)
  x = PixelNorm()(x)

  # x = tf.keras.layers.UpSampling2D()(x)
  # x = tf.keras.layers.GaussianNoise(0.1)(x)

  # x = layers.Conv2D(256, kernel_size=(3, 3), padding='same')(x)
  # x = layers.BatchNormalization(momentum=0.8)(x)
  # x = layers.LeakyReLU()(x)

  x = layers.Conv2D(nf(res-1), kernel_size=(3, 3), padding='same', kernel_initializer='he_uniform')(x)
  # x = layers.BatchNormalization(momentum=0.8)(x)
  x = layers.LeakyReLU(0.2)(x)
  x = PixelNorm(name='final_4')(x)

  # x = tf.keras.layers.GaussianNoise(0.1)(x)
 
  out = layers.Conv2D(3, (4, 4), strides=(1, 1), padding='same', use_bias=True, activation='linear', name='out_4', kernel_initializer='he_uniform')(x)
  print(out.shape)
  model = tf.keras.models.Model(inputs=inputLayer, outputs=out)
  return model
 
def generatorAddStage(gen, resolution=8, freeze=False, initialAlpha=0):
  print("Current Shape: ", gen.output.shape)

  res = int(np.log2(resolution))
 
  if freeze:
    print("Freezing")
    gen.trainable = False
  
  newDepth = nf(res - 1)
  x = gen.get_layer('final_'+str(resolution // 2)).output
  print("Choosing layer ", x)
 
  print("New Depth: ", newDepth)
  
  # x = layers.Conv2DTranspose(newDepth, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
  x = layers.UpSampling2D(size=(2,2), interpolation='nearest')(x)

  x = layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same', use_bias=True, kernel_initializer='he_uniform')(x)
  # x = layers.BatchNormalization(momentum=0.8)(x)
  x = layers.LeakyReLU(0.2)(x)
  x = PixelNorm()(x)

  x = layers.Conv2D(newDepth, (3, 3), strides=(1, 1), padding='same', use_bias=True, kernel_initializer='he_uniform')(x)
  # x = layers.BatchNormalization(momentum=0.8)(x)
  x = layers.LeakyReLU(0.2)(x)
  x = PixelNorm(name='final_'+str(resolution))(x)
 
  # x = tf.keras.layers.GaussianNoise(0.3)(x)

  # out = layers.Conv2DTranspose(3, (7, 7), strides=(2, 2), padding='same', use_bias=False, activation='tanh')(out)
  out = layers.Conv2D(3, (4, 4), strides=(1, 1), padding='same', use_bias=True, activation='linear', name='out_'+str(resolution), kernel_initializer='he_uniform')(x)
  print("New Shape: ", out.shape)
 
  # Add prev output
  lastOut = gen.get_layer('out_'+str(resolution // 2)).output
  up = layers.UpSampling2D((2,2), interpolation='nearest')(lastOut)
 
  alpha = FadeAdd()
  out = alpha([out, up])
 
  inputLayer = gen.input
  model = tf.keras.models.Model(inputs=inputLayer, outputs=out)
  return model, alpha
 
def changeGenAlpha(gen, alpha, step=0.1):
  newAlpha = alpha + step
  out = gen.layers[-1]
 
def reBaseModel(layers, inpTensor):
  layer = inpTensor
  # print("Rebasing")
  for i in range(len(layers)):
    # print(layer)
    layer = layers[i](layer)
  # print("Done")
  return layer
 
def descriminatorBase():
  resolution = 4
  res = int(np.log2(resolution))
  
  inputLayer = layers.Input((4, 4, 3))
  # x = tf.keras.layers.GaussianNoise(0.00)(inputLayer)
  x = inputLayer

  x = tf.keras.layers.Conv2D(nf(res-1), (1, 1), padding='same', name='sup_conv_4', kernel_initializer='he_uniform')(x)
  # x = tf.keras.layers.BatchNormalization(name='sup_bn_4')(x)
  x = tf.keras.layers.LeakyReLU(0.2, name='sup_act_4')(x)

  # x = reBaseModel(processingLayers, x)
 
  baseLayers = []
  baseLayers.append(MinibatchStddev())
  baseLayers.append(layers.Conv2D(nf(res-1), kernel_size=(3, 3), strides=(1, 1), padding='same', name='depth_4', kernel_initializer='he_uniform'))
  baseLayers.append(layers.BatchNormalization(momentum=0.8))
  baseLayers.append(layers.LeakyReLU(0.2))
 
  baseLayers.append(layers.Flatten())
  baseLayers.append(layers.Dense(nf(res-2), kernel_initializer='he_uniform'))
  baseLayers.append(layers.LeakyReLU(0.2))
  baseLayers.append(layers.Dense(1, kernel_initializer='he_uniform'))
 
  encOut = reBaseModel(baseLayers[:-1], x)
  desOut = reBaseModel(baseLayers, x)
 
  dis = tf.keras.models.Model(inputs=inputLayer, outputs=desOut)
  enc = tf.keras.models.Model(inputs=inputLayer, outputs=encOut)
  return dis, enc, baseLayers
 
def descriminatorAddStage(des, enc, baseLayers, resolution=8, freeze=False):
  print("Current Shape: ", enc.input.shape)
 
  if freeze:
    print("Freezing")
    des.trainable = False
    enc.trainable = False
 
  print("Previous Input layer ", enc.input)
    
  newSize = resolution
 
  res = int(np.log2(resolution))

  print("New input ", newSize, newSize)
 
  inputLayer = layers.Input((newSize, newSize, 3))
  # inp = tf.keras.layers.GaussianNoise(0.00)(inputLayer) 
  inp = inputLayer

  processingLayers = []
  processingLayers.append(tf.keras.layers.Conv2D(nf(res-1), (1, 1), padding='same', name='sup_conv_'+str(resolution), kernel_initializer='he_uniform'))
  # processingLayers.append(tf.keras.layers.BatchNormalization(name='sup_bn_'+str(resolution)))
  processingLayers.append(tf.keras.layers.LeakyReLU(name='sup_act_'+str(resolution)))
  
  x = reBaseModel(processingLayers, inp)

  newLayers = []
  newLayers.append(layers.Conv2D(nf(res-1), (3, 3), strides=(1, 1), padding='same', name='depth_'+str(resolution), kernel_initializer='he_uniform'))
  # newLayers.append(layers.BatchNormalization(momentum=0.8))
  newLayers.append(layers.LeakyReLU())
  newLayers.append(layers.Conv2D(nf(res-2), (3, 3), strides=(1, 1), padding='same', name='depth2_'+str(resolution), kernel_initializer='he_uniform'))
  # newLayers.append(layers.BatchNormalization(momentum=0.8))
  newLayers.append(layers.LeakyReLU())
  newLayers.append(layers.AveragePooling2D())
  
  newInp = reBaseModel(newLayers, x)
  # print(newInp, baseLayers)
 
  small = layers.AveragePooling2D((2, 2))(inp)
  sup = des.get_layer('sup_conv_'+str(resolution // 2))(small)
  # sup = des.get_layer('sup_bn_'+str(resolution // 2))(sup)
  sup = des.get_layer('sup_act_'+str(resolution // 2))(sup)

  print("====>", sup)
 
  print(newInp.shape, sup.shape)
  beta = FadeAdd()
  out = beta([newInp, sup])

  print("==>", out)
 
  desOut = reBaseModel(baseLayers, out)
  encOut = reBaseModel(baseLayers[:-1], out)
 
  des = tf.keras.models.Model(inputs=inputLayer, outputs=desOut)
  enc = tf.keras.models.Model(inputs=inputLayer, outputs=encOut)
 
  baseLayers = newLayers + baseLayers
  return des, enc, baseLayers, beta
 
def generateBaseModels():
  gen = generatorBase()
  des, enc, baseLayers = descriminatorBase()
  return gen, des, enc, baseLayers

In [None]:
gen, des, enc, baseLayers = generateBaseModels()

(None, 4, 4, 3)


In [None]:
!nvidia-smi

Wed Oct 21 09:06:20 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    33W / 250W |   4651MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
def lerp(a, b, t):
    return a + (b - a) * t

def discriminator_loss(real_output, fake_output, wgan_target=1., wgan_lambda=10.):
    wgan_loss = fake_output - real_output
    return wgan_loss

def descriminator_WGANGPloss(reals, fakes, des, batch_size, apply_penalty=True, wgan_target=1., wgan_lambda=10., wgan_epsilon=0.001):
    real_output = des(reals, training=True)
    fake_output = des(fakes, training=True)
    wgan_loss = discriminator_loss(real_output, fake_output)
    
    if apply_penalty:
      mixing_factors = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0, dtype=tf.float32)

      mixed_images = lerp(reals, fakes, mixing_factors)
      
      with tf.GradientTape() as gp_tape:
        gp_tape.watch(mixed_images)
        # 1. Get the discriminator output for this interpolated image.
        mixed_output = des(mixed_images, training=True)
        # mixed_loss = tf.reduce_sum(mixed_output)
        mixed_loss = mixed_output

      mixed_gradients = gp_tape.gradient(mixed_loss, [mixed_images])[0]
      mixed_norms = tf.sqrt(tf.reduce_sum(tf.math.square(mixed_gradients), axis=[1,2,3]))
      gradient_penalty = tf.math.square(mixed_norms - wgan_target)
      total_loss = wgan_loss + (gradient_penalty * (wgan_lambda / (wgan_target**2))) + wgan_epsilon*(tf.square(real_output))
    else:
      total_loss = wgan_loss
    return tf.reduce_mean(total_loss), real_output, fake_output
 
def generator_loss(fake_output):
    total_loss = -tf.reduce_mean(fake_output)
    return total_loss
 
def generator_enc_loss(real, fake):
  # return tf.reduce_mean(tf.abs(real - fake))
  return tf.abs(real - fake)#tf.keras.losses.mean_absolute_error(real, fake)

def generator_hinge_loss(fake_output):
  gen_loss = K.mean(fake_output)

def descriminator_hinge_loss(real, fakes, des, batch_size, apply_penalty=True, wgan_target=1., penalty_lambda=10):
  with tf.GradientTape() as gp_tape:
    real_output = des(reals, training=True)
    fake_output = des(fakes, training=True)

  divergence = K.mean(K.relu(1 + real_output) + K.relu(1 - fake_output))
  disc_loss = divergence

  if apply_penalty:
    gradient = gp_tape.gradient(real_output, [real])[0]
    penalty = K.mean(K.sum(tf.math.square(gradient), axis=np.arange(1, len(gradient.shape)))) * penalty_lambda
    disc_loss += penalty
  return disc_loss

In [None]:
# @tf.function
def trainGenEnc(gen, enc, real, batch_size, coeff=1, generator_optimizer=None, enc_optimizer=None):
  with tf.GradientTape() as enc_tape, tf.GradientTape() as gen_tape:
    real_enc = enc(real, training=True)
    enc_fake = gen(real_enc, training=True)
 
    gen_loss = generator_enc_loss(real, enc_fake) * coeff
    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
 
    gradients_of_enc = enc_tape.gradient(gen_loss, enc.trainable_variables)
    enc_optimizer.apply_gradients(zip(gradients_of_enc, enc.trainable_variables))
 
# @tf.function
def trainDes(gen, des, real, batch_size, hinge=False, discriminator_optimizer=None):
  with tf.GradientTape() as disc_tape:
    noise = tf.random.normal([batch_size, 512])
 
    fake = gen(noise, training=False)
    
    if hinge:
      des_loss, real_output, fake_output = descriminator_hinge_loss(real, fake, des, batch_size)
    else:
      des_loss, real_output, fake_output = descriminator_WGANGPloss(real, fake, des, batch_size)

    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))
 
# @tf.function
def trainDesGen(gen, des, real, batch_size, hinge=False, generator_optimizer=None, discriminator_optimizer=None):
  with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])
    fake = gen(noise, training=True)
    
    if hinge:
      des_loss, real_output, fake_output = descriminator_hinge_loss(real, fake, des, batch_size)
      gen_loss = generator_hinge_loss(fake_output)
    else:
      des_loss, real_output, fake_output = descriminator_WGANGPloss(real, fake, des, batch_size)
      gen_loss = generator_loss(fake_output)
 
    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))
 
    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
 
# @tf.function
def trainGen(gen, des, batch_size, hinge=False, generator_optimizer=None):
  with tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])
 
    fake = gen(noise, training=True)
    fake_output = des(fake, training=False)
 
    if hinge:
      gen_loss = generator_hinge_loss(fake_output)
    else:
      gen_loss = generator_loss(fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
 
def evalGan(gen, des, data, batches, batch_size):
  desAcc = 0
  genLoss = 0
  for i in range(batches):
    real = data
    fake = gen.predict(tf.random.normal([batch_size, 512]))
    X = tf.concat((fake, real), axis=0)
 
    output = des.predict(X)
 
    real_output = output[batch_size:]
    fake_output = output[:batch_size]
 
    labels = tf.reshape(tf.concat((tf.zeros_like(fake_output), tf.ones_like(real_output)), axis=0), [-1])
    output = tf.reshape(output, [-1])
    
    acc = tf.keras.metrics.binary_accuracy(labels, output, threshold=0.5)
    desAcc += acc.numpy()
    # print(acc)
    genLoss += tf.reduce_sum(generator_loss(fake_output)).numpy() / batch_size
  return desAcc / batches, genLoss / batches
 
def augmenter(size):
  def augment(sample):
    sample['image'] = tf.image.resize(sample['image'], [size, size], method='nearest')
    return sample
  return augment

from IPython.display import clear_output

def trainGan(data, name='A02', modeldir='/content/gdrive/My Drive/AI Research/GANs/models/', loadModel=None, epochs=10, 
             batchSize=5, iters=206, loss='mse', hinge=False, sizes=[4, 8, 16, 32, 64, 128, 256], des_steps=1, gen_steps=1, enc_steps=1, grow=True,
             gen_alpha=None, des_alpha=None):
  global gen, des, enc, baseLayers
  realData = data
  # print(realData.shape)
  noise = tf.random.normal([64, 512])
  results = []

  generator_optimizer = tfa.optimizers.Lookahead(tf.keras.optimizers.Adam(1e-4, beta_1 = 0, beta_2 = 0.999))
  discriminator_optimizer = tfa.optimizers.Lookahead(tf.keras.optimizers.Adam(1e-4, beta_1 = 0, beta_2 = 0.999))
  enc_optimizer = tfa.optimizers.Lookahead(tf.keras.optimizers.Adam(3e-4, beta_1 = 0, beta_2 = 0.999))

  for size in [4, 8, 16, 32, 64, 128, 256]:
    batch_size = int((nf(int(np.log2(size)) - 2) / 512) * batchSize)
    print("Current Batch Size", batch_size)
    if loadModel == size:
      gen.load_weights(modeldir + name + '_' + str(size) + '_gen.h5')
      des.load_weights(modeldir + name + '_' + str(size)  + '_des.h5')
      enc.load_weights(modeldir + name + '_' + str(size)  + '_enc.h5')

    if size in sizes:
      print("Size", size)
      grow = True
      coeff = 1
      print("Input shape: ",des.input.shape)
      currentData = realData.map(augmenter(size)).batch(batch_size, drop_remainder=True).shuffle(4096).repeat().prefetch(tf.data.experimental.AUTOTUNE)
      REAL = next(iter(realData.map(augmenter(size)).batch(64)))['image']
      REAL = (tf.cast(REAL, tf.float32) - 127.5) / 127.5
      iterData = iter(currentData)

      def getTrainers():
        def _trainDes(real):
          trainDes(gen, des, real, batch_size, hinge, discriminator_optimizer)

        def _trainGen():
          trainGen(gen, des, batch_size, hinge, generator_optimizer)

        def _trainGenEnc(real, coeff):
          trainGenEnc(gen, enc, real, batch_size, coeff, generator_optimizer, enc_optimizer)

        def _trainGenDes(real):
          trainDesGen(gen, des, real, batch_size, hinge, generator_optimizer, discriminator_optimizer)

        return tf.function(_trainDes), tf.function(_trainGen), tf.function(_trainGenEnc), tf.function(_trainGenDes)
      
      _trainDes, _trainGen, _trainGenEnc, _trainGenDes = getTrainers()

      @tf.function
      def trainStep(batch):
          if des_alpha != None and gen_alpha != None:
            des_alpha.incrementAlpha(0.1 / ((220*60)/batch_size))
            gen_alpha.incrementAlpha(0.1 / ((220*60)/batch_size))
            
          batch = next(iterData)
          real = batch['image']
          real = (tf.cast(real, tf.float32) - 127.5) / 127.5

          for des_iter in range(des_steps):
            _trainDes(real)
          
          for gen_iter in range(gen_steps):
            _trainGen()
          
          _trainGenDes(real)

          for enc_iter in range(enc_steps):
            if coeff >= 0.15:
              _trainGenEnc(real, coeff)
            else:
              _trainGenEnc(real, 0.15)
        
      for epoch in range(epochs):
        print("Running epoch ", epoch)
        t = time.time()
        for _ in range(iters):
          # print("Itering")
          try:
            trainStep(iterData)
          except Exception as e:
            print("Error in epoch ", epoch, e)
            _trainDes, _trainGen, _trainGenEnc, _trainGenDes = getTrainers()

        coeff *= 0.9

        fake = gen.predict(noise)
        real = REAL
        print("Evaluating:")
        desAcc, genLoss = evalGan(gen, des, real, 10, batch_size)
        results.append({'desAcc':desAcc, 'genLoss':genLoss})
        print("Epoch ", epoch, desAcc, genLoss, "of ", epochs, "Epochs")

        # print("epoch length: ", b)
        print("Real: ")
        plotImages(real)
  
        print("Fake: ")
        plotImages(fake)

        if desAcc > 0.8:
          coeff *= 2
          coeff = min(coeff, 1.)

        if des_alpha != None and gen_alpha != None:
          print("Alpha, Beta: ", gen_alpha.alpha, des_alpha.alpha)


      gen.save(modeldir + name + '_' + str(size) + '_gen.h5')
      des.save_weights(modeldir + name + '_' + str(size)  + '_des.h5')
      enc.save_weights(modeldir + name + '_' + str(size)  + '_enc.h5')
    
    if grow==True:
      des, enc, baseLayers, des_alpha = descriminatorAddStage(des, enc, baseLayers, resolution=size*2, freeze=False)
      gen, gen_alpha = generatorAddStage(gen, resolution=size*2, freeze=False)
    epochs *= 1.6
    epochs = int(epochs)
    clear_output(wait=True)

In [None]:
global gen, des, enc, baseLayers
gen, des, enc, baseLayers = generateBaseModels()
trainGan(lfw, name='New02', epochs=10, batchSize=128, hinge=False, des_steps=1, gen_steps=0, enc_steps=0)

In [None]:
des, enc, baseLayers, des_alpha = descriminatorAddStage(des, enc, baseLayers, resolution=64, freeze=False)
gen, gen_alpha = generatorAddStage(gen, resolution=64, freeze=False)

Current Shape:  (None, 32, 32, 3)
Previous Input layer  Tensor("input_44:0", shape=(None, 32, 32, 3), dtype=float32)
New input  64 64
====> Tensor("sup_act_32/LeakyRelu_2:0", shape=(None, 32, 32, 256), dtype=float32)
(None, 32, 32, 256) (None, 32, 32, 256)
==> Tensor("fade_add_22/add:0", shape=(None, 32, 32, 256), dtype=float32)
Current Shape:  (None, 32, 32, 3)
Choosing layer  Tensor("final_32/mul_2:0", shape=(None, 32, 32, 256), dtype=float32)
New Depth:  128
New Shape:  (None, 64, 64, 3)


In [None]:
global gen, des, enc, baseLayers
trainGan(lfw, name='New02', epochs=10, batchSize=128, hinge=False, des_steps=1, gen_steps=0, enc_steps=0, sizes=[64, 128, 256], grow=False, des_alpha=des_alpha, gen_alpha=gen_alpha)

In [None]:
des_alpha = des.get_layer('fade_add_26')
gen_alpha = gen.layers[-1]

In [None]:
global gen, des, enc, baseLayers
trainGan(lfw, name='New02', epochs=10, batchSize=128, hinge=False, des_steps=1, gen_steps=0, enc_steps=0, sizes=[256], grow=False, des_alpha=des_alpha, gen_alpha=gen_alpha)

Current Batch Size 16
Size 256
Input shape:  (None, 256, 256, 3)
Running epoch  0


In [None]:
!ls '/content/gdrive/My Drive/AI Research/GANs/models/'

A01_des.h5     A02_16_gen.h5  A02_8_enc.h5   A04_4_des.h5  A04_8_gen.h5
A01_enc.h5     A02_4_des.h5   A02_8_gen.h5   A04_4_enc.h5
A01_gen.h5     A02_4_enc.h5   A04_16_des.h5  A04_4_gen.h5
A02_16_des.h5  A02_4_gen.h5   A04_16_enc.h5  A04_8_des.h5
A02_16_enc.h5  A02_8_des.h5   A04_16_gen.h5  A04_8_enc.h5


In [None]:
# @tf.function
def trainGenEnc(gen, enc, real, batch_size, coeff=1):
  with tf.GradientTape() as enc_tape, tf.GradientTape() as gen_tape:
    real_enc = enc(real, training=True)
    enc_fake = gen(real_enc, training=True)
 
    gen_loss = generator_enc_loss(real, enc_fake) * coeff
    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
 
    gradients_of_enc = enc_tape.gradient(gen_loss, enc.trainable_variables)
    enc_optimizer.apply_gradients(zip(gradients_of_enc, enc.trainable_variables))
 
# @tf.function
def trainDes(gen, des, real, batch_size, smooth):
  with tf.GradientTape() as disc_tape:
    noise = tf.random.normal([batch_size, 512])
 
    fake = gen(noise, training=False)
 
    X = tf.concat((fake, real), axis=0)
    
    pred = des(X, training=True)
 
    fake_output = pred[:batch_size]
    real_output = pred[batch_size:]
 
    des_loss = discriminator_loss(real_output, fake_output, smooth)
    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))
 
# @tf.function
def trainDesGen(gen, des, real, batch_size, smooth):
  with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])
    fake = gen(noise, training=True)
    X = tf.concat((fake, real), axis=0)
    pred = des(X, training=True)
 
    fake_output = pred[:batch_size]
    real_output = pred[batch_size:]
 
    des_loss = discriminator_loss(real_output, fake_output, smooth)
    gen_loss = generator_loss(fake_output, smooth)
 
    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))
 
    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
 
# @tf.function
def trainGen(gen, des, batch_size, smooth):
  with tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 512])
 
    fake = gen(noise, training=True)
    fake_output = des(fake, training=False)
 
    gen_loss = generator_loss(fake_output, smooth)
 
    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
 
def evalGan(gen, des, data, batches, batch_size):
  desAcc = 0
  genLoss = 0
  for i in range(batches):
    real = data
    fake = gen.predict(tf.random.normal([batch_size, 512]))
    X = tf.concat((fake, real), axis=0)
 
    output = des.predict(X)
 
    real_output = output[batch_size:]
    fake_output = output[:batch_size]
 
    labels = tf.reshape(tf.concat((tf.zeros_like(fake_output), tf.ones_like(real_output)), axis=0), [-1])
    output = tf.reshape(output, [-1])
    
    acc = tf.keras.metrics.binary_accuracy(labels, output, threshold=0.5)
    desAcc += acc.numpy()
    # print(acc)
    genLoss += tf.reduce_sum(generator_loss(fake_output, 1)).numpy() / batch_size
  return desAcc / batches, genLoss / batches
 
def augmenter(size):
  def augment(sample):
    sample['image'] = tf.image.resize(sample['image'], [size, size], method='nearest')
    return sample
  return augment

from IPython.display import clear_output

def trainGan(data, name='A02', modeldir='/content/gdrive/My Drive/AI Research/GANs/models/', epochs=10, batch_size=5, loss='mse', smooth=1., sizes=[4, 8, 16, 32, 64, 128, 256]):
  global gen, des, enc, baseLayers
  realData = data
  # print(realData.shape)
  noise = tf.random.normal([64, 512])
  results = []
  gen_alpha, des_alpha = None, None
  for size in sizes:
    coeff = 1
    print("Input shape: ",des.input.shape)
    currentData = realData.map(augmenter(size)).batch(batch_size, drop_remainder=True).shuffle(4096).prefetch(tf.data.experimental.AUTOTUNE)
    for epoch in range(epochs):
      b = 0
      for batch in currentData:
        b += 1
        if des_alpha != None and gen_alpha != None:
          des_alpha.incrementAlpha(0.1 / ((220*60)/batch_size))
          gen_alpha.incrementAlpha(0.1 / ((220*60)/batch_size))

        real = batch['image']
        real = (tf.cast(real, tf.float32) - 127.5) / 127.5
 
        trainDes(gen, des, real, batch_size, smooth)
        trainDes(gen, des, real, batch_size, smooth)
        trainGen(gen, des, batch_size, smooth)
        trainDesGen(gen, des, real, batch_size, smooth)
        if coeff >= 0.15:
          trainGenEnc(gen, enc, real, batch_size, coeff)
        else:
          trainGenEnc(gen, enc, real, batch_size, 0.2)
        # trainDesGenEnc(gen, des, enc, real, batch_size)

      coeff *= 0.9

      fake = gen.predict(noise)
      print("Evaluating:")
      desAcc, genLoss = evalGan(gen, des, real, 10, batch_size)
      results.append({'desAcc':desAcc, 'genLoss':genLoss})
      print("Epoch ", epoch, desAcc, genLoss, "of ", epochs, "Epochs")

      print("epoch length: ", b)
      print("Real: ")
      plotImages(real)
 
      print("Fake: ")
      plotImages(fake)

      if desAcc > 0.8:
        coeff *= 2
        coeff = min(coeff, 1.)

      if des_alpha != None and gen_alpha != None:
        print("Alpha, Beta: ", gen_alpha.alpha, des_alpha.alpha)


    gen.save(modeldir + name + '_' + str(size) + '_gen.h5')
    des.save_weights(modeldir + name + '_' + str(size)  + '_des.h5')
    enc.save_weights(modeldir + name + '_' + str(size)  + '_enc.h5')
    des, enc, baseLayers, des_alpha = descriminatorAddStage(des, enc, baseLayers, freeze=False)
    gen, gen_alpha = generatorAddStage(gen, freeze=False)
    epochs *= 1.6
    epochs = int(epochs)
    clear_output(wait=True)

NameError: ignored