<a href="https://colab.research.google.com/github/nisarahamedk/kaggle-riid/blob/master/notebooks/RIID_TF_Transformers_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### RIID Transformer on TPU

In [123]:
import pickle

import tensorflow as tf
import tensorflow.keras as keras
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold


np.random.seed(42)
tf.random.set_seed(42)


### *Device* Settings - This needs to be at the TOP¶

In [124]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # no parameter needed for TPU_NAME env variable is set. This is the case for Kaggle
    print("Running on TPU: ", tpu.master())
except ValueError:
    tpu = None

In [125]:
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    # default strategy with the available hw
    strategy = tf.distribute.get_strategy()
    
REPLICAS = strategy.num_replicas_in_sync
print("REPLICAS: ", REPLICAS)

REPLICAS:  1


### Datasets

In [126]:
DATA_PATH = 'gs://kds-0d2ffe2dbc91ac5f57c8846e2d8c3ef3f2c48c5e264ae5abb3f92918'

In [127]:
n_train_files = len(tf.io.gfile.glob(DATA_PATH + "/tfrec*"))
n_train_files

32

In [128]:
FOLDS = 10

kfold = KFold(n_splits=FOLDS, shuffle=True, random_state=42)
folds_list = list(kfold.split(np.arange(n_train_files)))
folds_list[:2]

[(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 16, 18,
         19, 20, 21, 22, 23, 25, 26, 27, 28, 30, 31]),
  array([15, 17, 24, 29])),
 (array([ 0,  1,  2,  3,  4,  5,  6,  7, 10, 11, 12, 13, 14, 15, 16, 17, 18,
         19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 31]),
  array([ 8,  9, 25, 30]))]

In [129]:
FOLD = 5

train_folds, valid_folds = folds_list[FOLD]
len(train_folds), len(valid_folds)

(29, 3)

In [130]:
train_files = tf.io.gfile.glob([DATA_PATH + "/tfrec_%d.tfrec" % idx for idx in train_folds])
valid_files = tf.io.gfile.glob([DATA_PATH + "/tfrec_%d.tfrec" % idx for idx in valid_folds])

In [131]:
len(train_files), len(valid_files)

(29, 3)

#### Load TFRecord Datasets

In [132]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [133]:
feature_desc = {
    "content_id": tf.io.FixedLenFeature([], tf.string),
    "task_container_id": tf.io.FixedLenFeature([], tf.string),
    "prior_question_elapsed_time": tf.io.FixedLenFeature([], tf.string),
    "part": tf.io.FixedLenFeature([], tf.string),
    "prev_answered_correctly": tf.io.FixedLenFeature([], tf.string),
    "answered_correctly": tf.io.FixedLenFeature([], tf.string),
}

def parse_example(example):
  example = tf.io.parse_single_example(example, feature_desc)

  content_id = tf.io.parse_tensor(example["content_id"], tf.int16)
  task_container_id = tf.io.parse_tensor(example["task_container_id"], tf.int16)
  prior_question_elapsed_time = tf.io.parse_tensor(example["prior_question_elapsed_time"], tf.float32)
  part = tf.io.parse_tensor(example["part"], tf.int16)
  prev_answered_correctly = tf.io.parse_tensor(example["prev_answered_correctly"], tf.int8)
  answered_correctly = tf.io.parse_tensor(example["answered_correctly"], tf.int8)
  
  return tf.stack([
      tf.cast(content_id, tf.float32),
      tf.cast(task_container_id, tf.float32),
      tf.cast(prior_question_elapsed_time, tf.float32),
      tf.cast(part, tf.float32),
      tf.cast(prev_answered_correctly, tf.float32),
      tf.cast(answered_correctly, tf.float32)
  ]) # [features, seq_len] # TODO make it [seq_len, features] so that we dont have to reshape in transformer

In [134]:
def load_dataset_from_tfrecord(filenames, ds_type="train", cache_to=None):
        # Since we are reading dataset from multiple files. and we dont care about the order.
        # set deterministic reading to False.
        ignore_order = tf.data.Options()
        if ds_type == "train":
            ignore_order.experimental_deterministic = False
            
        dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
        if not cache_to:
            dataset = dataset.cache() # cache to RAM
        else:
            dataset = dataset.cache(cache_to) # cache to file given by self.cache_to 
        if ds_type == "train":
            dataset = dataset.repeat() # repeat individual item, so that we have full batch at every step.
        dataset.with_options(ignore_order)
        dataset = dataset.map(parse_example, num_parallel_calls=AUTOTUNE)
        return dataset

In [135]:
dataset = load_dataset_from_tfrecord(train_files)

In [136]:
SEQ_LEN = 128

In [137]:
@tf.function
def pad(a, seq_len, max_seq_len):
  s = max_seq_len - seq_len
  # making [[0, 0], [s, 0]]
  r = tf.stack([s, tf.constant(0)])
  t = tf.stack([tf.constant([0, 0]), r])
  
  return tf.pad(a, t) # ,1 to debug

