In [None]:
save_model = False
load_model = False
train_model = True
use_learning_schedule = False
save_preprocessing = True #add
load_preprocessing = False #add
save_results = False
predict_train = False
predict_dev = True

run_toy = True
toy_size = 1000
epochs = 3
batch_size = 8
VERSION="train-on-drop-only"
t5_model = 't5-small'

warmup_steps = 10 #1e4
encoder_max_len = 250
decoder_max_len = 54
buffer_size = 1000

# Using T5 on DROP

#### Package installs

#### check gpu

In [None]:
!nvidia-smi

#### Download drop_eval module and set directories

https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py

In [None]:
!wget https://raw.githubusercontent.com/allenai/allennlp-reading-comprehension/master/allennlp_rc/eval/drop_eval.py -O drop_eval.py

#set directories
import os
if not os.path.exists('./data'):
    !mkdir data
data_dir = f"./data/{VERSION}/{t5_model}"
if not os.path.exists(data_dir):
    !mkdir $data_dir
else:
    if save_model: print(f'!!!!!{VERSION} directory already created locally -- CAUTION -- this run may overwrite existing data!!!!!')
    else: print('NOTE: save_model == FALSE, this execution will not save the model locally')
    
results_dir = f"{data_dir}/results/"
if not os.path.exists(results_dir):
    !mkdir $results_dir

log_dir = f"{data_dir}/experiments/logs"
save_path = f"{data_dir}/experiments/models"


#### load packages

In [None]:
# import warnings
# warnings.filterwarnings('ignore')
# warnings.simplefilter('ignore')


# import logging
# logging.getLogger("tensorflow").setLevel(logging.ERROR)
# logging.getLogger("tensorflow").addHandler(logging.NullHandler(logging.ERROR))

from transformers import T5Tokenizer, TFT5ForConditionalGeneration
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow.keras as keras
import drop_eval
import pandas as pd
import numpy as np
import json
from datasets import Dataset, load_dataset
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import datetime
import re

from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm,trange

%load_ext tensorboard

assert len(tf.config.list_physical_devices("GPU")) > 0, "No GPU found by Tensorflow"

if(run_toy): print(f'Running on {toy_size:,} records for development run')
    
!nvcc -V

#### Define model class

In [None]:
class T5forDrop(TFT5ForConditionalGeneration):
    def __init__(self, *args, log_dir=None, cache_dir= None, **kwargs):
        super().__init__(*args, **kwargs)
        tf.keras.metrics.Mean(name='loss') 
        self.loss_tracker= tf.keras.metrics.Mean(name='loss') 
    
    @tf.function
    def train_step(self, data):
        x = data
        y = x["labels"]
        y = tf.reshape(y, [-1, 1])
        with tf.GradientTape() as tape:
            outputs = self(x, training=True)
            loss = outputs[0]
            logits = outputs[1]
            loss = tf.reduce_mean(loss)
            
            grads = tape.gradient(loss, self.trainable_variables)
            
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        lr = self.optimizer._decayed_lr(tf.float32)
        
        self.loss_tracker.update_state(loss)        
        self.compiled_metrics.update_state(y, logits)
        metrics = {m.name: m.result() for m in self.metrics}
        metrics.update({'lr': lr})
        
        return metrics

    def test_step(self, data):
        x = data
        y = x["labels"]
        y = tf.reshape(y, [-1, 1])
        output = self(x, training=False)
        loss = output[0]
        loss = tf.reduce_mean(loss)
        logits = output[1]
        
        self.loss_tracker.update_state(loss)
        self.compiled_metrics.update_state(y, logits)
        return {m.name: m.result() for m in self.metrics}



#### Import model and tokenizer

In [None]:
tokenizer = T5Tokenizer.from_pretrained(t5_model)
#replace numbers with special tokens
numbers = {'additional_special_tokens':['1','2','3','4','5','6','7','8','9','0','<ss>','<sv>']}
num_tokens_added = tokenizer.add_special_tokens(numbers)


model = T5forDrop.from_pretrained(t5_model)

#### Import data

In [None]:
def make_toy(dataset,toy_size=1000):
    df = dataset.to_pandas()
    df = df.head(toy_size)
    return Dataset.from_pandas(df)

In [None]:
train_dataset_full = load_dataset('drop', split='train')
valid_dataset_full = load_dataset('drop', split='validation')

print('Dataset features: ',train_dataset_full.features)

#reduce data to toy size if run_toy flag is set
if(run_toy):
    train_dataset = make_toy(train_dataset_full)
    valid_dataset = make_toy(valid_dataset_full)

