In [None]:
!apt install --allow-change-held-packages libcudnn8=8.6.0.163-1+cuda11.8
!pip uninstall -y tensorflow estimator keras
!pip install -U tensorflow_text tensorflow tensorflow_datasets
!pip install einops

In [None]:
import concurrent.futures
import collections
import dataclasses
import hashlib
import itertools
import json
import math
import os
import pathlib
import random
import re
import string
import time
import urllib.request

import einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import requests
import tqdm

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import tensorflow_datasets as tfds

Get dataset

In [None]:
def conceptual_captions(*, data_dir='conceptual_captions', num_train, num_val):
  def iter_index(index_path):
    with open(index_path) as f:
      for line in f:
        caption, url = line.strip().split('\t')
        yield caption, url

  def download_image_urls(data_dir, urls):
    ex = concurrent.futures.ThreadPoolExecutor(max_workers=100)
    def save_image(url):
      hash = hashlib.sha1(url.encode())
      # name the files after the hash of the URL
      file_path = data_dir/f'{hash.hexdigest()}.jpeg'

      if file_path.exists(): return file_path  # download each file only noce

      try:
        result = requests.get(url, timeout=5)
      except Exception:
        file_path = None
      else:
        file_path.write_bytes(result.content)

      return file_path

    result = []
    out_paths = ex.map(save_image, urls)
    for file_path in tqdm.tqdm(out_paths, total=len(urls)):
      result.append(file_path)

    return result

  def ds_from_index_file(index_path, data_dir, count):
    data_dir.mkdir(exist_ok=True)

    index = list(itertools.islice(iter_index(index_path), count))
    captions = [caption for caption, url in index]
    urls = [url for caption, url in index]

    paths = download_image_urls(data_dir, urls)

    new_captions = []
    new_paths = []
    for cap, path in zip(captions, paths):
      if path is None:  # download failed for this, skip it
        continue

      new_captions.append(cap)
      new_paths.append(path)

    new_paths = [str(p) for p in new_paths]

    ds = tf.data.Dataset.from_tensor_slices((new_paths, new_captions))
    ds = ds.map(lambda path, cap: (path, cap[tf.newaxis]))  # 1 caption per image

    return ds

  data_dir = pathlib.Path(data_dir)
  train_index_path = tf.keras.utils.get_file(
      origin='https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv',
      cache_subdir=data_dir,
      cache_dir='.'
  )
  val_index_path = tf.keras.utils.get_file(
      origin='https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv',
      cache_subdir=data_dir,
      cache_dir='.'
  )

  train_raw = ds_from_index_file(train_index_path, data_dir=data_dir/'train', count=num_train)
  test_raw = ds_from_index_file(val_index_path, data_dir=data_dir/'val', count=num_val)

  return train_raw, test_raw

train_raw, test_raw = conceptual_captions(num_train=100, num_val=50)

In [None]:
train_raw.element_spec

In [None]:
for ex_path, ex_caps in train_raw.take(1):
  print(ex_path, '\n', ex_caps)

Image feature extractor (PreTrained MobileNet)

In [None]:
IMAGE_SHAPE = (224, 224, 3)

mobilenet = tf.keras.applications.MobileNetV3Small(
    input_shape=IMAGE_SHAPE,
    include_top=False,
    include_preprocessing=True
)
mobilenet.trainable = False

In [None]:
def load_image(path):
  img = tf.io.read_file(path)
  img = tf.io.decode_jpeg(img, channels=3)
  img = tf.image.resize(img, IMAGE_SHAPE[:-1])

  return img

test_img_batch = load_image(ex_path)[tf.newaxis, :]
print(test_img_batch.shape, '\n', mobilenet(test_img_batch).shape)

Tokenizer/Vectorizer

In [None]:
def standardize(s):
  s = tf.strings.lower(s)
  s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')
  s = tf.strings.join(['[START]', s, '[END]'], separate=' ')

  return s

vocab_size = 5_000  # use top 5k words
tokenizer = tf.keras.layers.TextVectorization(
    max_tokens=vocab_size,
    standardize=standardize,
    ragged=True
)
# kearn the vocab from captions
tokenizer.adapt(train_raw.map(lambda fp, txt: txt).unbatch().batch(1024))

t = tokenizer([['a cat in a hat'], ['a robot dog']])
print(t)

