# Text to Image using Attention GAN

In [76]:
### Parts of this code is adapted from https://github.com/AloneTogetherY/text-to-image-synthesis

In [69]:
# importing files

import tensorflow as tf
import os
import time
import numpy as np
from PIL import Image
import gensim
from nltk.tokenize import word_tokenize
import pandas as pd
import random
import re
from random import randint,choice
from keras.preprocessing.image import array_to_img

from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy import vstack
from numpy import asarray
from keras.callbacks import ModelCheckpoint
from keras.initializers import RandomNormal
from numpy.random import random
from tensorflow.keras import layers
from tensorflow.keras import Model
from numpy.random import randn
from numpy.random import randint
import time
from keras.layers.advanced_activations import PReLU
from keras.utils.vis_utils import plot_model
from tensorflow.keras import initializers

## Data Preprocessing


### Image Embedding

In [None]:
# Embedding the image and storing as numpy file
embedding_file = os.path.join('./embeddings/',
        f'birds_image_embedding') 
start = time.time()
print("Loading training images...")

training_data = []
# flowers_path = sorted(os.listdir(DATA_PATH))


birds_path = sorted(os.listdir('./images/'))

for filename in range(len(birds_path)):
    path = os.path.join('./images/',birds_path[filename])
    if path.endswith('.jpg'):
      try:
        image = Image.open(path).resize((64,64),Image.ANTIALIAS) # reducing the image size into 64px
        channel = np.asarray(image).shape[2]

        training_data.append(np.asarray(image))

          
      except:
        print(birds_path[filename])
training_data = np.reshape(training_data,(-1,64,64,3))     #reshaping numpy array into (64,64,3)
training_data = training_data.astype(np.float32)
     
training_data = training_data / 127.5 - 1.            #Normalizing the input

print("Image embedding finished and saving...")
training_data = training_data[:11776]
np.save(embedding_file + ".npy",training_data)

print (f'Time taken to complete embedding: {time.time()-start}')


### Text Embedding

In [38]:
# This section of code is for text embedding
model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)

vw = np.zeros((11776,52,300))
def clean_and_tokenize_comments_for_image(comment):
    stop_words = ['a', 'and', 'of', 'to']
    punctuation = r"""!"#$%&'()*+,./:;<=>?@[\]^_`…’{|}~"""
    comments_without_punctuation = [s.translate(str.maketrans(' ', ' ', punctuation)) for s in comment]
    sentences = []

    for q_w_c in comments_without_punctuation:
        q_w_c = re.sub(r"-(?:(?<!\b[0-9]{4}-)|(?![0-9]{2}(?:[0-9]{2})?\b))", ' ', q_w_c)  # replace with space

        temp_tokens = word_tokenize(str(q_w_c).lower())
        tokens = [t for t in temp_tokens if t not in stop_words]
        sentences.append(tokens)
    return sentences
    
def getword2vec(word2vec_model, cleaned_comments):
    vectorized_list = []
    sentence_vlist = []
    cleaned_caption = clean_and_tokenize_comments_for_image(cleaned_comments)
    # print(cleaned_caption)
    for i,words in enumerate(cleaned_caption):
        result_array = np.empty((0, 300))
        if i == 11776:
            break
        for word in words:
            
            try:
                    w = [word2vec_model[word]]
                    result_array = np.append(result_array, w, axis=0)
            except KeyError:
                if word in 'superciliary' or word in 'superciliaries':
                    result_array = np.append(result_array, [word2vec_model['eyebrow']], axis=0)
                    result_array = np.append(result_array, [word2vec_model['region']], axis=0)
                elif word in 'rectrices' or word in 'rectices':
                    result_array = np.append(result_array, [word2vec_model['large']], axis=0)
                    result_array = np.append(result_array, [word2vec_model['tail']], axis=0)
                    result_array = np.append(result_array, [word2vec_model['feathers']], axis=0)
                else:
                    result_array = np.append(result_array, [word2vec_model[random.choice(word2vec_model.index_to_key)]], axis=0)

        vectorized_list.append(result_array.astype('float32'))
        sentence_vlist.append(result_array.mean(axis=0).astype('float32'))

    return np.asarray(vectorized_list,dtype='object'),np.asarray(sentence_vlist).astype('float32')
