# Image Captioning
This notebook is a hands-on lab provided by Google.

In this notebook an image captioning model is going to be trained using visual attention mechanism.

The main goal is to generate text based on an image as input. This text will describe as approximately as possible the content of the image.

In [None]:
import time
from textwrap import wrap

import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow.keras import Input
from tensorflow.keras.layers import (
    GRU,
    Add,
    AdditiveAttention,
    Attention,
    Concatenate,
    Dense,
    Embedding,
    LayerNormalization,
    Reshape,
    StringLookup,
    TextVectorization,
)

print(tf.version.VERSION)

2.17.1


## Read data.
* **Dataset:** COCO captions.
* **Feature extractor:** `InceptionResNetV2`.

In [None]:
VOCAB_SIZE = 20000 # this can be change to control accuracy/speed
ATTENTION_DIM = 512
WORD_EMBEDDING_DIM = 128

FEATURE_EXTRACTOR = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(
    include_top=False,
    weights="imagenet"
)
IMG_HEIGHT = 299
IMG_WIDTH = 299
IMG_CHANNELS = 3
FEATURE_SHAPE = (8, 8, 1523) # as inception_resnet_v2 feature shape

## Filter and preprocess
* Resize images to defined shape above
* Rescale pixel values for speed up process
* Return images as `image_tensor` and `captions` dictionary

In [None]:
GCS_DIR = "gs://asl-public/data/tensorflow_datasets/"
BUFFER_SIZE = 1000

def get_image_label(example):
  caption = example["captions"]["text"][0] # only first caption per image
  img = example["image"]
  img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH))
  img = img/255
  return {"image_tensor": img, "caption": caption}

trainds = tfds.load("coco_captions", split="train", data_dir=GCS_DIR)

trainds = trainds.map(
    get_image_label,
    num_parallel_calls=tf.data.AUTOTUNE
).shuffle(BUFFER_SIZE)
trainds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

### Visualize examples

In [None]:
f, ax = plt.subplots(1, 4, figsize=(20, 5))
for idx, data in enumerate(trainds.take(4)):
  ax[idx].imshow(data["image_tensor"].numpy())
  caption = "\n".join(wrap(data["caption"].numpy().decode("utf-8"), 30))
  ax[idx].set_title(caption)
  ax[idx].axis("off")

## Text preprocessing
Define special tokens `<start>` and `<end>`

In [None]:
def add_start_end_token(data):
  start = tf.convert_to_tensor("<start>")
  end = tf.convert_to_tensor("<end>")
  data["caption"] = tf.strings.join(
      [start, data["caption"], end], separator=" "
  )
  return data

In [None]:
MAX_CAPTION_LEN = 64 # this can be defined by getting some descriptives on captions lenght

def standardize(inputs):
  inputs = tf.strings.lower(inputs)
  return tf.strings.regex_replace(
      inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", ""
  )

tokenizer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    standardize=standardize,
    output_sequence_length=MAX_CAPTION_LEN
)

tokenizer.adapt(trainds.map(lambda data: data["caption"]))

In [None]:
# test tokenizer
tokenizer(["<start> This is a sentence <end>"])

In [None]:
sample_captions = []
for d in trainds.take(5):
  sample_captions.append(d["caption"].numpy())

In [None]:
# see captions with the standarizing tokens
print(sample_captions)
# see how coul it be tokenized
tokenizer(sample_captions[:2])

In [None]:
# if want to do the inverse process
for wordid in tokenizer(sample_captions[0])[0]:
  print(tokenizer.get_vocabulary()[wordid], end=" ")

In [None]:
# create word to index converters
word_to_index = StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary()
)
index_to_word = StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary(),
    invert=True
)

## Create training dataset
Need to have targets in format `"I love cats <end> <padding>"` instead of `"<start> I love cats <end>"`

In [None]:
BATCH_SIZE = 32

def create_ds_fn(data):
  img_tensor = data["image_tensor"]
  caption = data["caption"]
  target = tf.roll(caption, -1, 0) # here the first word rolls to first position
  zeros = tf.zeros([1], dtype=tf.int64)
  target = tf.concat([target[:-1], zeros], axis=-1)
  return (img_tensor, caption), target