word_to_idx = tf.keras.layers.StringLookup(
    mask_token='',
    vocabulary=tokenizer.get_vocabulary()
)
idx_to_word = tf.keras.layers.StringLookup(
    mask_token='',
    vocabulary=tokenizer.get_vocabulary(),
    invert=True
)

w = idx_to_word(t)
print(w.to_list())
print(tf.strings.reduce_join(w, separator-' ', axis=-1).numpy())

Prepare dataset

In [None]:
# train and test contain 1 img -> many captions, so turn it into 1:1
def match_shapes(imgs, caps):
  cap_shape = einops.parse_shape(caps, 'b c')
  caps = einops.rearrange(caps, 'b c -> (b c)')

  imgs = einops.repeat(
      imgs, 'b ... -> (b c) ...',
      c = cap_shape['c']
  )

  return imgs, caps

for ex_paths, ex_captions in train_raw.batch(32).take(1):
  print(ex_paths.shape, '\n', ex_captions.shape)
  ex_paths, ex_captions = match_shapes(ex_paths, ex_captions)
  print('\n', ex_paths.shape, '\n', ex_captions.shape)

  break


# for keras, dataset should be (inputs, labels) pairs
# for text gen, tokens = both input and labels, but shifted by 1 step
def prep_txt(imgs, txts):
  toks = tokenizer(txts)
  inp_toks = tokens[..., :-1]
  label_toks = tokens[..., 1:]

  return (imgs, inp_toks), label_toks

def prep_dataset(ds, tokenizer, batch_size=32, shuffle_buffer=1000):
  ds = (ds
        # load imgs, ignore those that fail
        .shuffle(10_000).map(lambda path, cap: (load_image(path), cap))
        .apply(tf.data.experimental.ignore_errors())
        .batch(batch_size)
        )

  def to_tensor(inps, labels):
    (imgs, in_tok), out_tok = inps, labels
    return (imgs, in_tok.to_tensor()), out_tok.to_tensor()

  return (ds
          # replicate imgs to match number of captions
          .map(match_shapes, tf.data.AUTOTUNE),
          .unbatch()
          # shuffle and rebatch
          .shuffle(shuffle_buffer)
          .batch(batch_size)
          # tokenize and add label tokens
          .map(prep_txt, tf.data.AUTOTUNE)
          # convert from RaggedTensor to padded dense Tensor
          .map(to_tensor, tf.data.AUTOTUNE)
          )

train_ds = prep_dataset(train_raw, tokenizer)
test_ds = prep_dataset(test_raw, tokenizer)

print(train_ds.element_spec, '\n', test_ds.element_spec)

Cache the image features (cuz MobileNet is fixed/not trainable)

In [None]:
def save_ds(ds, path, img_model, tokenizer, shards=10, batch_size=32):
  ds = (ds
        .map(lambda p, c: (load_image(p), c))
        .apply(tf.data.experimental.ignore_errors())
        .batch(batch_size)
        )
  # run feature extractor
  def gen():
    for (i, c) in tqdm.tqdm(ds):
      feature_maps = img_model(i)
      feature_maps, c = match_shapes(feature_maps, c)

      yield feature_maps, c

  new_ds = tf.data.Dataset.from_generator(
      gen,
      output_signature=(
          tf.TensorSpec(shape=img_model.output_shape),
          tf.TensorSpec(shape=(None,), dtype=tf.string)
      )
  )

  new_ds = (new_ds
            .map(prep_txt, tf.data.AUTOTUNE)
            .unbatch()
            .shuffle(1_000)
            )

  def shard_func(i, item):  # save dataset into shard files
    return i % shards

  new_ds.enumerate().save(path, shard_func=shard_func)

def load_ds(path, batch_size=32, shuffle=1000, cycle_length=2):
  def custom_reader_func(datasets):
    datasets = datasets.shuffle(1_000)
    return datasets.interleave(lambda x: x, cycle_length=cycle_length)

  ds = tf.data.Dataset.load(path, reader_func=custom_reader_func)

  def drop_idx(i, x): return x

  ds = (ds
        .map(drop_idx, tf.data.AUTOTUNE),
        .shuffle(shuffle)
        .padded_batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
        )

  return ds

save_ds(train_raw, 'train_cache', mobilenet, tokenizer)
save_ds(test_raw, 'test_cache', mobilenet, tokenizer)