else:
    train_dataset = train_dataset_full
    valid_dataset = valid_dataset_full
    
#check out one record
data = next(iter(valid_dataset))
print("\n\nExample data from the dataset: \n", data)

#### set parameters

In [None]:
steps = int(np.ceil(len(train_dataset)/batch_size))
valid_steps = int(np.ceil(len(valid_dataset)/batch_size))
print('Training datset size: {:,} records'.format(len(train_dataset)))
print('Validation datset size: {:,} records'.format(len(valid_dataset)))
print('Batch size: {}'.format(batch_size))
print("Total Steps: {:,}".format(steps))
print("Total Validation Steps: {:,}".format(valid_steps))

#### Preprocess data

In [None]:
def encode(example,
           encoder_max_len=encoder_max_len, decoder_max_len=decoder_max_len):
  
    context = example['passage']
    question = example['question']
    
    answer = example['answers_spans']['spans']
    answer_type = example['answers_spans']['types']
    
    question_plus = f"answer_me: {str(question)}"
    question_plus += f" context: {str(context)}"
    
    answer_plus = ', '.join([i for i in list(answer)])
    answer_plus = f"{answer_plus}"
    
    encoder_inputs = tokenizer(question_plus, truncation=True, 
                               return_tensors='tf', max_length=encoder_max_len,
                              pad_to_max_length=True)
    
    decoder_inputs = tokenizer(answer_plus, truncation=True, 
                               return_tensors='tf', max_length=decoder_max_len,
                              pad_to_max_length=True)
    
    input_ids = encoder_inputs['input_ids'][0]
    input_attention = encoder_inputs['attention_mask'][0]
    target_ids = decoder_inputs['input_ids'][0]
    target_attention = decoder_inputs['attention_mask'][0]
    
    outputs = {'input_ids':input_ids, 'attention_mask': input_attention, 
               'labels':target_ids, 'decoder_attention_mask':target_attention,
                }
    return outputs
    
def to_tf_dataset(dataset):
    '''convert from arrow to TF dataset'''
    
    columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']
    dataset.set_format(type='tensorflow', columns=columns)
    return_types = {'input_ids':tf.int32, 'attention_mask':tf.int32, 
                'labels':tf.int32, 'decoder_attention_mask':tf.int32,}
    return_shapes = {'input_ids': tf.TensorShape([None]), 'attention_mask': tf.TensorShape([None]), 
                  'labels': tf.TensorShape([None]), 'decoder_attention_mask':tf.TensorShape([None]),}
    ds = tf.data.Dataset.from_generator(lambda : dataset, return_types, return_shapes)
    return ds

def create_dataset(dataset, cache_path=None, batch_size=batch_size, 
                   buffer_size= 1000, shuffling=True):
    '''returns a padded_batch tf dataset'''
    if cache_path is not None:
        dataset = dataset.cache(cache_path)        
    if shuffling:
        dataset = dataset.shuffle(buffer_size)
    dataset = dataset.padded_batch(batch_size)
#     dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset



In [None]:
#Preprocess data
train_ds = train_dataset.map(encode)
valid_ds = valid_dataset.map(encode)

tf_train_ds = to_tf_dataset(train_ds)
tf_train_ds = tf_train_ds.repeat(epochs)

tf_valid_ds = to_tf_dataset(valid_ds)

tf_train_ds= create_dataset(tf_train_ds, batch_size=batch_size, 
                         shuffling=True, cache_path = None)
tf_valid_ds = create_dataset(tf_valid_ds, batch_size=batch_size, 
                         shuffling=False, cache_path = None)

print('dataset schema:')
tf_train_ds.element_spec

#### Callbacks and checkpoints

In [None]:
d_model = 128

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
    
temp_learning_rate_schedule = CustomSchedule(d_model)

plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")

In [None]:
start_profile_batch = steps+10
stop_profile_batch = start_profile_batch + 100
profile_range = f"{start_profile_batch},{stop_profile_batch}"

log_path = log_dir + "/" + datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_path, histogram_freq=1,
                                                     update_freq=20,profile_batch=profile_range)

checkpoint_filepath = save_path + "/" + "T5-{epoch:04d}-{val_loss:.4f}.ckpt"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

callbacks = [tensorboard_callback, model_checkpoint_callback] 
metrics = [tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5,name='accuracy') ]#[drop_eval.get_metrics]

#### Compile and run model

In [None]:
if use_learning_schedule:
    learning_rate = CustomSchedule(d_model)
else:
    learning_rate = 0.0005
    
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)
model.compile(optimizer=optimizer, metrics=metrics)
model.summary()

In [None]:
%tensorboard --logdir $log_dir

