In [19]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, UpSampling2D, Dense, MaxPool2D, LeakyReLU, Reshape, Dropout, Flatten, GaussianNoise, Embedding, multiply, concatenate
from tensorflow.keras.losses import BinaryCrossentropy, SparseCategoricalCrossentropy, CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K
import tensorflow as tf

import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import time

In [None]:
def normalize_img(img):
    return (img - 127.5)/127.5

def unnormalize_img(img):
    return (img + 1.) / 2.

In [None]:
img_classes = ['chihuahua', 'chimpanzee', 'dalmatian', 'dolphin', 'fox', 'giant+panda', 'giraffe', 'otter', 'polar+bear', 'zebra']
min_images = float('inf')
for root, dirs, files in os.walk("./AwA2/AwA2-data/Animals_with_Attributes2/JPEGImagesCleaned/", topdown=False):
    if any(x in root for x in img_classes):
        min_images = min(min_images, len(files))
print('Minimum images across all included classes is:', min_images)

In [None]:
train_imgs = []
train_classes = []
class_index = 0
import os
for root, dirs, files in os.walk("./AwA2/AwA2-data/Animals_with_Attributes2/JPEGImagesCleaned/", topdown=False):
    print(f'{len(files)} samples in class {class_index}, ({root})')
    if any(x in root for x in img_classes):
        for name in files[:min_images]:
            train_imgs.append(cv2.cvtColor(cv2.imread(os.path.join(root, name)), cv2.COLOR_RGB2BGR))
            train_classes.append(class_index)
        plt.imshow(train_imgs[-1])
        plt.title(img_classes[class_index])
        plt.show()
        class_index += 1
    if class_index >= 10:
        break
train_imgs = np.array(train_imgs)
train_imgs = (train_imgs - 127.5)/127.5
train_classes = np.array(train_classes, dtype=np.float32)

In [None]:
plt.hist(train_classes, bins=10)
plt.show()
assert(len(train_imgs) == len(train_classes))