@tf.function
def trim(a, seq_len,  max_seq_len):
  start = tf.squeeze(tf.random.uniform((1,), maxval=(seq_len-max_seq_len), dtype=tf.int32))
  # https://www.quora.com/How-does-tf-slice-work-in-TensorFlow
  begin = tf.stack([tf.constant(0), start])
  size = tf.stack([tf.shape(a)[0], max_seq_len])
  
  return tf.slice(a, begin, size) # , start - to debug

@tf.function
def pad_or_trim(a):
  seq_len = tf.shape(a)[-1]
  max_seq_len = SEQ_LEN
  fn = tf.cond(tf.less_equal(seq_len, max_seq_len), lambda: pad(a, seq_len, max_seq_len), lambda: trim(a, seq_len, max_seq_len))
  return fn

In [138]:
dataset = dataset.map(pad_or_trim, num_parallel_calls=AUTOTUNE) # every sample is padded if len < SEQ_LEN or randomly trimmed to SEQ_LEN

In [139]:
dataset = dataset.map(lambda x: (x[:-1, :], x[-1, :])) # x and y

In [140]:
for x, y in dataset.take(1):
  print(x.shape)
  print(y.shape)

(5, 128)
(128,)


In [141]:
dataset = dataset.shuffle(int(1024 * REPLICAS))

In [142]:
BATCH_SIZE = 128

In [143]:
dataset = dataset.batch(BATCH_SIZE)

In [144]:
for xb, yb in dataset.take(1):
  print(xb.shape)
  print(yb.shape)

(128, 5, 128)
(128, 128)


In [145]:
dataset = dataset.prefetch(AUTOTUNE)

In [146]:
# len(list(iter(dataset.take(1000)))) # dataset repeats indefinitely

1000

### Model

##### Positional Encoding

In [147]:
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

In [148]:
def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)

##### Look ahead mask

In [149]:
def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)

##### Scaled Dot Product Attention

In [150]:
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

##### Multi Head Attention

In [151]:
class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights

##### Pointwise FeedForward Network

In [152]:
def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])

##### EncoderLayer

In [153]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2

##### Encoder

In [197]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, maximum_position_encoding, embed_size_dict, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.content_id_emb = tf.keras.layers.Embedding(embed_size_dict["content_id"] + 1, d_model)
    self.task_container_id_emb = tf.keras.layers.Embedding(embed_size_dict["task_container_id"] + 1, d_model)
    self.part_emb = tf.keras.layers.Embedding(embed_size_dict["part"] + 2, d_model)
    self.prev_answered_emb = tf.keras.layers.Embedding(4, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)


    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # adding embeddings and position encoding.
    c_emb = self.content_id_emb(x[...,0])  # (batch_size, input_seq_len, d_model)
    t_emb = self.task_container_id_emb(x[...,1])
    pt_emb = self.part_emb(x[...,3])
    pv_emb = self.prev_answered_emb(x[...,4])
    x = c_emb + t_emb + pt_emb + pv_emb
    print(x.shape)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)

###### Embedding Sizes

In [198]:
embed_sizes = pickle.loads(tf.io.read_file(DATA_PATH + "/emb_sz.pkl").numpy())

In [199]:
embed_sizes

{'content_id': 32736, 'part': 7, 'task_container_id': 9999}

In [200]:
with strategy.scope():
  model = Encoder(
      num_layers=1,
      d_model=512,
      num_heads=8,
      dff=1024,
      maximum_position_encoding=128,
      embed_size_dict=embed_sizes
  )

In [175]:
for xb, yb in dataset.take(1):
  pass

In [201]:
xb = tf.transpose(xb, (0, 2, 1))
xb.shape

TensorShape([128, 5, 128])

In [202]:
model(xb, mask=None)

InvalidArgumentError: ignored

In [196]:
xb[...,3]

<tf.Tensor: shape=(128, 128), dtype=float32, numpy=
array([[0., 0., 0., ..., 6., 6., 6.],
       [3., 3., 3., ..., 6., 6., 6.],
       [6., 6., 6., ..., 6., 6., 6.],
       ...,
       [0., 0., 0., ..., 6., 6., 6.],
       [0., 0., 0., ..., 6., 6., 6.],
       [0., 0., 0., ..., 3., 3., 3.]], dtype=float32)>

In [203]:
xb[..., 1]

<tf.Tensor: shape=(128, 5), dtype=float32, numpy=
array([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [6.2100000e+02, 4.7000000e+01, 2.5000000e-01, 3.0000000e+00,
        2.0000000e+00],
       [4.2170000e+03, 2.0000000e+00, 4.8333332e-01, 6.0000000e+00,
        2.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [6.3810000e+03, 6.6000000e+01, 7.1666664e-01, 6.0000000e+00,
        1.0000000e+00],
       [1.0687000e+04, 6.0000000e+01, 3.1666666e-01, 3.0000000e+00,
        2.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [4.4670000e+03, 1.0800000e+02, 1.4166666e+00, 6.0000000e+00,
        1.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [0.0000000e+00, 0.00000