In [1]:
import time
import datetime

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds


from transformers import (TFAutoModelWithLMHead, AutoTokenizer, 
    TFTrainer, TFTrainingArguments, TFT5ForConditionalGeneration, T5Config)

In [2]:
tf.__version__

'2.5.0-dev20201029'

### Define Model

In [3]:
class TFT5(TFT5ForConditionalGeneration):
    def __init__(self, *args, log_dir = None, cache_dir = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker= tf.keras.metrics.Mean(name = 'loss') 
    
    @tf.function
    def train_step(self, data):
        x, _= data
        y = x["labels"]
        y = tf.reshape(y, [-1, 1])
        with tf.GradientTape() as tape:
            outputs = self(x, training = True)
            loss = outputs[0]
            logits = outputs[1]
            loss = tf.reduce_mean(loss)
            
            grads = tape.gradient(loss, self.trainable_variables)
            
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        lr = self.optimizer._decayed_lr(tf.float32)
        
        self.loss_tracker.update_state(loss)        
        self.compiled_metrics.update_state(y, logits)
        metrics = {m.name: m.result() for m in self.metrics}
        metrics.update({'lr': lr})
        
        return metrics

    def test_step(self, data):
        x, _ = data
        y = x["labels"]
        y = tf.reshape(y, [-1, 1])
        output = self(x, training = False)
        loss = output[0]
        loss = tf.reduce_mean(loss)
        logits = output[1]
        
        self.loss_tracker.update_state(loss)
        self.compiled_metrics.update_state(y, logits)
        return {m.name: m.result() for m in self.metrics}

### Tokenizer

In [4]:
tokenizer = AutoTokenizer.from_pretrained("t5-base")

### Dataset

In [5]:
train_dataset = tfds.load('squad', split = 'train', with_info = False)
valid_dataset = tfds.load('squad', split = 'validation', with_info = False)

INFO:absl:No config specified, defaulting to first: squad/v1.1
INFO:absl:Load dataset info from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0
INFO:absl:Reusing dataset squad (/home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0)
INFO:absl:Constructing tf.data.Dataset for split train, from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0
INFO:absl:No config specified, defaulting to first: squad/v1.1
INFO:absl:Load dataset info from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0
INFO:absl:Reusing dataset squad (/home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from /home/mirac13/tensorflow_datasets/squad/v1.1/2.0.0


In [6]:
data = next(iter(train_dataset))
print("Example data from the dataset: \n", data)

Example data from the dataset: 
 {'answers': {'answer_start': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([427], dtype=int32)>, 'text': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'mobile phones'], dtype=object)>}, 'context': <tf.Tensor: shape=(), dtype=string, numpy=b'The difference in the above factors for the case of \xce\xb8=0 is the reason that most broadcasting (transmissions intended for the public) uses vertical polarization. For receivers near the ground, horizontally polarized transmissions suffer cancellation. For best reception the receiving antennas for these signals are likewise vertically polarized. In some applications where the receiving antenna must work in any position, as in mobile phones, the base station antennas use mixed polarization, such as linear polarization at an angle (with both vertical and horizontal components) or circular polarization.'>, 'id': <tf.Tensor: shape=(), dtype=string, numpy=b'57306bf68ab72b1400f9c4dc'>, 'question': <tf.Tensor: 