# Importing necessary libraries 

In [4]:
import tensorflow as tf
import matplotlib.pyplot as plt
import collections
import random
import numpy as np
import os
import time
import json
from PIL import Image

In [5]:
os.path.abspath('.') # gives current path

'/home/dwip.dalal/Openvivo/Text to image generation'

# Loading the dataset

In [6]:
dataset = os.path.abspath('.') + "/annotations/captions_train2014.json"
with open(dataset, 'r') as f:
    data = json.load(f) #loaded the annotations file in annotations 

In [None]:
data.keys()

### Let's make a dictionary that stores images paths as key and captions as values

In [None]:
image_path_to_caption = collections.defaultdict(list)
for val in data['annotations']:
  caption = f"<start> {val['caption']} <end>"
  image_path =  os.path.abspath('.') + '/train2014/' + 'COCO_train2014_' + '%012d.jpg' % (val['image_id'])
  image_path_to_caption[image_path].append(caption)

In [None]:
print(next(iter(image_path_to_caption.keys()))) #here we can see that the list of image path forms the key of image_path_To_caption

In [None]:
len(list(image_path_to_caption.keys()))

### defining training dataset

In [None]:
train_image_paths = list(image_path_to_caption.keys())[:3000]
random.shuffle(train_image_paths)

In [None]:
train_captions = []
image_vector = []

for image_path in train_image_paths:
  caption_list = image_path_to_caption[image_path]
  train_captions.extend(caption_list)
  image_vector.extend([image_path]*len(caption_list))

In [None]:
print(image_vector[2])

In [None]:
Image.open(image_vector[0])
print(train_captions[0])

# Tokenization

In [None]:
caption_dataset = tf.data.Dataset.from_tensor_slices(train_captions)
max_length = 50
vocabulary_size = 5000
tokenizer = tf.keras.layers.TextVectorization(max_tokens=vocabulary_size, output_sequence_length=max_length)
tokenizer.adapt(caption_dataset)
caption_vec = caption_dataset.map(lambda x: tokenizer(x))

In [None]:
word_to_index = tf.keras.layers.StringLookup(mask_token="", vocabulary=tokenizer.get_vocabulary())
index_to_word = tf.keras.layers.StringLookup(mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True)

### Loading inceptionV2 and passing image through it so as to get abstract representation of the image 

In [None]:
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.keras.layers.Resizing(299, 299)(img)  # since we shall use InceptionV3 
    img = tf.keras.applications.inception_v3.preprocess_input(img)
    return img, image_path

In [None]:
image_model = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')
new_input = image_model.input
hidden_layer = image_model.layers[-1].output
image_features_extract_model = tf.keras.Model(new_input, hidden_layer)

encode_train = sorted(set(image_vector))

image_dataset = tf.data.Dataset.from_tensor_slices(encode_train)
image_dataset = image_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE).batch(16)

for img, path in image_dataset:
  batch_features = image_features_extract_model(img)
  batch_features = tf.reshape(batch_features, (batch_features.shape[0], -1, batch_features.shape[3]))

  for bf, p in zip(batch_features, path):
    path_of_feature = p.numpy().decode("utf-8")
    np.save(path_of_feature, bf.numpy())

# Splitting the dataset

In [None]:
img_to_caption_vec = collections.defaultdict(list)
for img, cap in zip(image_vector, caption_vec):
  img_to_caption_vec[img].append(cap)

img_keys = list(img_to_caption_vec.keys())
random.shuffle(img_keys)

image_name_train_keys, image_name_val_keys = img_keys[:int(len(img_keys)*0.8)], img_keys[int(len(img_keys)*0.8):]

image_name_train = []
caption_train = []

for imgt in image_name_train_keys:
  capt_len = len(img_to_caption_vec[imgt])
  image_name_train.extend([imgt] * capt_len)
  caption_train.extend(img_to_caption_vec[imgt])

image_name_val = []
cap_val = []
for imgv in image_name_val_keys:
  capv_len = len(img_to_caption_vec[imgv])
  image_name_val.extend([imgv] * capv_len)
  cap_val.extend(img_to_caption_vec[imgv])

In [None]:
len(image_name_train), len(caption_train), len(image_name_val), len(cap_val)

In [None]:
BATCH_SIZE = 64
BUFFER_SIZE = 1000
embedding_dim = 256
units = 512
num_steps = len(image_name_train) // BATCH_SIZE
# Shape of the vector extracted from InceptionV3 is (64, 2048)
features_shape = 2048
attention_features_shape = 64

In [None]:
# Load the numpy files
def map_func(image_name, cap):
  img_tensor = np.load(image_name.decode('utf-8')+'.npy')
  return img_tensor, cap

In [None]:
len(image_name_train)

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((image_name_train, caption_train))

# Use map to load the numpy files in parallel
dataset = dataset.map(lambda item1, item2: tf.numpy_function(map_func, [item1, item2], [tf.float32, tf.int64]),
          num_parallel_calls=tf.data.AUTOTUNE)

# Shuffle and batch
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

# Model Architecture