In [None]:
train_ds = load_ds('train_cache')
test_ds = load_ds('test_cache')

train_ds.element_spec

In [None]:
for (inps, ex_labels) in train_ds.take(1):
  (ex_img, ex_in_tok) = inps
  print(ex_img.shape, '\n', ex_in_tok.shape, '\n', ex_labels.shape)
  # these are shifted by 1 step
  print('\n', ex_in_tok[0].numpy(), '\n', ex_labels[0].numpy())
  break

Transformer Decoder Model

In [None]:
class SeqEmb(tf.keras.layers.Layer):
  def __init__(self, vocab_size, max_len, depth):
    super().__init__()

    self.pos_emb = tf.keras.layers.Embedding(input_dim=max_len, output_dim=depth)
    # mask = True to initialize keras-masks for the model
    self.tok_emb = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=depth, mask_zero=True)
    self.add = tf.keras.layers.Add()

  def call(self, seq):
    # looks up embedding vec for each token
    seq = self.tok_emb(seq)  # (batch, seq, depth)

    x = tf.range(tf.shape(seq)[1])  # (seq)
    x = x[tf.newaxis, :]  # (1, seq)
    # looks up embedding vec for each seq location
    x = self.pos_emb(x)  # (1, seq, depth)

    # adds them
    return self.add([seq, x])