In [None]:
batched_ds = (
    trainds.map(create_ds_fn)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

In [None]:
# see how it'll be
for (img, caption), label in batched_ds.take(2):
    print(f"Image shape: {img.shape}")
    print(f"Caption shape: {caption.shape}")
    print(f"Label shape: {label.shape}")
    print(caption[0])
    print(label[0])

## Model
### The Image encoder
1. Extract features with `InceptionResNetV2`.
2. Reshape vector to (Batch size, 64, 1536)
3. Squash it to a lenght of `ATTENTION_DIM` with a Dense Layer and return (Batch Size, 64 `ATTENTION_DIM`)
4. The attention layer attends over the image to predict the next word.

In [None]:
FEATURE_EXTRACTOR.trainable = False

image_input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
image_features = FEATURE_EXTRACTOR(image_input)

x = Reshape(
    (FEATURE_SHAPE[0] * FEATURE_SHAPE[1], FEATURE_SHAPE[2])
)(img_features)
encoder_output = Dense(ATTENTION_DIM, activation="relu")(x)

In [None]:
encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output)
encoder.summary()

### The Caption decoder
1. Receives a word tokens batch
2. Embeds the word tokens to `ATTENTION_DIM` dimension
3. Pass it to GRU. Returns GRU outputs and states
4. Bahdanau-style attention attends over the encoder's output feature by using GRU output as query
5. Performs an skip connection using GRU (step 3) output and attention's output as well, then these are normalized
6. Generates logit preds for next token

In [None]:
# 1
word_input = Input(shape=(MAX_CAPTION_LEN,), name="words")
# 2
embed_x = Embedding(VOCAB_SIZE, ATTENTION_DIM)(word_input)
# 3
decoder_gru = GRU(
    ATTENTION_DIM,
    return_sequences=True,
    return_state=True,
    name="gru"
)
gru_output, gru_state = decoder_gru(embed_x)
# 4
decoder_attention = Attention()
context_vector = decoder_attention([gru_output, encoder_output])
# 5
addition = Add()([gru_output, context_vector])
layer_norm = LayerNormalization(axis=-1)
layer_norm_out = layer_norm(addition)
# 6
decoder_output_dense = Dense(VOCAB_SIZE)
decoder_output = decoder_output_dense(layer_norm_out)

In [None]:
decoder = tf.keras.Model(
    inputs=[word_input, encoder_output],
    outputs=decoder_output
)
tf.keras.utils.plot_model(decoder)

In [None]:
decoder.summary()

## Model training

In [None]:
# define one single model that compiles
image_caption_train_model = tf.keras.Model(
    inputs=[image_input, word_input],
    outputs=decoder_output
)

# define a loss function to be a cross-entropy
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True,
    reduction="none"
)

def loss_function(real, pred):
  loss_ = loss_object(real, pred)
  # returns 1 to word index and 0 to padding
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  mask = tf.cast(mask, dtype=tf.int32)
  sentence_len = tf.reduce_sum(mask)
  loss_ = loss_[:sentence_len]

  return tf.reduce_mean(loss_, 1)

In [None]:
image_caption_train_model.compile(
    optimizer="adam",
    loss=loss_function,
)

In [None]:
%%time
history = image_caption_train_model.fit(batched_ds, epoch=1)

In [None]:
gru_state_input = Input(shape=(ATTENTION_DIM,), name="gru_state_input")

gru_output, gru_state = decoder_gru(embed_x, initial_state=gru_state_input)

context_vector = decoder_attention([gru_output, encoder_output])
addition_output = Add()([gru_output, context_vector])
layer_norm_output = layer_norm(addition_output)

decoder_output = decoder_output_dense(layer_norm_output)

# define prediction omdel with state input and output
decoder_pred_model = tf.keras.Model(
    inputs=[word_input, gru_state_input, encoder_output],
    otuputs=[decoder_output, gru_state],
)

## Predict captions

In [None]:
MINIMUM_SENTENCE_LENGHT = 5

def predict_caption(filename):
  gru_state = tf.zeros((1, ATTENTION_DIM))

  img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS)
  img = img / 255

  features = encoder(tf.expand_dims(img, axis=0))
  dec_input = tf.expand_dims([word_to_index("<start>")], 1)
  result = []
  for i in range(MAX_CAPTION_LEN):
    predictions, gru_state = decoder_pred_model(
        [dec_input, gru_state, features],
      )
    top_probs, top_idxs = tf.math.top_k(
        input=predictions[0][0],
        k=10,
        sorted=False
    )
    chosen_id = tf.random.categorical([top_probs], 1)[0].numpy()
    predicted_id = top_idx.numpy()[chosen_id][0]

    result.append(tokenizer.get_vocabulary()[predicted_id])

    if predicted_id == word_to_index("<end>"):
      return img, result

    dec_input = tf.expand_dims([predicted_id], 1)

  return img, result

In [None]:
# test it
filename = "../sample_images/baseball.jpeg"

for i in range(5):
  image, caption = predict_caption(filename)
  print(" ".join(caption[:-1]) + ".")

img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS)
plt.imshow(img)
plt.show()