In [0]:
import tensorflow as tf
import tensorflow_hub as hub

In [2]:
!pip install bert-for-tf2
import bert
FullTokenizer = bert.bert_tokenization.FullTokenizer

Collecting bert-for-tf2
[?25l  Downloading https://files.pythonhosted.org/packages/35/5c/6439134ecd17b33fe0396fb0b7d6ce3c5a120c42a4516ba0e9a2d6e43b25/bert-for-tf2-0.14.4.tar.gz (40kB)
[K     |████████                        | 10kB 20.1MB/s eta 0:00:01[K     |████████████████▏               | 20kB 6.4MB/s eta 0:00:01[K     |████████████████████████▎       | 30kB 7.6MB/s eta 0:00:01[K     |████████████████████████████████| 40kB 3.7MB/s 
[?25hCollecting py-params>=0.9.6
  Downloading https://files.pythonhosted.org/packages/a4/bf/c1c70d5315a8677310ea10a41cfc41c5970d9b37c31f9c90d4ab98021fd1/py-params-0.9.7.tar.gz
Collecting params-flow>=0.8.0
  Downloading https://files.pythonhosted.org/packages/a9/95/ff49f5ebd501f142a6f0aaf42bcfd1c192dc54909d1d9eb84ab031d46056/params-flow-0.8.2.tar.gz
Building wheels for collected packages: bert-for-tf2, py-params, params-flow
  Building wheel for bert-for-tf2 (setup.py) ... [?25l[?25hdone
  Created wheel for bert-for-tf2: filename=bert_for_tf2

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
import os
import json
root = "./drive/My Drive/question_rewrite"

import numpy as np
from tqdm import tqdm

In [0]:
def read_data(cate):
  # read data into three lists
  # q: original q (long)
  # new_q: shorten q
  # img: img_path
  path = os.path.join(root, cate, 'data.json')
  q, new_q, img = [], [], []
  with open(path, 'r') as f:
    data = json.load(f)
  for item in data:
    q.append(item['rewrite_q'])
    new_q.append(item['new_q'])
    img.append(item['img_path'])
  return q, new_q, img

In [0]:
def get_model(name, fc):
  if name == 'vgg19' and not fc:
    model = tf.keras.applications.VGG19(include_top=False,
                                        weights='imagenet')
    output = model.layers[-1].output
  elif name == 'vgg19' and fc:
    model = tf.keras.applications.VGG19(include_top=True,
                                        weights='imagenet')
    output = model.layers[-3].output
  elif name == 'res50' and not fc:
    model = tf.keras.applications.ResNet50(include_top=False,
                                           weights='imagenet')
    output = model.layers[-1].output
  else:
    print("Illegal Config.")
  # build new image model
  input = model.input
  img_model = tf.keras.Model(input, output)
  return img_model


def feature_path(img_path, name, fc):
  # "./drive/My Drive/questions_rewrite/auto_annot/image/000002.jpg"
  # if cate=res50 and fc=False
  # "./drive/My Drive/questions_rewrite/auto_annot/image/000002_res50_nonfc.npy""
  p, img_name = img_path.rsplit('/', 1)
  new_name = img_name.split('.')[0] + "_{}_{}.npy"
  fc_ = 'fc' if fc else 'nonfc'
  new_name = new_name.format(name, fc_)
  return os.path.join(p, new_name)


def extract_img_feat(cate, img_data, name='res50', fc=False):
  # set the image model accoding to the name and fc config
  img_model = get_model(name, fc)

  # load img function
  def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (224, 224))
    if name == 'vgg19':
      img = tf.keras.applications.vgg19.preprocess_input(img)
    elif name == 'res50':
      img = tf.keras.applications.resnet.preprocess_input(img)
    else:
      print("Illegal Name.")
    return img, image_path

  # get unique images
  unique_img = list(set(img_data))
  unique_img = list(map(lambda x: os.path.join(root, cate, x), unique_img))

  # build dataset for unique images
  image_dataset = tf.data.Dataset.from_tensor_slices(unique_img)
  image_dataset = image_dataset.map(
      load_image,
      num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(64)

  print("Model and Dataset Done.")

  # extract features and save
  for img, path in image_dataset:
      batch_features = img_model(img)
      # reshpae image features
      # vgg19 without top: [7*7, 512]
      # vgg19 with top: [4096]
      # res50 without top: [7*7, 2048]
      if not fc:
        batch_features = tf.reshape(batch_features, 
                                      (batch_features.shape[0],
                                      -1, batch_features.shape[-1]))
      for f, p in zip(batch_features, path):
        path_of_feature = feature_path(p.numpy().decode("utf-8"), name, fc)
        np.save(path_of_feature, f.numpy())

  print("Extraction Done.")

  # exchange img_data with feature path
  for i in range(len(img_data)):
    img_data[i] = feature_path(os.path.join(root, cate, img_data[i]), name, fc)
  
  return img_data

In [7]:
cate = "human_annot"
img_model_name = 'res50'
fc_top = False
q_data, new_q_data, img_data = read_data(cate)
img_data = extract_img_feat(cate,
                            img_data,
                            name=img_model_name,
                            fc=fc_top)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Model and Dataset Done.
Extraction Done.


In [0]:
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2",
                            trainable=True)

vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
bert_tokenizer = FullTokenizer(vocab_file, do_lower_case)

In [0]:
def get_masks(tokens, max_seq_length):
  if len(tokens) > max_seq_length:
    raise IndexError("Token longer than max_seq_length.")
  return [1] * len(tokens) + [0] * (max_seq_length - len(tokens))


def get_segments(tokens, max_seq_length):
  if len(tokens) > max_seq_length:
    raise IndexError("Token longer than max_seq_length.")
  segments = []
  current_segment_id = 0
  for token in tokens:
    segments.append(current_segment_id)
    if token == "[SEP]":
      current_segment_id = 1
  return segments + [0] * (max_seq_length - len(tokens))


def get_ids(tokens, tokenizer, max_seq_length):
  token_ids = tokenizer.convert_tokens_to_ids(tokens)
  input_ids = token_ids + [0] * (max_seq_length - len(token_ids))
  return input_ids

In [0]:
import re
import time

In [0]:
def preprocess_sentence(w):
  w = re.sub(r"([?.!,¿])", r" \1 ", w)
  w = re.sub(r'[" "]+', " ", w)
  w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
  w = w.strip()
  # w = "[CLS] " + w + " [SEP]"
  return w

In [0]:
def bert_preprocess(qs, tokenizer):
   max_length = 0
   qs_tokens = []
   for q in qs:
     q_data = preprocess_sentence(q)
     tokens = tokenizer.tokenize(q)
     tokens = ["[CLS]"] + tokens + ["[SEP]"]
     max_length = max(max_length, len(tokens))
     qs_tokens.append(tokens)

   ids = []
   masks = []
   segments = []
   for q_tokens in qs_tokens:
     ids.append(get_ids(q_tokens, tokenizer, max_length))
     masks.append(get_masks(q_tokens, max_length))
     segments.append(get_segments(q_tokens, max_length))
   return ids, masks, segments, max_length

In [13]:
input_ids, input_masks, input_segments, input_max_length = bert_preprocess(new_q_data, bert_tokenizer)
print(len(input_ids), len(input_ids[0]))

241 31


In [0]:
def output_preprocess(qs):
  qs = [preprocess_sentence(q) for q in qs]
  qs = ['[CLS] ' + q + ' [SEP]' for q in qs]
  output_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='',
                                                           lower=False)
  output_tokenizer.fit_on_texts(qs)
  output_tokenizer.word_index['<pad>'] = 0
  output_tokenizer.index_word[0] = '<pad>'
  tokens = output_tokenizer.texts_to_sequences(qs)
  tokens = tf.keras.preprocessing.sequence.pad_sequences(tokens,
                                                         padding='post')
  return tokens, output_tokenizer

In [15]:
target_ids, target_tokenizer = output_preprocess(q_data)
print(len(target_ids), len(target_ids[0]))
print(target_ids[0])
print(q_data[0])
print(target_tokenizer.sequences_to_texts(target_ids)[0])

241 49
[ 2 20 16  5 17  7 21 24 25 60  4 23  6  3  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0]
beautiful kitchen. can you please tell me dimensions? thanks!
[CLS] beautiful kitchen . can you please tell me dimensions ? thanks ! [SEP] <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


In [16]:
print(target_tokenizer.word_index['[CLS]'])

2


In [0]:
def train_val_split(data, rate=0.8):
  len_ = len(data)
  return data[:int(len_ * rate)], data[int(len_ * rate):]

In [0]:
train_input_ids, val_input_ids = train_val_split(input_ids)
train_input_masks, val_input_masks = train_val_split(input_masks)
train_input_segments, val_input_segments = train_val_split(input_segments)

train_target_ids, val_target_ids = train_val_split(target_ids)

train_img, val_img = train_val_split(img_data)

In [0]:
def load_func(ids, masks, segs, targ, img):
  img_tensor = np.load(img.decode('utf-8'))
  return ids, masks, segs, targ, img_tensor

In [0]:
buffer_size = len(train_input_ids)
batch_size = 64
steps_per_epoch = len(train_input_ids) // batch_size
vocab_tar_size = len(target_tokenizer.word_index) + 1
embedding_dim = 256
units = 768


dataset = tf.data.Dataset.from_tensor_slices((train_input_ids,
                                              train_input_masks,
                                              train_input_segments,
                                              train_target_ids,
                                              train_img))
dataset = dataset.map(lambda ids, masks, segs, targ, img:
                          tf.numpy_function(load_func,
                                            [ids, masks, segs, targ, img],
                                            [tf.int32, tf.int32, tf.int32, tf.int32, tf.float32]),
                      num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(buffer_size).batch(batch_size)

In [0]:
class TextAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(TextAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    query_with_time_axis = tf.expand_dims(query, 1)
    score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
    attention_weights = tf.nn.softmax(score, axis=1)

    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights

In [0]:
class ImageAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(ImageAttention, self).__init__()
    self.W_I = tf.keras.layers.Dense(units)
    self.W_Q = tf.keras.layers.Dense(units)
    self.W_P = tf.keras.layers.Dense(1)
    
  def call(self, v_I, v_Q):
    v_I_att = self.W_I(v_I)
    v_Q_att = self.W_Q(v_Q)
    v_Q_att = tf.expand_dims(v_Q_att, axis=1)
    p_I = self.W_P(tf.tanh(v_I_att + v_Q_att))
    attention_weights = tf.nn.softmax(p_I, axis=1)

    v_att = p_I * v_I
    v_att = tf.reduce_sum(v_att, axis=1)

    return v_att, attention_weights

In [0]:
class QRewriteModel(tf.keras.Model):
  def __init__(self,
               vocab_tar_size,
               embedding_dim,
               dec_units,
               batch_size):
    super(QRewriteModel, self).__init__()
    self.batch_size = batch_size
    self.dec_units = dec_units
    self.dec_embedding = tf.keras.layers.Embedding(vocab_tar_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_tar_size)

    # self.encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1",
    #                              trainable=False)
    self.text_attention = TextAttention(self.dec_units)
    self.img_attention = ImageAttention(self.dec_units)
    self.img_fc = tf.keras.layers.Dense(self.dec_units)

  def call(self, x, hidden, enc_output, img_feature):
    text_vector, text_weights = self.text_attention(hidden, enc_output)
    img_feature = self.img_fc(img_feature)
    image_vector, image_weights = self.img_attention(img_feature, hidden)

    x = self.dec_embedding(x)

    x = tf.concat([tf.expand_dims(text_vector, 1),
                   tf.expand_dims(image_vector, 1),
                   x], axis=-1)
    output, state =  self.gru(x)
    output = tf.reshape(output, (-1, output.shape[2]))

    x = self.fc(output)
    return x, state, text_weights, image_weights

In [0]:
decoder = QRewriteModel(vocab_tar_size,
                        embedding_dim,
                        units,
                        batch_size)

In [0]:
input_word_ids = tf.keras.layers.Input(shape=(input_max_length,),
                                       dtype=tf.int32,
                                       name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(input_max_length,),
                                   dtype=tf.int32,
                                   name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(input_max_length,),
                                    dtype=tf.int32,
                                    name="segment_ids")

pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
encoder = tf.keras.Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_output, sequence_output])

In [0]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                            reduction='none')

def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)
  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask
  return tf.reduce_mean(loss_)

In [0]:
exp_name = img_model_name

In [0]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [47]:
example_input_ids, example_mask, example_seg, example_output_ids, example_img = next(iter(dataset))
print("input_ids:", example_input_ids.shape)
print("mask:", example_mask.shape)
print("segments:", example_seg.shape)
print("output_ids:", example_output_ids.shape)
print("img:", example_img.shape)

input_ids: (64, 31)
mask: (64, 31)
segments: (64, 31)
output_ids: (64, 49)
img: (64, 49, 2048)


In [0]:
@tf.function
def train_step(enc_hidden, enc_output, img, targ):
  loss = 0

  with tf.GradientTape() as tape:
    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([target_tokenizer.word_index['[CLS]']] * batch_size, 1)

    for t in range(1, targ.shape[1]):
      predictions, dec_hidden, _, _ = decoder(dec_input,
                                              dec_hidden,
                                              enc_output,
                                              img)
      loss += loss_function(targ[:, t], predictions)
      dec_input = tf.expand_dims(targ[:, t], 1)

  batch_loss = (loss / int(targ.shape[1]))
  variables = decoder.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return batch_loss

In [54]:
EPOCHS = 10

for epoch in range(EPOCHS):
  start = time.time()

  total_loss = 0

  for batch, (ids, masks, segments, targ, img) in enumerate(dataset):
    enc_hidden, enc_output = encoder([ids, masks, segments])
    batch_loss = train_step(enc_hidden, enc_output, img, targ)
    total_loss += batch_loss

    print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                 batch,
                                                 batch_loss.numpy()))
    
  if (epoch + 1) % 2 == 0:
    checkpoint.save(file_prefix=checkpoint_prefix)
  
  print('Epoch {} Loss {:.4f}'.format(epoch + 1, total_loss / steps_per_epoch))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Epoch 1 Batch 0 Loss 2.8392
Epoch 1 Batch 1 Loss 2.5437
Epoch 1 Batch 2 Loss 2.3330
Epoch 1 Loss 2.5720
Time taken for 1 epoch 50.0321159362793 sec

Epoch 2 Batch 0 Loss 2.3370
Epoch 2 Batch 1 Loss 2.4786
Epoch 2 Batch 2 Loss 2.2778
Epoch 2 Loss 2.3645
Time taken for 1 epoch 4.656791925430298 sec

Epoch 3 Batch 0 Loss 2.2813
Epoch 3 Batch 1 Loss 2.0684
Epoch 3 Batch 2 Loss 2.0645
Epoch 3 Loss 2.1381
Time taken for 1 epoch 2.5919201374053955 sec

Epoch 4 Batch 0 Loss 1.9049
Epoch 4 Batch 1 Loss 1.9456
Epoch 4 Batch 2 Loss 2.0907
Epoch 4 Loss 1.9804
Time taken for 1 epoch 4.500836133956909 sec

Epoch 5 Batch 0 Loss 1.9739
Epoch 5 Batch 1 Loss 1.9271
Epoch 5 Batch 2 Loss 1.8593
Epoch 5 Loss 1.9201
Time taken for 1 epoch 2.6015875339508057 sec

Epoch 6 Batch 0 Loss 2.0298
Epoch 6 Batch 1 Loss 1.8870
Epoch 6 Batch 2 Loss 1.8434
Epoch 6 Loss 1.9201
Time taken for 1 epoch 4.518457412719727 sec

Epoch 7 Batch 0 Loss 1.7868
Epoch 7 Batch 1 Loss 1.8547
Epoch 7 Batch 2 Loss 2.1086
Epoch 7 Loss 1.