df = pd.read_csv('final.csv')
all_captions = df['captions'].values
vector_word,vector_sentence = getword2vec(model,all_captions)


for i in range(len(vector_sentence)):
    vw[i,:vector_word[i].shape[0]] = vector_word[i]       #padding vector

np.save('./embeddings/bird_sentence_vector.npy',vector_sentence)  
np.save('./embeddings/bird_word_features.npy',vw)  



## Data Loading

In [61]:
image_embedding = np.load('./embeddings/birds_image_embedding.npy')
vw = np.load('./embeddings/bird_word_features.npy')
vector_sentence = np.load('./embeddings/bird_sentence_vector.npy')

captions = pd.read_csv('./captions.csv')
captions = captions['captions'][:11776]

## Defining Discriminator

In [58]:
# Discriminator model
def define_discriminator():
    word_vector_dim = 300
    dropout_prob = 0.4

    in_label = layers.Input(shape=(300,))

    n_nodes = 3 * 64 * 64
    li = layers.Dense(n_nodes)(in_label)
    li = layers.Reshape((64, 64, 3))(li)

    dis_input = layers.Input(shape=(64, 64, 3))

    merge = layers.Concatenate()([dis_input, li])

    discriminator = layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same")(merge)
    discriminator = layers.LeakyReLU(0.2)(discriminator)
    discriminator = layers.GaussianNoise(0.2)(discriminator)
    
    
    discriminator = layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(2, 2), padding="same")(discriminator)
    discriminator = layers.BatchNormalization(momentum=0.5)(discriminator)
    discriminator = layers.LeakyReLU()(discriminator)
    
    discriminator = layers.Conv2D(filters=128, kernel_size=(3, 3), padding="same")(discriminator)
    discriminator = layers.BatchNormalization(momentum=0.5)(discriminator)
    discriminator = layers.LeakyReLU(0.2)(discriminator)
    
    discriminator = layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(2, 2), padding="same")(discriminator)
    discriminator = layers.BatchNormalization(momentum=0.5)(discriminator)
    discriminator = layers.LeakyReLU(0.2)(discriminator)
    
    discriminator = layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same")(discriminator)
    discriminator = layers.BatchNormalization(momentum=0.5)(discriminator)
    discriminator = layers.LeakyReLU(0.2)(discriminator)
    
    discriminator = layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(2, 2), padding="same")(discriminator)
    discriminator = layers.BatchNormalization(momentum=0.5)(discriminator)
    discriminator = layers.LeakyReLU(0.2)(discriminator)
    
    discriminator = layers.Conv2D(filters=512, kernel_size=(3, 3), padding="same")(discriminator)
    discriminator = layers.BatchNormalization(momentum=0.5)(discriminator)
    discriminator = layers.LeakyReLU(0.2)(discriminator)

    discriminator = layers.Flatten()(discriminator)

    discriminator = layers.Dense(1024)(discriminator)

    discriminator = layers.LeakyReLU(0.2)(discriminator)

    discriminator = layers.Dense(1)(discriminator)

    discriminator_model = Model(inputs=[dis_input, in_label], outputs=discriminator)

    discriminator_model.summary()

    return discriminator_model


## Defining Generator

In [59]:

def resnet_block(model, kernel_size, filters, strides):
    gen = model
    model = layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding="same")(model)
    model = layers.BatchNormalization(momentum=0.5)(model)
    model = tf.keras.layers.PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1, 2])(model)
    model = layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding="same")(model)
    model = layers.BatchNormalization(momentum=0.5)(model)
    model = layers.Add()([gen, model])
    return model