In [None]:
def make_generator_model(shape):
    
    noise_in = Input(shape=shape)
    label_in = Input(shape=(1,))
    label_embedding = Embedding(10, 100)(label_in)
    
    input_layer = multiply([noise_in, label_embedding])
    
    x = Dense(16*16*64*2)(input_layer)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    x = Reshape((16, 16, 128))(x)
    
    x = GaussianNoise(1)(x)
    x = Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    x = GaussianNoise(1)(x)
    x = Conv2DTranspose(128, (3, 3), strides=(1, 1), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    x = GaussianNoise(1)(x)
    x = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    x = GaussianNoise(1)(x)
    x = Conv2DTranspose(32, (3, 3), strides=(1, 1), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    x = Conv2DTranspose(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)
    
    return Model([noise_in, label_in], x)

In [None]:
generator = make_generator_model((100,))
generator.summary()

In [None]:
generator = make_generator_model((100,))

noise = np.random.normal(0, 1, 100)
label = 1

generated_image = generator.predict( [np.array([noise]), np.array([label])] )[0]
plt.imshow(unnormalize_img(generated_image), )

In [None]:
def make_discriminator_model():
    input_layer = Input((64, 64, 3))
    x = GaussianNoise(1)(input_layer)
    
    x = Conv2D(64, (3, 3), strides=(2, 2), padding='same')(input_layer)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)

    
    x = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    
    
    x = Flatten()(x)
    x = Dense(64)(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    
    valid = Dense(1, activation='sigmoid')(x)
    label = Dense(10, activation='softmax')(x)
    
    return Model(input_layer, [valid, label])

In [None]:
discriminator = make_discriminator_model()
discriminator.summary()
valid, label = discriminator.predict(np.array([generated_image]))
print(valid)
print(label)

In [None]:
cross_entropy = BinaryCrossentropy(from_logits=False)

def discriminator_valid_loss(real_output, fake_output, flip):
    
    if flip:
        fake_labels = tf.random.uniform(fake_output.shape, minval=0.9, maxval=1.0)
        real_labels = tf.random.uniform(real_output.shape, minval=0.0, maxval=0.1)
    else:
        real_labels = tf.random.uniform(real_output.shape, minval=0.9, maxval=1.0)
        fake_labels = tf.random.uniform(fake_output.shape, minval=0.0, maxval=0.1)

        
    real_loss = cross_entropy(real_labels, real_output)
    fake_loss = cross_entropy(fake_labels, fake_output)
    total_loss = real_loss + fake_loss
    return total_loss, real_loss, fake_loss

In [None]:
sparse_cce = SparseCategoricalCrossentropy(from_logits=False)
cce = CategoricalCrossentropy(from_logits=False)

'''
real_classes, fake_classes: List of class ints (0-200)
real_output, fake_output: List of softmax vectors
'''
def discriminator_class_loss_real(real_output, real_classes):
    return sparse_cce(real_classes, real_output) 

def discriminator_class_loss_fake(fake_output, fake_classes):
    return sparse_cce(fake_classes, fake_output)

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
generator_optimizer = Adam(1e-4)
discriminator_optimizer = Adam(1e-4)

In [20]:
@tf.function
def train_step(images, labels, flip, warmup):
    noise = tf.random.truncated_normal([32, 100])
    sampled_labels = tf.cast(tf.random.uniform([32, 1], minval=0, maxval=10, dtype='int32'), 'float32')
#     sampled_labels += tf.random.truncated_normal(sampled_labels.shape, mean=0.0, stddev=0.05)
    labels = tf.reshape(labels, (labels.shape[0], 1))
#     labels += tf.random.truncated_normal(sampled_labels.shape, mean=0.0, stddev=0.05)
    
    
    gen_losses = []
    disc_losses = []
    disc_real_losses = []
    disc_fake_losses = []
    disc_class_real_losses = []
    disc_class_fake_losses = []

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator([noise, sampled_labels], training=True)
        
        real_output, real_label_output = discriminator(images, training=True)
        fake_output, fake_label_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss, disc_real_loss, disc_fake_loss = discriminator_valid_loss(real_output, fake_output, flip)
        disc_class_real_loss = discriminator_class_loss_real(real_label_output, labels)
        disc_class_fake_loss = discriminator_class_loss_fake(fake_label_output, sampled_labels)
        disc_loss += ((2.0 - warmup) * disc_class_real_loss + warmup * disc_class_fake_loss)
        
        gen_losses.append(K.mean(gen_loss))
        disc_losses.append(K.mean(disc_loss))
        disc_real_losses.append(K.mean(disc_real_loss))
        disc_fake_losses.append(K.mean(disc_fake_loss))
        disc_class_real_losses.append(K.mean(disc_class_real_loss))
        disc_class_fake_losses.append(K.mean(disc_class_fake_loss))
    
    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)

    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return (gen_losses, disc_losses, disc_real_losses, disc_fake_losses, disc_class_real_losses, disc_class_fake_losses)

In [21]:
def train(train_imgs, train_labels, epochs, warmup):
    gen_loss_all = []
    disc_loss_all = []
    disc_loss_real_all = []
    disc_loss_fake_all = []
    disc_class_loss_real_all = []
    disc_class_loss_fake_all = []
    warmup_step = 1./5000.
    
    num_samples = len(train_imgs)
    batch_counter = 0
    
    for epoch in range(epochs):
        start = time.time()
        
        gen_loss_epoch = [1, 0]
        disc_loss_epoch = [1, 0]
        disc_loss_real_epoch = [1, 0]
        disc_loss_fake_epoch = [1, 0]
        disc_class_loss_real_epoch = [1, 0]
        disc_class_loss_fake_epoch = [1, 0]
        
        seed = np.random.randint(0, 10000)
        np.random.seed(seed)
        np.random.shuffle(train_imgs)
        np.random.seed(seed)
        np.random.shuffle(train_labels)
        
        for idx, i in enumerate(range(0, num_samples - 32, 32)):
            
            gen_loss_batch, disc_loss_batch, disc_loss_real_batch, disc_loss_fake_batch, disc_class_loss_real_batch, disc_class_loss_fake_batch = train_step(train_imgs[i:(i+32)], train_labels[i:(i+32)], batch_counter % 20 == 0, tf.constant(warmup, dtype='float32'))

            gen_loss_epoch[0] += 1
            disc_loss_epoch[0] += 1
            disc_loss_real_epoch[0] += 1
            disc_loss_fake_epoch[0] += 1
            disc_class_loss_real_epoch[0] += 1
            disc_class_loss_fake_epoch[0] += 1
            gen_loss_epoch[1] += np.mean(gen_loss_batch)
            disc_loss_epoch[1] += np.mean(disc_loss_batch)
            disc_loss_real_epoch[1] += np.mean(disc_loss_real_batch)
            disc_loss_fake_epoch[1] += np.mean(disc_loss_fake_batch)
            disc_class_loss_real_epoch[1] += np.mean(disc_class_loss_real_batch)
            disc_class_loss_fake_epoch[1] += np.mean(disc_class_loss_fake_batch)
            
            batch_counter += 1
            if warmup < 1:
                warmup += warmup_step
            
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
        print(f'Gen loss: {gen_loss_epoch[1]/gen_loss_epoch[0]}, Disc loss: {disc_loss_epoch[1]/disc_loss_epoch[0]}, Disc valid real loss: {disc_loss_real_epoch[1]/disc_loss_real_epoch[0]}, Disc valid fake loss: {disc_loss_fake_epoch[1]/disc_loss_fake_epoch[0]}, Disc class real loss: {disc_class_loss_real_epoch[1]/disc_class_loss_real_epoch[0]}, Disc class fake loss: {disc_class_loss_fake_epoch[1]/disc_class_loss_fake_epoch[0]}')
        
        gen_loss_all.append(gen_loss_epoch[1]/gen_loss_epoch[0])
        disc_loss_all.append(disc_loss_epoch[1]/disc_loss_epoch[0])
        disc_loss_real_all.append(disc_loss_real_epoch[1]/disc_loss_real_epoch[0])
        disc_loss_fake_all.append(disc_loss_fake_epoch[1]/disc_loss_fake_epoch[0])
        disc_class_loss_real_all.append(disc_class_loss_real_epoch[1]/disc_class_loss_real_epoch[0])
        disc_class_loss_fake_all.append(disc_class_loss_fake_epoch[1]/disc_class_loss_fake_epoch[0])
        
    
    return (gen_loss_all, disc_loss_all, disc_loss_real_all, disc_loss_fake_all, disc_class_loss_real_all, disc_class_loss_fake_all)
    


In [22]:
def long_train(starting_checkpoint=0, num_checkpoints=0, checkpoint_interval=0, g_hist=None, d_hist=None, d_real_hist=None, d_fake_hist=None, d_class_real_hist=None, d_class_fake_hist=None, noise=None, label=None, save_files=False, warmup=0.):
    
#     if noise is None:
#         noise = np.random.normal(0, 1, 100)

    if g_hist is None:
        g_hist = []
    
    if d_hist is None:
        d_hist = []
        
    if d_real_hist is None:
        d_real_hist = []
    
    if d_fake_hist is None:
        d_fake_hist = []
        
    if d_class_real_hist is None:
        d_class_real_hist = []
        
    if d_class_fake_hist is None:
        d_class_fake_hist = []

    for i in range(starting_checkpoint, starting_checkpoint + num_checkpoints):
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('')
        print(f'Starting checkpoint {i}')
        print('')
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        
        gen_temp, disc_temp, disc_real_temp, disc_fake_temp, disc_class_real_temp, disc_class_fake_temp = train(train_imgs, train_classes, checkpoint_interval, warmup)
        g_hist += gen_temp
        d_hist += disc_temp
        d_real_hist += disc_real_temp
        d_fake_hist += disc_fake_temp
        d_class_real_hist += disc_class_real_temp
        d_class_fake_hist += disc_class_fake_temp
        generated_image = generator.predict( [np.array([noise]), np.array([label])])[0]
        plt.imshow(unnormalize_img(generated_image))
        if save_files:
            plt.savefig('./results/prog-imgs/animal-gan-conditional-10-samples/' + str((i+1)*checkpoint_interval) + '.png')
        plt.show()
        if save_files:
            generator.save_weights('./results/weights/animal-gan-conditional-10-samples/gen_weights_conditional_10_samples' + str((i+1)*checkpoint_interval) + '.h5')
            discriminator.save_weights('./results/weights/animal-gan-conditional-10-samples/disc_weights_conditional_10_samples' + str((i+1)*checkpoint_interval) + '.h5')
    
    return starting_checkpoint + num_checkpoints, totalepochs

In [23]:
discriminator = make_discriminator_model()
generator = make_generator_model((100,))
generator_optimizer = Adam(1e-4)
discriminator_optimizer = Adam(1e-4)

In [24]:
# To get passed into training function and get modified
# Run cell when starting from scratch
gen_loss_hist = []
disc_loss_hist = []
disc_loss_real_hist = []
disc_loss_fake_hist = []
disc_class_loss_real_hist = []
disc_class_loss_fake_hist = []
noise100 = np.random.normal(0, 1, 100)
label = 1
next_starting_checkpoint = 0
totalepochs = 0
warmup = 0.

In [None]:
print(label)
next_starting_checkpoint, totalepochs = long_train(starting_checkpoint=next_starting_checkpoint, 
                                      num_checkpoints=40, 
                                      checkpoint_interval=250, 
                                      g_hist=gen_loss_hist, 
                                      d_hist=disc_loss_hist,
                                      d_real_hist=disc_loss_real_hist,
                                      d_fake_hist=disc_loss_fake_hist,
                                      d_class_real_hist=disc_class_loss_real_hist,
                                      d_class_fake_hist=disc_class_loss_fake_hist,
                                      noise=noise100,
                                      label=label,
                                      save_files=True,
                                      warmup=warmup)

In [None]:
generator.load_weights('./results/weights/animal-gan-conditional-10-samples/gen_weights_conditional_10_samples10000.h5')

In [None]:
print(img_classes)
# cls = 5
noise = np.random.normal(0, 1, 100)
for label in range(10):
    generated_image = generator.predict([np.array([noise]), np.array([label])] )[0]
    plt.imshow(unnormalize_img(generated_image))
    plt.title(img_classes[label])
    plt.show()

In [None]:
plt.plot(gen_loss_hist)
plt.plot(disc_loss_hist)

plt.legend(['Gen Loss','Disc Loss'])
plt.show()

plt.plot(gen_loss_hist - np.mean(gen_loss_hist))
plt.plot(disc_loss_hist - np.mean(disc_loss_hist))
plt.legend(['Gen Loss Normalized','Disc Loss normalized'])
plt.show()

plt.plot(disc_loss_real_hist)
plt.plot(disc_loss_fake_hist)
plt.legend(['Disc Real Loss','Disc Fake Loss'])
plt.show()

plt.plot(disc_class_loss_real_hist)
plt.plot(disc_class_loss_fake_hist)
plt.legend(['Disc Class Real Loss', 'Disc Class Fake Loss'])
plt.show()


In [None]:
noise = tf.random.truncated_normal([32, 100])
tf.dtypes.cast(noise, tf.int32)


In [None]:
for i in range(20,30):
    plt.imshow(unnormalize_img(train_imgs[i]))
    plt.title(img_classes[int(train_classes[i])])
    plt.show()

In [None]:
intermediate_layer_model = Model(inputs=generator.input[1],
                                 outputs=generator.layers[2].output)


In [None]:
total = None
for label in range(10):
    if total is None:
        total = np.abs(intermediate_layer_model.predict(np.array([label]))[0][0])
    total += np.abs(intermediate_layer_model.predict(np.array([label]))[0][0])
print(total)

In [None]:
intermediate_layer_model.predict(np.array([0]))[0][0]

In [None]:
test_gen = make_generator_model((100,))

In [None]:
test_gen_embedding = Model(inputs=test_gen.input[1],
                                 outputs=test_gen.layers[2].output)

In [None]:
test_gen_embedding.predict(np.array([7]))[0][0]