In [None]:
class CausalSelfAttn(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()

    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    # use this instead of + so the keras mask propagates
    self.add = tf.keras.layers.Add()
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    attn = self.mha(query=x, value=x, use_causal_mask=True)
    x = self.add([x, attn])

    return self.layernorm(x)

In [None]:
class CrossAttn(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()

    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.add = tf.keras.layers.Add()
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x, y, **kwargs):
    attn, attn_scores = self.mha(
        query=x, value=y,
        return_attention_scores=True  # note this
    )
    self.last_attn_scores = attn_scores
    x = self.add([x, attn])

    return self.layernorm(x)

In [None]:
class FFwd(tf.keras.layers.Layer):
  def __init__(self, units, dropout_rate=0.1):
    super().__init__()

    self.seq = tf.keras.Sequential([
        tf.keras.layers.Dense(units=2*units, activation='relu'),
        tf.keras.layers.Dense(units=units),
        tf.keras.layers.Dropout(rate=dropout_rate)
    ])
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    # input will be (batch, seq, channels)
    # will apply Dense pointwise across batch and seq
    x = x + self.seq(x)

    return self.layernorm(x)

In [None]:
class DecLayer(tf.keras.layers.Layer):
  def __init__(self, units, num_heads=1, dropout_rate=0.1):
    super().__init__()

    self.self_attn = CausalSelfAttn(
        num_heads=num_heads,
        key_dim=units,
        dropout=dropout_rate
    )
    self.cross_attn = CrossAttn(
        num_heads=num_heads,
        key_dim=units,
        dropout=dropout_rate
    )
    self.ff = FFwd(units, dropout_rate)

  def call(self, inps, training=False):
    # (img, text)
    in_seq, out_seq = inps

    out_seq = self.self_attn(out_seq)
    # cross attn uses the img
    out_seq = self.cross_attn(out_seq, in_seq)
    self.last_attn_scores = self.cross_attn.last_attn_scores

    out_seq = self.ff(out_seq)

    return out_seq

Output needs a Dense layer at minimum to get logit-predictions
<br>Can improve it:
 - handle bad tokens:
  - pad '', unknown '[UNK]', start '[START]'
  - model should never generate these, set their bias to a large -ve value & need to ignore them in the loss function

- smart init:
 - default init of Dense = initially predicts w/ almost uniform likelihood, far from the actual token dst
 - add adapt() to count the tokens and set optimal inital bias
 - reduces initial loss from entropy of uniform dst (log(vocab_size)) to marginal entropy of dst (-p*log(p))

In [None]:
class TokOut(tf.keras.layers.Layer):
  def __init__(self, tokenizer, banned=('', '[UNK]', '[START]'), **kwargs):
    super().__init__()

    self.dense = tf.keras.layers.Dense(
        units=tokenizer.vocabulary_size(), **kwargs
    )
    self.tokenizer = tokenizer
    self.banned = banned

    self.bias = None

  def adapt(self, ds):
    counts = collections.Counter()
    vocab_dict = {
        name:id
        for id, name in enumerate(self.tokenizer.get_vocabulary())
    }

    for toks in tqdm.tqdm(ds): counts.update(toks.numpy().flatten())

    counts_arr = np.zeros((self.tokenizer.vocabulary_size(),))
    counts_arr[np.array(list(counts.keys()), dtype=np.int32)] = list(counts.values())

    counts_arr = counts_arr[:]
    for tok in self.banned:
      counts_arr[vocab_dict[tok]] = 0

    total = counts_arr.sum()
    p = counts_arr / total
    p[counts_arr==0] = 1.0
    log_p = np.log(p)  # log(1) = 0

    entropy = -(log_p*p).sum()

    print(f'\nUniform entropy: {np.log(self.tokenizer.vocabulary_size()):0.2f}',
          f'\nMarginal entropy: {entropy:0.2f}')

    self.bias = log_p
    self.bias[counts_arr==0] = -1e9

  def call(self, x):
    x = self.dense(x)
    # Add layer doesn't work cuz different shapes
    # clears the mask, but is fine as it prevents keras from rescaling the losses
    return x + self.bias

out_layer = TokOut(tokenizer)
out_layer.adapt(train_ds.map(lambda inps, labels: labels))

In [None]:
class Captioner(tf.keras.Model):
  def __init__(self, tokenizer, feature_extractor, out_layer, num_layers=1,
               units=256, max_len=50, num_heads=1, dropout_rate=0.1):
    super().__init__()

    self.feature_extractor = feature_extractor
    self.tokenizer = tokenizer

    self.word_to_idx = tf.keras.layers.StringLookup(mask_token='', vocabulary=tokenizer.get_vocabulary())
    self.idx_to_word = tf.keras.layers.StringLookup(mask_token='', vocabulary=tokenizer.get_vocabulary(), invert=True)

    self.seq_emb = SeqEmb(vocab_size=tokenizer.vocabulary_size, depth=units, max_len=max_len)
    self.dec_layers = [
        DecLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)
        for _ in range(num_layers)
    ]
    self.out_layer = out_layer

  def call(self, inps):
    img, txt = inps

    # if RGB, then apply feature extractor
    # else assume it's already applied
    if img.shape[-1] == 3:
      img = self.feature_extractor(img)

    # flatten
    img = einops.rearrange(img, 'b h w c -> b (h w) c')

    # if string, apply tokenizer
    # else assume it's already applied
    if txt.dtype == tf.string:
      txt = tokenizer(txt)

    txt = self.seq_emb(txt)

    # look at the img
    for dec_layer in self.dec_layers:
      txt = dec_layer(inps=(img, txt))

    txt = self.out_layer(txt)

    return txt

  # temp=0 means greedy decoding (choose most likely)
  # temp=1 means random sampling a/c to logits
  # temp much >> 1 means uniform random sampling
  def simple_gen(self, img, temp=1):
    initial = self.word_to_idx([['[START]']])  # (batch, seq)

    # extract img features
    img_features = self.feature_extractor(img[tf.newaxis, ...])

    # initialize output tokens with [START]
    toks = initial
    for n in range(50):
      # pass img features + tokens to the model, get logits
      preds = self((img_features, toks)).numpy()  # (batch, seq, vocab)
      preds = preds[:, -1, :]  # (batch, vocab)

      # choose next token based on logits
      if temp == 0:
        next = tf.argmax(preds, axis=-1)[:, tf.newaxis]  # (batch, 1)
      else:
        next = tf.random.categorical(preds/temp, num_samples=1)  # (batch, 1)

      # add to list of tokens and continue
      toks = tf.concat([toks, next], axis=1)  # (batch, seq)

      # end when [END] is generated
      if next[0] == self.word_to_idx(['END']):
        break

    words = idx_to_word(toks[0, 1:-1])
    result = tf.strings.reduce_join(words, axis=-1, separate=' ')

    return result.numpy().decode()

In [None]:
model = Captioner(
    tokenizer, feature_extractor=mobilenet, out_layer=out_layer,
    units=256, dropout_rate=0.5, num_layers=2, num_heads=2
)

Generate Captions

In [None]:
img_url = 'https://tensorflow.org/images/surf.jpg'
img_path = tf.keras.utils.get_file('surf.jpg', origin=img_url)
img = load_image(img_path)