# Generator model
def define_generator():
    kernel_init = tf.random_normal_initializer(stddev=0.02)
    batch_init = tf.random_normal_initializer(1., 0.02)
   
    random_input = layers.Input(shape=(100,))
    text_input1 = layers.Input(shape=(300,))
    text_layer1 = layers.Dense(8192)(text_input1)
    text_layer1 = layers.Reshape((8, 8, 128))(text_layer1)

    n_nodes = 128 * 8 * 8
    gen_input_dense = layers.Dense(n_nodes)(random_input)
    generator = layers.Reshape((8, 8, 128))(gen_input_dense)

    merge = layers.Concatenate()([generator, text_layer1])

    model = layers.Conv2D(filters=64, kernel_size=9, strides=1, padding="same")(merge)
    model = tf.keras.layers.PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1, 2])(model)

    gen_model = model

    for _ in range(4):
      model = resnet_block(model, 3, 64, 1)

    model = layers.Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(model)
    model = layers.BatchNormalization(momentum=0.5)(model)
    model = layers.Add()([gen_model, model])

    model = layers.Conv2DTranspose(filters=512, kernel_size=(3, 3), strides=(2, 2), padding="same", kernel_initializer=kernel_init)(model)
    model = layers.LeakyReLU(0.2)(model)

    model = layers.Conv2DTranspose(filters=256, kernel_size=(3, 3), strides=(2, 2), padding="same", kernel_initializer=kernel_init)(model)
    model = layers.LeakyReLU(0.2)(model)

    model = layers.Conv2DTranspose(filters=128, kernel_size=(3, 3), strides=(2, 2), padding="same", kernel_initializer=kernel_init)(model)
    model = layers.LeakyReLU(0.2)(model)

    model = layers.Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="same", kernel_initializer=kernel_init)(model)
    F0_model = layers.LeakyReLU(0.2)(model)
    # model = Model(inputs=[random_input,text_input1], outputs=[model])
    G0_model = layers.Conv2D(3, (3, 3), padding='same', activation='tanh')(F0_model)

    generator_model = Model(inputs=[random_input,text_input1], outputs=[G0_model,F0_model])
    generator_model.summary()

    return generator_model
# define_generator()
    

## Attention Module

In [65]:
## Attention module implemented by myself


def attention_module():
    
    # masking_vector = layers.Input(shape=(52,4096))

    word_feature = layers.Input(shape=(52,300))            # word vectors of feature 300 as input
    new_word_feature = layers.Dense(64)(word_feature)    # changing word vector to (T,64)

    f0_output = layers.Input(shape=(64,64,64))           # f0 output is given as input (64,64,64)
    new_f0_output = layers.Reshape((64,64*64))(f0_output) # changing shape to (64,4096)
    
    s_function = tf.matmul(new_word_feature,new_f0_output) # matmul of (T,64)x(64,4096)
    
    nsd = tf.where(s_function!=0,s_function,tf.float64.min)
    beta = layers.Softmax(axis=0)(nsd) 
    new_beta = tf.where(tf.math.is_nan(beta), tf.zeros_like(beta), beta)                 
    
    c = tf.einsum('ijk,ijl->ikl', new_beta, new_word_feature)  # finding the vector c
    c = tf.linalg.matrix_transpose(c)
    attnout = layers.Reshape((64,64,64))(c)                   # reshaping the output to (64,64,64)
    
    model = Model(inputs=([word_feature, f0_output]),outputs=([attnout]))
    model.summary()
    return model


## Second Generator

In [66]:
## Function defining second generator 

def define_F1():

    input1 = layers.Input(shape=(64,64,64))
    input2 = layers.Input(shape=(64,64,64))

    input1_2 = layers.Concatenate(axis=3)([input1,input2])
    
    conv2d4 = layers.Conv2DTranspose(64,kernel_size=4,padding="same",kernel_initializer=initializers.RandomNormal(stddev=0.02))(input1_2)
    batchNorm4 = layers.BatchNormalization(momentum=0.8)(conv2d4)
    model = layers.LeakyReLU(alpha=0.2)(batchNorm4)

    for _ in range(4):
      model = resnet_block(model, 3, 64, 1)

   
    G1_model = layers.Conv2D(3, (3, 3), padding='same', activation='tanh')(model)
  
    generator_model = Model(inputs=[input1,input2], outputs=[G1_model])
    generator_model.summary()
    return generator_model
    


