In [1]:
import time
import datetime

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds


from transformers import (TFAutoModelWithLMHead, AutoTokenizer, 
    TFTrainer, TFTrainingArguments, TFT5ForConditionalGeneration, T5Config)

In [2]:
tf.__version__

'2.5.0-dev20201029'

### Define Model

In [3]:
class TFT5(TFT5ForConditionalGeneration):
    def __init__(self, *args, log_dir = None, cache_dir = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker= tf.keras.metrics.Mean(name = 'loss') 
    
    @tf.function
    def train_step(self, data):
        x, _= data
        y = x["labels"]
        y = tf.reshape(y, [-1, 1])
        with tf.GradientTape() as tape:
            outputs = self(x, training = True)
            loss = outputs[0]
            logits = outputs[1]
            loss = tf.reduce_mean(loss)
            
            grads = tape.gradient(loss, self.trainable_variables)
            
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        lr = self.optimizer._decayed_lr(tf.float32)
        
        self.loss_tracker.update_state(loss)        
        self.compiled_metrics.update_state(y, logits)
        metrics = {m.name: m.result() for m in self.metrics}
        metrics.update({'lr': lr})
        
        return metrics

    def test_step(self, data):
        x, _ = data
        y = x["labels"]
        y = tf.reshape(y, [-1, 1])
        output = self(x, training = False)
        loss = output[0]
        loss = tf.reduce_mean(loss)
        logits = output[1]
        
        self.loss_tracker.update_state(loss)
        self.compiled_metrics.update_state(y, logits)
        return {m.name: m.result() for m in self.metrics}

### Tokenizer

In [4]:
tokenizer = AutoTokenizer.from_pretrained("t5-base")

### Dataset

In [5]:
train_dataset, info = tfds.load('squad', split = 'train', with_info = True)
valid_dataset = tfds.load('squad', split = 'validation', with_info = False)

INFO:absl:No config specified, defaulting to first: squad/v1.1
INFO:absl:Load dataset info from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0
INFO:absl:Reusing dataset squad (/home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0)
INFO:absl:Constructing tf.data.Dataset for split train, from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0
INFO:absl:No config specified, defaulting to first: squad/v1.1
INFO:absl:Load dataset info from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0
INFO:absl:Reusing dataset squad (/home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0


In [6]:
data = next(iter(train_dataset))
print("Example data from the dataset: \n", data)

Example data from the dataset: 
 {'answers': {'answer_start': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([427], dtype=int32)>, 'text': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'mobile phones'], dtype=object)>}, 'context': <tf.Tensor: shape=(), dtype=string, numpy=b'The difference in the above factors for the case of \xce\xb8=0 is the reason that most broadcasting (transmissions intended for the public) uses vertical polarization. For receivers near the ground, horizontally polarized transmissions suffer cancellation. For best reception the receiving antennas for these signals are likewise vertically polarized. In some applications where the receiving antenna must work in any position, as in mobile phones, the base station antennas use mixed polarization, such as linear polarization at an angle (with both vertical and horizontal components) or circular polarization.'>, 'id': <tf.Tensor: shape=(), dtype=string, numpy=b'57306bf68ab72b1400f9c4dc'>, 'question': <tf.Tensor: 

### Hyperparameters

In [7]:
warmup_steps = 1e4
batch_size = 4
encoder_max_len = 250
decoder_max_len = 54
buffer_size = 1000
ntrain = info.splits["train"].num_examples
nvalid = info.splits["validation"].num_examples
steps = int(np.ceil(ntrain/batch_size))
valid_steps = int(np.ceil(nvalid/batch_size))
print("Total Steps: ", steps)
print("Total Validation Steps: ", valid_steps)

Total Steps:  21900
Total Validation Steps:  2643


### Data preprocessing

In [8]:
def encode(context,question ,answer, 
           encoder_max_len = encoder_max_len, decoder_max_len = decoder_max_len):
    question_plus = f"answer_me: {str(question.numpy().decode('utf-8'))}"
    question_plus += f" context: {str(context.numpy().decode('utf-8'))}  </s>"
    
    answer_plus = ', '.join([i.decode('utf-8') for i in list(answer.numpy())])
    answer_plus = f"{answer_plus} </s>"
    
    encoder_inputs = tokenizer(question_plus, truncation=True, 
                               return_tensors = 'tf', max_length = encoder_max_len,
                              padding = 'max_length')
    
    decoder_inputs = tokenizer(answer_plus, truncation = True, 
                               return_tensors = 'tf', max_length = decoder_max_len,
                              padding = 'max_length')
    
    input_ids = encoder_inputs['input_ids'][0]
    input_attention = encoder_inputs['attention_mask'][0]
    target_ids = decoder_inputs['input_ids'][0]
    target_attention = decoder_inputs['attention_mask'][0]
    
    return input_ids,input_attention, target_ids, target_attention

In [9]:
def encode_tf(inputs):
    context = inputs['context']
    question = inputs['question']
    answer = inputs['answers']['text']
    encoded = tf.py_function(encode, [context, question, answer], 
                                           [tf.int32, tf.int32, tf.int32, tf.int32])
    input_ids,input_attention, target_ids,target_attention = encoded
    input_ids.set_shape([None])
    target_ids.set_shape([None])
    input_attention.set_shape([None])
    target_attention.set_shape([None])
    
    data =  {'input_ids': input_ids, 
            'labels': target_ids, 
            'attention_mask': input_attention,
           'decoder_attention_mask': target_attention}
    return (data, None)

In [10]:
def create_dataset(source_dataset, cache_path = None, batch_size = 4, 
                   buffer_size = 1000, shuffling = True):
    dataset = source_dataset.map(encode_tf, num_parallel_calls = tf.data.experimental.AUTOTUNE)
    
    if cache_path is not None:
        dataset = dataset.cache(cache_path)        
    if shuffling:
        dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

In [11]:
train_ds = create_dataset(train_dataset, batch_size = batch_size, 
                         shuffling = True, cache_path = None)
valid_ds = create_dataset(valid_dataset, batch_size = batch_size, 
                         shuffling = False, cache_path = None)

In [12]:
data = next(iter(train_ds))
data



({'input_ids': <tf.Tensor: shape=(4, 250), dtype=int32, numpy=
  array([[ 1525,   834,   526,    10,   262, 15420,    31,     7, 23907,
           2854,  1775,    19,   859,   149,   186,    13,   165,   773,
             16, 25101,    58,  2625,    10,    37, 23907,  2854,  1775,
             44,   262, 15420,    19,     8,   163,    80,    16, 25101,
             11,    65,     8,   508,  3620,     6,  1341,    12,     3,
             23,    52,    52,  5883,   257,   685,   173,  2197,   114,
            576,  3849,    17, 19516,  1391,    11, 18597,    29, 11638,
           3026,     5,     1,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,   