In [None]:
class BahdanauAttention(tf.keras.Model):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, features, hidden):  
    hidden_with_time_axis = tf.expand_dims(hidden, 1)
    attention_hidden_layer = (tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis)))
    score = self.V(attention_hidden_layer)
    attention_weights = tf.nn.softmax(score, axis=1)
    context_vector = attention_weights * features
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights

In [None]:
class CNN(tf.keras.Model):
    def __init__(self, embedding_dim):
        super(CNN, self).__init__()
        self.fc = tf.keras.layers.Dense(embedding_dim)

    def call(self, x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        return x

In [None]:
class RNN(tf.keras.Model):
  def __init__(self, embedding_dim, units, vocab_size):
    super(RNN, self).__init__()
    self.units = units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')
    self.fc1 = tf.keras.layers.Dense(self.units)
    self.fc2 = tf.keras.layers.Dense(vocab_size)
    self.attention = BahdanauAttention(self.units)

  def call(self, x, features, hidden):
    context_vector, attention_weights = self.attention(features, hidden)
    x = self.embedding(x)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
    output, state = self.gru(x)
    x = self.fc1(output)
    x = tf.reshape(x, (-1, x.shape[2]))
    x = self.fc2(x)

    return x, state, attention_weights

  def reset_state(self, batch_size):
    return tf.zeros((batch_size, self.units))

In [None]:
encoder = CNN(embedding_dim)
decoder = RNN(embedding_dim, units, tokenizer.vocabulary_size())

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)
  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

In [None]:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

In [None]:
start_epoch = 0
if ckpt_manager.latest_checkpoint:
  start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
  ckpt.restore(ckpt_manager.latest_checkpoint)

In [None]:
loss_plot = []

In [None]:
@tf.function
def train_step(img_tensor, target):
  loss = 0
  hidden = decoder.reset_state(batch_size=target.shape[0])
  dec_input = tf.expand_dims([word_to_index('<start>')] * target.shape[0], 1)

  with tf.GradientTape() as tape:
      features = encoder(img_tensor)

      for i in range(1, target.shape[1]):
          predictions, hidden, _ = decoder(dec_input, features, hidden)
          loss += loss_function(target[:, i], predictions)
          dec_input = tf.expand_dims(target[:, i], 1)

  total_loss = (loss / int(target.shape[1]))
  trainable_variables = encoder.trainable_variables + decoder.trainable_variables
  gradients = tape.gradient(loss, trainable_variables)
  optimizer.apply_gradients(zip(gradients, trainable_variables))

  return loss, total_loss

In [None]:
# for (batch, (img_tensor, target)) in enumerate(dataset):
#         print(img_tensor.shape)

In [None]:
EPOCHS = 25

for epoch in range(start_epoch, EPOCHS):
    start = time.time()
    total_loss = 0

    for (batch, (img_tensor, target)) in enumerate(dataset):
#         print(img_tensor.shape, target.shape)
        batch_loss, t_loss = train_step(img_tensor, target)
        total_loss += t_loss

        if batch % 100 == 0:
            average_batch_loss = batch_loss.numpy()/int(target.shape[1])
            print(f'Epoch {epoch+1} Batch {batch} Loss {average_batch_loss:.4f}')
    # storing the epoch end loss value to plot later
    loss_plot.append(total_loss / num_steps)

    if epoch % 5 == 0:
      ckpt_manager.save()

    print(f'Epoch {epoch+1} Loss {total_loss/num_steps:.6f}')
    print(f'Time taken for 1 epoch {time.time()-start:.2f} sec\n')

In [None]:
plt.plot(loss_plot)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Plot')
plt.show()

In [None]:
def evaluate(image):
    attention_plot = np.zeros((max_length, attention_features_shape))

    hidden = decoder.reset_state(batch_size=1)
    temp_input = tf.expand_dims(load_image(image)[0], 0)
    img_tensor_val = image_features_extract_model(temp_input)
    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))

    features = encoder(img_tensor_val)

    dec_input = tf.expand_dims([word_to_index('<start>')], 0)
    result = []

    for i in range(max_length):
        predictions, hidden, attention_weights = decoder(dec_input, features, hidden)

        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()

        predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()
        predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())
        result.append(predicted_word)

        if predicted_word == '<end>':
            return result, attention_plot

        dec_input = tf.expand_dims([predicted_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot

In [None]:
def plot_attention(image, result, attention_plot):
    temp_image = np.array(Image.open(image))

    fig = plt.figure(figsize=(10, 10))

    len_result = len(result)
    for i in range(len_result):
        temp_att = np.resize(attention_plot[i], (8, 8))
        grid_size = max(int(np.ceil(len_result/2)), 2)
        ax = fig.add_subplot(grid_size, grid_size, i+1)
        ax.set_title(result[i])
        img = ax.imshow(temp_image)
        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())

    plt.tight_layout()
    plt.show()

In [None]:
# captions on the validation set
rid = np.random.randint(0, len(image_name_val))
image = image_name_val[rid]
real_caption = ' '.join([tf.compat.as_text(index_to_word(i).numpy()) for i in cap_val[rid] if i not in [0]])
result, attention_plot = evaluate(image)
print(result)
print('Real Caption:', real_caption[5:-3])
print('Prediction Caption:', ' '.join(result))
Image.open(image)