## Training

In [74]:
from IPython.display import clear_output
import matplotlib.pyplot as pyplot


def generate_random_vectors(n_samples):  
  vectorized_random_captions = []

  for n in range(n_samples):
    rnd = randint(8, 25)
    result_array = np.empty((0, 300))
    for i in range(rnd):
      result_array = np.append(result_array, [model[choice(model.index_to_key)]], axis=0)
    vectorized_random_captions.append(np.mean(result_array, axis=0).astype('float32'))

  return np.array(vectorized_random_captions)

def get_random_word_vectors_from_dataset(n_samples):
  ix = np.random.randint(0, len(vector_sentence), n_samples)
  return np.asarray(vector_sentence)[ix],tf.convert_to_tensor(np.asarray(vw)[ix])

def generate_latent_points(latent_dim, n_samples):
    x_input  = tf.random.normal([n_samples, latent_dim])
    text_captions,word_f = get_random_word_vectors_from_dataset(n_samples)
    return [x_input, text_captions,word_f]  
  

def generate_and_save_images(model1,att, model2, epoch, test_input,word_f):
  
  predictions = model1(test_input, training=False)
  attn = att([word_f,predictions[1]],training = False)
  final_image = model2([predictions[1],attn],training = False)
  print(predictions[1].shape)
  pyplot.figure(figsize=[7, 7])

  for i in range(predictions[1].shape[0]):
      pyplot.subplot(5, 5, i+1)
      pyplot.imshow(array_to_img(predictions[0].numpy()[i]))
      pyplot.axis('off')

  pyplot.savefig('Samples/G0/image_at_epoch_{:04d}.png'.format(epoch))
  pyplot.show()
  
  print(final_image.shape)
  pyplot.figure(figsize=[7, 7])

  for i in range(final_image.shape[0]):
      pyplot.subplot(5, 5, i+1)
      pyplot.imshow(array_to_img(final_image.numpy()[i]))
      pyplot.axis('off')

  pyplot.savefig('Samples/G1/image_at_epoch_{:04d}.png'.format(epoch))
  pyplot.show()



cross_entropy = tf.keras.losses.BinaryCrossentropy()

def discriminator_loss(real_image_real_text, fake_image_real_text, real_image_fake_text):
    real_loss = cross_entropy(tf.random.uniform(real_image_real_text.shape,0.8,1.0), real_image_real_text)
    fake_loss = (cross_entropy(tf.random.uniform(fake_image_real_text.shape,0.0,0.2), fake_image_real_text) + 
                 cross_entropy(tf.random.uniform(real_image_fake_text.shape,0.0,0.2), real_image_fake_text))/2

    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)



In [None]:
@tf.function
def train_step(batch):
  random_captions = generate_random_vectors(BATCH_SIZE)
  
  with tf.GradientTape() as gen_tape1, tf.GradientTape() as disc_tape1:
    noise = tf.random.normal([64,100])
   
    generated_images1 = generator([noise,batch[1]], training=True)

    fake_output_real_text1 = discriminator1([generated_images1[0], batch[1]], training=True)
    real_output_real_text = discriminator1([batch[0], batch[1]], training=True)
    real_output_fake_text = discriminator1([batch[0], random_captions], training=True)


    gen_loss1 = generator_loss(fake_output_real_text1)    # #     #### Calculating losses ####
    disc_loss_1 = discriminator_loss(real_output_real_text, fake_output_real_text1, real_output_fake_text)


  gradients_of_discriminator = disc_tape1.gradient(disc_loss_1, discriminator1.trainable_variables)  
  gradients_of_generator = gen_tape1.gradient(gen_loss1, generator.trainable_variables)    
  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator1.trainable_variables))
  
  #################################################################################################################
  
  
  with tf.GradientTape() as gen_tape2, tf.GradientTape() as disc_tape2:
    

    att_out = att([batch[2],generated_images1[1]],training = True)          #### Attention Block
    generated_images2 = F1_block([att_out,generated_images1[1]],training = True)          #### F1 Block


    fake_output_real_text2 = discriminator2([generated_images2, batch[1]], training=True)
    real_output_real_text2 = discriminator2([batch[0], batch[1]], training=True)
    real_output_fake_text2 = discriminator2([batch[0], random_captions], training=True)
    
   


    gen_loss2 = generator_loss(fake_output_real_text2)  # #     #### Calculating losses ####
    disc_loss_2 = discriminator_loss(real_output_real_text2, fake_output_real_text2, real_output_fake_text2)
  

  gradients_of_discriminator2 = disc_tape2.gradient(disc_loss_2, discriminator2.trainable_variables)  
  gradients_of_generator2 = gen_tape2.gradient(gen_loss2, F1_block.trainable_variables)    
  generator_optimizer2.apply_gradients(zip(gradients_of_generator2, F1_block.trainable_variables))
  discriminator_optimizer2.apply_gradients(zip(gradients_of_discriminator2, discriminator2.trainable_variables))
  
  return gen_loss1+gen_loss2,disc_loss_1+disc_loss_2