In [None]:
if train_model:
    model.fit(tf_train_ds, epochs=epochs, steps_per_epoch=steps, callbacks=callbacks, 
              validation_data=tf_valid_ds, validation_steps=valid_steps,verbose=1)
    if(save_model):
        model.save_pretrained(save_path)
        print('Training complete, model saved')

In [None]:
if load_model:
    model.load_weights(save_path+'/tf_model.h5')
    print(f'Model loaded from {save_path}')

#### Predict & Evaluate

In [None]:
def batch_predict(ds,model,tokenizer):
    preds = []

    with tqdm(total=batch_size*len(list(ds.as_numpy_iterator()))) as bar:
        for batch in ds:
            input_ids = batch['input_ids']
            output = model.generate(input_ids)

            for i in range(output.shape[0]):
                single_pred = tokenizer.decode(output[i])
                single_pred = single_pred.replace('<pad>','')
                single_pred = single_pred.replace('</s>','')
                single_pred = single_pred.strip()
                single_pred = re.sub(r'(\d)\s+(\d)', r'\1\2', single_pred)
                preds.append(single_pred)
                bar.update(1)
    return preds

def evaluate(df):
    EM = []
    F1 = []
    
    
    for predicted,gold in tqdm(zip(df['predicted'],df['answers_spans'])):

        best_EM = 0
        best_F1 = 0

        for potential_answer in gold['spans']:
            metrics = drop_eval.get_metrics(predicted=predicted,gold=potential_answer)

            if metrics[1] > best_F1:
                best_EM = metrics[0]
                best_F1 = metrics[1]

        EM.append(best_EM)
        F1.append(best_F1)
        
    df['EM'] = EM
    df['F1'] = F1
    
    print('Exact Match: {:0.4f}, F1: {:0.4f}'.format(df.EM.mean(),df.F1.mean()))
    return df


In [None]:
predict_train

In [None]:
if predict_train:
    print('Making Train Predictions...')
    preds = batch_predict(ds=tf_train_ds,model=model,tokenizer=tokenizer)
    train_df = train_dataset.to_pandas()
    assert len(train_df) == len(preds), "count mismatch, something went wrong"
    train_df['predicted'] = preds
    print('Evaluating Train Predictions...')
    train_df = evaluate(train_df)
    if save_results:
        train_df.to_pickle(results_dir+'drop_train'+datetime.datetime.now().strftime('%H%M-%h%d')+'.pkl')
        print('results for predictions on the training data saved to:\n',save_path)
    
if predict_dev:
    print('Making Dev Predictions...')
    preds = batch_predict(ds=tf_valid_ds,model=model,tokenizer=tokenizer)
    valid_df = valid_dataset.to_pandas()
    valid_df['predicted'] = preds
    assert len(valid_df) == len(preds), "count mismatch, something went wrong"
    print('Evaluating Dev Predictions...')
    valid_df = evaluate(valid_df)
    if save_results:
        valid_df.to_pickle(results_dir+'drop_validation'+datetime.datetime.now().strftime('%H%M-%h%d')+'.pkl')
        print('results for predictions on the validation data saved to:\n',save_path)    
 

In [None]:
valid_df[['query_id','passage','question','answers_spans','predicted','EM','F1']].sample(10)

In [None]:
def print_example(query_id,df):
    print('question: ',df.loc[df.query_id == query_id,'question'].iloc[0])
    print('passage: ',df.loc[df.query_id == query_id,'passage'].iloc[0])
    print('\npredicted answer: ',df.loc[df.query_id == query_id,'predicted'].iloc[0])
    print('True answers: ',df.loc[df.query_id == query_id,'answers_spans'].iloc[0])
    print('F1 score: ',df.loc[df.query_id == query_id,'F1'].iloc[0])
    print('EM score: ',df.loc[df.query_id == query_id,'EM'].iloc[0])
    
    
query_id = '0686d1f9-4a8e-4031-b665-49d425afb777'
print_example(query_id,valid_df)

In [None]:
query_id = '86dd1721-6bf4-45fa-b01e-de47e4f7301d'
print_example(query_id,valid_df)

In [None]:
query_id = 'ad19857f-cd76-4d01-ba29-1a589cfee053'
print_example(query_id,valid_df)

In [None]:
model.summary()

# Cells to explore the model a bit

(make these raw cells into code cells to explore)

In [None]:
for item in tf_valid_ds:
    inputs = item
    outputs = model(inputs)
    break

In [None]:
inputs.keys()

In [None]:
type(outputs)

In [None]:
outputs.keys()

In [None]:
outputs['logits']