In [None]:
for t in (0.0, 0.5, 1.0):
  # model is untrained + we initialized w/ frequency of tokens
  # so greedy output (t = 0.0) would only contain most common tokens (a, ., [END])
  result = model.simple_gen(img, temp=t)
  print(result)

Train

In [None]:
def masked_loss(labels, preds):
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)

  # loss < 1e8 discards the artificial, impossibly high losses for the banned tokens
  mask = (labels != 0) & (loss < 1e8)
  mask = tf.cast(mask, loss.dtype)

  loss = loss * mask
  loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)

  return loss

def masked_acc(labels, preds):
  mask = tf.cast(labels != 0, tf.float32)
  preds = tf.argmax(preds, axis=-1)
  labels = tf.cast(labels, tf.int64)

  matched = tf.cast(preds == labels, tf.int64)
  acc = tf.reduce_sum(matched * mask) / tf.reduce_sum(mask)

  return acc

In [None]:
# for feedback during training
class GenText(tf.keras.callbacks.Callback):
  def __init__(self):
    image_url = 'https://tensorflow.org/images/surf.jpg'
    image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
    self.image = load_image(image_path)

  def on_epoch_end(self, epochs=None, logs=None):
    print('\n\n')
    for t in (0.0, 0.5, 1.0):
      result = self.model.simple_gen(self.img, temp=t)
      print(result)
    print('\n')

g = GenText()
g.model = model
g.on_epoch_end(0)

In [None]:
callbacks = [
    GenText(),
    tf.keras.callbacks.EarlyStopping(
        patience=5, restore_best_weights=True
    )
]

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=masked_loss,
    metrics=[masked_acc]
)

In [None]:
history = model.fit(
    # for more freq reporting, use repeate() and give values for steps
    train_ds.repeat(),
    steps_per_epoch=100,
    validation_data=test_ds.repeat(),
    validation_steps=20,
    epochs=100,
    callbacks=callbacks
)

Visualization

In [None]:
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')

plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()


plt.plot(history.history['masked_acc'], label='accuracy')
plt.plot(history.history['val_masked_acc'], label='val_accuracy')

plt.ylim(0, max(plt.ylim()))
plt.xlabel('Epoch #')
plt.ylabel('CE/Token')
plt.legend()

Attention Plots

In [None]:
result = model.simple_gen(img, temp=0.0)
result

In [None]:
str_toks = result.split()  # split back into tokens
str_toks.append('[END]')

# DecLayer caches attn scores for CrossAttn
# shape: (batch=1, heads, seq, img)
attn_maps = [layer.last_attn_scores for layer in model.dec_layers]
print([map.shape for map in attn_maps])

# stack along batch axis
attn_maps = tf.concat(attn_maps, axis=0)
# average over (batch, heads) axes + split image axis back into height and width
attn_maps = einops.reduce(
    attn_maps,
    'batch heads seq (height width) -> seq height width',
    height=7, width=7,
    reduction='mean'
)
# have 1 map for each sequence pred
# values in each map should sum to 1
print(einops.reduce(attn_maps, 'seq height width -> seq', reduction='sum'))

In [None]:
def plot_attn_maps(image, str_toks, attn_map):
  f = plt.figure(figsize=(16, 9))

  len_result = len(str_toks)

  titles = []
  for i in range(len_result):
    map = attn_map[i]

    grid_sz = max(int(np.ceil(len_result / 2)), 2)
    ax = fig.add_subplot(3, grid_size, i+1)
    img = ax.imshow(image)
    ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(), clim-[0.0, np.max(map)])

  plt.tight_layout()

plot_attn_maps(img / 255, str_toks, attn_maps)

In [None]:
# put it together
def run_and_show_attn(model, img, temp=0.0):
  result_txt = self.simple_gen(img, temp)
  str_toks = result_txt.split()
  str_toks.append('[END]')

  attn_maps = [layer.last_attn_scores for layer in model.dec_layers]
  attn_maps = tf.concat(attn_maps, axis=0)
  attn_maps = einops.reduce(
      attn_maps,
      'b head s (h w) -> s h w',
      height=7, width=7,
      reduction='mean'
  )

  plot_attn_maps(img / 255, str_toks, attn_maps)

  t = plt.suptitle(result_txt)
  t.set_y(1.05)


In [None]:
image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'
image_path = tf.keras.utils.get_file(origin=image_url)
image = load_image(image_path)

run_and_show_attention(model, image)