def train(data, epochs = 1000):
  checkpoint_dir = 'checkpoints'
  checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
  checkpoint1 = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                discriminator_optimizer=discriminator_optimizer,
                                  generator=generator,
                                  discriminator=discriminator1)
  checkpoint2 = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                discriminator_optimizer=discriminator_optimizer,
                                  generator=F1_block,
                                  discriminator=discriminator2)                                
  
  ckpt_manager1 = tf.train.CheckpointManager(checkpoint1, checkpoint_dir, max_to_keep=3)
  ckpt_manager2 = tf.train.CheckpointManager(checkpoint2, checkpoint_dir, max_to_keep=3)
  print(ckpt_manager1.latest_checkpoint,'1-----------------------------------------')
  print(ckpt_manager2.latest_checkpoint,'2-----------------------------------------')
  if ckpt_manager1.latest_checkpoint:
    checkpoint1.restore(ckpt_manager1.latest_checkpoint)  #ckpt_manager.checkpoints[3]
    print ('Latest checkpoint1 restored!!')
  if ckpt_manager2.latest_checkpoint:
    checkpoint2.restore(ckpt_manager2.latest_checkpoint)  #ckpt_manager.checkpoints[3]
    print ('Latest checkpoint2 restored!!')

  for epoch in range(epochs):
    start = time.time()
    genloss =[]
    discloss =[]
    for batch in data:
      loss = train_step(batch)
      genloss.append(loss[0])
      discloss.append(loss[1]) 
    tf.print('G_loss =====',sum(genloss)/len(genloss))
    tf.print('D_loss =====',sum(discloss)/len(discloss))
    if (epoch +1) % 1 == 0:
      [z_input, labels_input,word_f] = generate_latent_points(100, 25)
      generate_and_save_images(generator,att,F1_block,epoch+1,[z_input, labels_input],word_f) 
    
    if (epoch + 1) % 30 == 0:
      ckpt_save_path = ckpt_manager1.save()
      ckpt_save_path = ckpt_manager2.save()
      print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))

    if (epoch +1) % 50 == 0:
    
      clear_output(wait=True)
      generator.save('models/g0/stage_new_gan_animal_model_%03d.h5' % (epoch + 1)) 
      F1_block.save('models/g1/stage_new_gan_animal_model_%03d.h5' % (epoch + 1))
       

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))


binary_cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.000035, beta_1 = 0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.000035, beta_1 = 0.5)
generator_optimizer2 = tf.keras.optimizers.Adam(learning_rate=0.000035, beta_1 = 0.5)
discriminator_optimizer2 = tf.keras.optimizers.Adam(learning_rate=0.000035, beta_1 = 0.5)

att = attention_module()
F1_block = define_F1()
generator = define_generator()
discriminator1 = define_discriminator()
discriminator2 = define_discriminator()



BUFFER_SIZE = 11776
BATCH_SIZE = 64
train_dataset = tf.data.Dataset.from_tensor_slices((image_embedding,vector_sentence,vw)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

train(train_dataset)