In [16]:
# Install tf-transformers from github

### T5 + Squad for Span Selection (not Generation)

* The purpose of this tutorial is to demonstrate that, how can we customize **decoder** for some other tasks
* which it was actually not trained for . 

**Note** - T5 works well when we formulate most or all tasks as Text Generation

In [None]:
import json
import tensorflow as tf
import time
import glob

from tf_transformers.utils.tokenization import BasicTokenizer, SPIECE_UNDERLINE
from tf_transformers.utils import fast_sp_alignment
from tf_transformers.data.squad_utils_sp import (
    read_squad_examples,
    post_clean_train_squad,
    example_to_features_using_fast_sp_alignment_train,
    example_to_features_using_fast_sp_alignment_test, 
    evaluate_v1
)
from tf_transformers.data import TFWriter, TFReader, TFProcessor
from tf_transformers.models import T5Model
from tf_transformers.core import optimization, SimpleTrainer
from tf_transformers.tasks import Span_Selection_Model

from transformers import T5Tokenizer
from absl import logging
logging.set_verbosity("INFO")

### Load Tokenizer

In [5]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
basic_tokenizer = BasicTokenizer(do_lower_case=False)

tokenizer.sep_token = tokenizer.cls_token = tokenizer.eos_token # We need it for squad internal processing

### Convert train data to Features

* using Fast Sentence Piece Alignment, we convert text to features (text -> list of sub words)

In [3]:
input_file_path = '/mnt/home/PRE_MODELS/HuggingFace_models/datasets/squadv1.1/train-v1.1.json'

is_training = True

# 1. Read Examples
start_time = time.time()
train_examples = read_squad_examples(
      input_file=input_file_path,
      is_training=is_training,
      version_2_with_negative=False
      )
end_time = time.time()
print('Time taken {}'.format(end_time-start_time))

# 2.Postprocess (clean text to avoid some unwanted unicode charcaters)
train_examples_processed, failed_examples = post_clean_train_squad(train_examples, basic_tokenizer, is_training=is_training)


# 3.Convert question, context and answer to proper features (tokenized words) not word indices
feature_generator = example_to_features_using_fast_sp_alignment_train(tokenizer, train_examples_processed, max_seq_length = 384, 
                                                           max_query_length=64, doc_stride=128, SPECIAL_PIECE=SPIECE_UNDERLINE) 

all_features = []
for feature in feature_generator:
    all_features.append(feature)
end_time = time.time()
print("time taken {} seconds".format(end_time-start_time))

INFO:absl:Time taken 0.06583905220031738


Time taken 0.7573883533477783
time taken 1.051743984222412 seconds


### Convert features to TFRecords using TFWriter

In [7]:
# Convert tokens to id and add type_ids
# input_mask etc
# This is user specific/ tokenizer specific
# Eg: Roberta has input_type_ids = 0, BERT has input_type_ids = [0, 1]

def parse_train():
    result = {}
    for f in all_features:
        
        question_sep_index =  f['input_ids'].index(tokenizer.sep_token) # We dont want sep-token in the starting
        question_ids = tokenizer.tokenize('question: ') + f['input_ids'][1:question_sep_index+1] # 1 to avoid CLS token
        
        passage_context_prompt = tokenizer.tokenize('context: ')
        passage_ids = passage_context_prompt + f['input_ids'][question_sep_index +1 :] # Adding context as prompt
        new_start_position  = (f['start_position'] - len(question_ids)) + len(passage_context_prompt) + 1
        new_end_position  = (f['end_position'] - len(question_ids)) + len(passage_context_prompt) + 1
        
        assert(f['input_ids'][f['start_position']: f['end_position']] == passage_ids[new_start_position: new_end_position])

        result['encoder_input_ids'] = tokenizer.convert_tokens_to_ids(question_ids)
        result['encoder_input_mask'] = [1] * len(result['encoder_input_ids'])
        
        result['decoder_input_ids'] = tokenizer.convert_tokens_to_ids(passage_ids)
        result['decoder_input_mask'] = [1] * len(result['decoder_input_ids'])
        
        result['start_position'] = new_start_position
        result['end_position']   = new_end_position
        yield result
        

# Lets write using TF Writer
# Use TFProcessor for smalled data

schema = {'encoder_input_ids': ("var_len", "int"), 
         'encoder_input_mask': ("var_len", "int"), 
         'decoder_input_ids': ("var_len", "int"), 
         'decoder_input_mask': ("var_len", "int"),
         'start_position': ("var_len", "int"), 
         'end_position': ("var_len", "int")}

tfrecord_train_dir = '../OFFICIAL_TFRECORDS/squad/t5_span/train'
tfrecord_filename = 'squad'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    overwrite=True
                    )
tfwriter.process(parse_fn=parse_train())

INFO:absl:Total individual observations/examples written is 100
INFO:absl:All writer objects closed


### Read TFRecords using TFReader

In [11]:
# Read Data
schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['encoder_input_ids', 'encoder_input_mask', 'decoder_input_ids', 'decoder_input_mask']
y_keys = ['start_position', 'end_position']
batch_size = 32
train_dataset = tf_reader.read_record(auto_batch=True, 
                                   keys=x_keys,
                                   batch_size=batch_size, 
                                   x_keys = x_keys, 
                                   y_keys = y_keys,
                                   shuffle=True, 
                                   drop_remainder=True
                                  )



In [None]:
for (batch_inputs, batch_labels) in train_dataset.take(1):
    print(batch_inputs, batch_labels)

### Load t5 Model

In [13]:
import tensorflow as tf

model_layer, model, config = T5Model(model_name='t5-small', decoder_mask_mode='user_defined')
model.load_checkpoint("/mnt/home/PRE_MODELS/LegacyAI_models/checkpoints/t5-small/")

INFO:absl:We are overwriding `is_training` is False to `is_training` to True with `use_dropout` is False, no effects on your inference pipeline


Kwargs {}


INFO:absl:Initialized Variables
INFO:absl:Succesful: Model checkpoints matched


### Load Span Selection Model

In [15]:

span_selection_layer = Span_Selection_Model(model=model,
                                      is_training=True)
span_selection_model = span_selection_layer.get_model()

In [18]:
# Delete to save memory

del model
del model_layer
del span_selection_layer

### Define Loss

Loss function is simple.
* labels: 1D (batch_size) # start or end positions
* logits: 2D (batch_size x sequence_length)


In [31]:
def span_loss(position, logits):
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=tf.reshape(position, -1)))
    return loss

def span_loss(position, logits):
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=tf.squeeze(position)))
    return loss

def start_loss(y_true_dict, y_pred_dict):
    return span_loss(y_true_dict['start_position'], y_pred_dict['start_logits'])

def end_loss(y_true_dict, y_pred_dict):
    return span_loss(y_true_dict['end_position'], y_pred_dict['end_logits'])


def joint_loss(y_true_dict, y_pred_dict):
    start_loss = span_loss(y_true_dict['start_position'], y_pred_dict['start_logits'])
    end_loss = span_loss(y_true_dict['end_position'], y_pred_dict['end_logits'])
    return (start_loss + end_loss)/2.0

### Define Optimizer

In [29]:
train_data_size = 89000
EPOCHS = 4
optimizer = optimization.AdamWeightDecay(learning_rate=0.001)

INFO:absl:using Adamw optimizer


### Train Using Keras :-)

- ```compile2``` allows you to have directly use model outputs as well batch dataset outputs into the loss function, without any further complexity.

Note: For ```compile2```, loss_fn must be None, and custom_loss_fn must be active. Metrics are not supprted for time being.

In [27]:
# Keras Fit

keras_loss_fn = {'start_logits': start_loss, 
           'end_logits': end_loss}
span_selection_model.compile2(optimizer=optimizer, 
                            loss=None, 
                            custom_loss=keras_loss_fn, 
                            run_eagerly=False)
history = span_selection_model.fit(train_dataset, epochs=2, steps_per_epoch=10)

Epoch 1/2
Start logits (16, None)








Start logits (16, None)














<tensorflow.python.keras.callbacks.History at 0x7f90801d0c70>

### Train using SimpleTrainer (part of tf-transformers)

In [None]:
# Custom training
history = SimpleTrainer(model = span_selection_model,
             optimizer = optimizer,
             loss_fn = joint_loss,
             dataset = train_dataset.repeat(EPOCHS+1), # This is important
             epochs = EPOCHS, 
             num_train_examples = train_data_size, 
             batch_size = batch_size, 
             steps_per_call=100)

INFO:absl:Global Steps 165
  0%|          | 0/165 [00:00<?, ?it/s]









### Save Models 

You can save models as checkpoints using ```.save_checkpoint``` attribute, which is a part of all ```LegacyModels```

In [9]:
model_save_dir = '../OFFICIAL_MODELS/squad/t5_span_selection'
span_selection_model.save_checkpoint(model_save_dir, overwrite=True)

INFO:absl:Succesful: Model checkpoints matched


### Parse validation data

We use ```TFProcessor``` to create validation data, because dev data is small

In [8]:
dev_input_file_path = '/mnt/home/PRE_MODELS/HuggingFace_models/datasets/squadv1.1/dev-v1.1.json'

is_training = False

start_time = time.time()
dev_examples = read_squad_examples(
      input_file=dev_input_file_path,
      is_training=is_training,
      version_2_with_negative=False
)
end_time = time.time()
print('Time taken {}'.format(end_time-start_time))
dev_examples_cleaned = post_clean_train_squad(dev_examples, basic_tokenizer, is_training=False)
qas_id_info, dev_features = example_to_features_using_fast_sp_alignment_test(tokenizer, dev_examples_cleaned,  max_seq_length = 384, 
                                                           max_query_length=64, doc_stride=128, SPECIAL_PIECE=SPIECE_UNDERLINE)


def parse_dev():
    result = {}
    for f in dev_features:
        question_sep_index =  f['input_ids'].index(tokenizer.sep_token) # We dont want sep-token in the starting
        question_ids = tokenizer.tokenize('question: ') + f['input_ids'][1:question_sep_index+1] # 1 to avoid CLS token
        
        passage_context_prompt = tokenizer.tokenize('context: ')
        passage_ids = passage_context_prompt + f['input_ids'][question_sep_index +1 :] # Adding context as prompt

        result['encoder_input_ids'] = tokenizer.convert_tokens_to_ids(question_ids)
        result['encoder_input_mask'] = [1] * len(result['encoder_input_ids'])
        
        result['decoder_input_ids'] = tokenizer.convert_tokens_to_ids(passage_ids)
        result['decoder_input_mask'] = [1] * len(result['decoder_input_ids'])
        
        f['passage_ids'] = result['decoder_input_ids'] # indices
        yield result
        

tf_processor = TFProcessor()
dev_dataset = tf_processor.process(parse_fn=parse_dev())
dev_dataset = tf_processor.auto_batch(dev_dataset, batch_size=32)

Time taken 0.07536649703979492


### Evaluate Exact Match

* Make Predictions
* Extract Answers
* Evaluate

### Make Batch Predictions

In [28]:
def extract_from_dict(dict_items, key):
    holder = []
    for item in dict_items:
        holder.append(item[key])
    return holder
qas_id_list = extract_from_dict(dev_features, 'qas_id')
doc_offset_list = extract_from_dict(dev_features, 'doc_offset')

# Make batch predictions

per_layer_start_logits = []
per_layer_end_logits = []

start_time = time.time()
for (batch_inputs) in dev_dataset:
    model_outputs = span_selection_model(batch_inputs)
    per_layer_start_logits.append(model_outputs['start_logits'])
    per_layer_end_logits.append(model_outputs['end_logits'])
    
end_time = time.time()
print('Time taken {}'.format(end_time-start_time))

# Make batch predictions
n_best_size = 20
max_answer_length = 30
squad_dev_data = json.load(open(dev_input_file_path))['data']
layer_results = []

start_logits_unstcaked = []
end_logits_unstacked = []
for batch_start_logits in per_layer_start_logits:
    start_logits_unstcaked.extend(tf.unstack(batch_start_logits))
for batch_end_logits in per_layer_end_logits:
    end_logits_unstacked.extend(tf.unstack(batch_end_logits))

qas_id_logits = {}
for i in range(len(qas_id_list)): #
    qas_id = qas_id_list[i]
    example = qas_id_info[qas_id]
    feature = dev_features[i]
    assert qas_id == feature['qas_id']
    if qas_id not in qas_id_logits:
        qas_id_logits[qas_id] = {
                                            'feature_length': [len(feature['passage_ids'])],
                                            'doc_offset': [doc_offset_list[i]],
                                            'passage_start_pos': [0],
                                            'start_logits': [start_logits_unstcaked[i]], 
                                            'end_logits': [end_logits_unstacked[i]]}

    else:
        qas_id_logits[qas_id]['start_logits'].append(start_logits_unstcaked[i])
        qas_id_logits[qas_id]['end_logits'].append(end_logits_unstacked[i])
        qas_id_logits[qas_id]['feature_length'].append(len(feature['passage_ids']))
        qas_id_logits[qas_id]['doc_offset'].append(doc_offset_list[i])
        qas_id_logits[qas_id]['passage_start_pos'].append(0)
    
        
    


qas_id_answer = {}
skipped = []
skipped_null = []
global_counter = 0
p_texts = []
for qas_id in qas_id_logits:

    current_example = qas_id_logits[qas_id]

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index",
         "start_log_prob", "end_log_prob"])
    prelim_predictions = []
    example_features = []
    for i in range(len( current_example['start_logits'])):
        f = dev_features[global_counter]
        assert f['qas_id'] == qas_id
        example_features.append(f)
        global_counter += 1
        # passage_start_pos = current_example['passage_start_pos'][i]
        passage_start_pos = 0
        feature_length = current_example['feature_length'][i] # non-masked length

        start_log_prob_list = current_example['start_logits'][i].numpy().tolist()[:feature_length]
        end_log_prob_list = current_example['end_logits'][i].numpy().tolist()[:feature_length]
        start_indexes = _get_best_indexes(start_log_prob_list, n_best_size)
        end_indexes   = _get_best_indexes(end_log_prob_list, n_best_size)

        for start_index in start_indexes:
            for end_index in end_indexes:
              # We could hypothetically create invalid predictions, e.g., predict
              # that the start of the span is in the question. We throw out all
              # invalid predictions.
              # if start_index < passage_start_pos or end_index < passage_start_pos:
              #  continue
              if end_index < start_index:
                continue
              length = end_index - start_index + 1
              if length > max_answer_length:
                continue
              start_log_prob = start_log_prob_list[start_index]
              end_log_prob = end_log_prob_list[end_index]
              start_idx = start_index - passage_start_pos
              end_idx = end_index - passage_start_pos

              prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=i,
                            start_index=start_idx,
                            end_index=end_idx,
                            start_log_prob=start_log_prob,
                            end_log_prob=end_log_prob))



    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_log_prob + x.end_log_prob),
        reverse=True)
    
    if prelim_predictions:
        best_index = prelim_predictions[0].feature_index
        passage_ids = example_features[best_index]['passage_ids']
        predicted_ids = passage_ids[prelim_predictions[best_index].start_index: prelim_predictions[best_index].end_index + 1]
        predicted_text = tokenizer.decode(predicted_ids)
        qas_id_answer[qas_id] = predicted_text
        print("Predicted text", predicted_text)
        p_texts.append(predicted_text)

    else:
        qas_id_answer[qas_id] = ""
        skipped_null.append(qas_id)
        
    
    
eval_results = evaluate_v1(squad_dev_data, qas_id_answer)

# {'exact_match': 63.27341532639546, 'f1': 75.86691833684937}

### Save as Serialized version 

- Now we can use ```save_as_serialize_module``` to save a model directly to saved_model

In [None]:
# Save as optimized version
span_selection_model.save_as_serialize_module("{}/saved_model".format(model_save_dir), overwrite=True)

# Load optimized version
span_selection_model_serialized = tf.saved_model.load("{}/saved_model".format(model_save_dir))

### TFLite Conversion

TFlite conversion requires:
- static batch size
- static sequence length

In [None]:
# Sequence_length = 384
# batch_size = 1

# Lets convert it to a TFlite model

model_layer, model, config = T5Model(model_name='t5-small', 
                                     decoder_mask_mode='user_defined', 
                                     batch_size=1, 
                                     encoder_sequence_length=64, 
                                     decoder_sequence_length=384
                                     )



span_selection_layer = Span_Selection_Model(model=model,
                                      is_training=False)
span_selection_model = span_selection_layer.get_model()
span_selection_model.load_checkpoint(model_save_dir)

# Save to .pb format , we need it for tflite

span_selection_model.save_as_serialize_module("{}/saved_model_for_tflite".format(model_save_dir))

converter = tf.lite.TFLiteConverter.from_saved_model("{}/saved_model_for_tflite".format(model_save_dir)) # path to the SavedModel directory
converter.experimental_new_converter = True

tflite_model = converter.convert()

open("{}/converted_model.tflite".format(model_save_dir), "wb").write(tflite_model)

### **In production**

- We can use either ```tf.keras.Model``` or ```saved_model```. I recommend saved_model, which is much much faster and no hassle of having architecture code

**Note** - As of now pipeline cannot be overwridden. But you can write one by looking at existing pipelines,
with minimal changes

### Custom Pipeline for T5

In [None]:

import tensorflow as tf
import collections
from tf_transformers.data.squad_utils_sp import *
from tf_transformers.data.squad_utils_sp import _get_best_indexes, _compute_softmax
from tf_transformers.utils.tokenization import BasicTokenizer



def extract_from_dict(dict_items, key):
    holder = []
    for item in dict_items:
        holder.append(item[key])
    return holder

class Span_Extraction_Pipeline():

    def __init__(self, model, 
                tokenizer, 
                tokenizer_fn, 
                SPECIAL_PIECE, 
                n_best_size, 
                n_best, 
                max_answer_length, 
                max_seq_length,
                max_query_length, 
                doc_stride,
                call_fn = None,
                batch_size=32):

        self.get_model_fn(model)
        self.tokenizer = tokenizer
        self.tokenizer_fn = tokenizer_fn
        self.SPECIAL_PIECE = SPECIAL_PIECE
        self.n_best_size = n_best_size
        self.n_best = n_best
        self.max_answer_length = max_answer_length

        self.basic_tokenizer = BasicTokenizer(do_lower_case=False)
        self.max_seq_length = max_seq_length
        self.max_query_length = max_query_length
        self.doc_stride = doc_stride
        self.batch_size = batch_size
        


    def get_model_fn(self, model):
        self.model_fn = None
        # keras Model
        if isinstance(model, tf.keras.Model):
            self.model_fn = model
        else:
            # saved model
            if "saved_model" in str(type(self.model)):
                # Extract signature
                self.model_pb = model.signatures['serving_default']
                def model_fn(x):
                    return self.model_pb(**x)
                self.model_fn = model_fn
        if self.model_fn is None:
            raise ValueError("Please check the type of your model")
    
    def run(self, dataset):
        start_logits = []
        end_logits = []
        for batch_inputs in dataset:
            model_outputs = self.model_fn(batch_inputs)
            start_logits.append(model_outputs['start_logits'])
            end_logits.append(model_outputs['end_logits'])

        # Unstack

        start_logits_unstacked = []
        end_logits_unstacked = []

        for batch_logits in start_logits:
            start_logits_unstacked.extend(tf.unstack(batch_logits))
        for batch_logits in end_logits:
            end_logits_unstacked.extend(tf.unstack(batch_logits))

        return start_logits_unstacked, end_logits_unstacked
    
    def convert_to_features(self, dev_examples):
        """Convert examples to features"""
        qas_id_examples = {ex['qas_id']: ex for ex in dev_examples} 
        # dev_examples_cleaned = post_clean_train_squad(dev_examples, self.basic_tokenizer, is_training=False)
        qas_id_info, dev_features = example_to_features_using_fast_sp_alignment_test(self.tokenizer,
            dev_examples, self.max_seq_length, self.max_query_length, self.doc_stride, self.SPECIAL_PIECE
        )
        return qas_id_examples, qas_id_info, dev_features
        
    def convert_features_to_dataset(self, dev_features):
        """Feaures to TF dataset"""
        # for TFProcessor
        def local_parser():
            for f in dev_features:
                result = tokenizer_fn(f)
                f['passage_ids'] = result['decoder_input_ids']
                yield result

        # Create dataset
        tf_processor = TFProcessor()
        dev_dataset  = tf_processor.process(parse_fn=local_parser())
        self.dev_dataset = dev_dataset  = tf_processor.auto_batch(dev_dataset, batch_size = self.batch_size)
        return dev_dataset
    
    def post_process(self, qas_id_examples, dev_features, qas_id_info, start_logits_unstacked, end_logits_unstacked):
        # List of qa_ids per feature
        # List of doc_offset, for shifting when an example gets splitted due to length
        qas_id_list = extract_from_dict(dev_features, 'qas_id')  
        doc_offset_list = extract_from_dict(dev_features, 'doc_offset')

        # Group by qas_id -> predictions , because multiple feature may come from
        # single example :-)

        qas_id_logits = {}
        for i in range(len(qas_id_list)):
            qas_id = qas_id_list[i]
            example = qas_id_info[qas_id]
            feature = dev_features[i]
            assert qas_id == feature['qas_id']
            if qas_id not in qas_id_logits:
                qas_id_logits[qas_id] = {
                                                    'feature_length': [len(feature['passage_ids'])],
                                                    'doc_offset': [doc_offset_list[i]],
                                                    'start_logits': [start_logits_unstacked[i]], 
                                                    'end_logits': [end_logits_unstacked[i]]}

            else:
                qas_id_logits[qas_id]['start_logits'].append(start_logits_unstacked[i])
                qas_id_logits[qas_id]['end_logits'].append(end_logits_unstacked[i])
                qas_id_logits[qas_id]['feature_length'].append(len(feature['passage_ids']))
                qas_id_logits[qas_id]['doc_offset'].append(doc_offset_list[i])



        qas_id_answer = {}
        skipped = []
        skipped_null = []
        global_counter = 0
        final_result = {}
        for qas_id in qas_id_logits:

            current_example = qas_id_logits[qas_id]

            _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
                "PrelimPrediction",
                ["feature_index", "start_index", "end_index",
                "start_log_prob", "end_log_prob"])
            prelim_predictions = []
            example_features = []
            for i in range(len( current_example['start_logits'])):
                f = dev_features[global_counter]
                assert f['qas_id'] == qas_id
                example_features.append(f)
                global_counter += 1
                feature_length = current_example['feature_length'][i]

                start_log_prob_list = current_example['start_logits'][i].numpy().tolist()[:feature_length]
                end_log_prob_list = current_example['end_logits'][i].numpy().tolist()[:feature_length]
                start_indexes = _get_best_indexes(start_log_prob_list, self.n_best_size)
                end_indexes   = _get_best_indexes(end_log_prob_list, self.n_best_size)

                for start_index in start_indexes:
                    for end_index in end_indexes:
                        # We could hypothetically create invalid predictions, e.g., predict
                        # that the start of the span is in the question. We throw out all
                        # invalid predictions.
                        # if start_index < passage_start_pos or end_index < passage_start_pos:
                        #    continue
                        if end_index < start_index:
                            continue
                        length = end_index - start_index + 1
                        if length > self.max_answer_length:
                            continue
                        start_log_prob = start_log_prob_list[start_index]
                        end_log_prob = end_log_prob_list[end_index]
                        start_idx = start_index 
                        end_idx = end_index 
                        
                        prelim_predictions.append(
                                    _PrelimPrediction(
                                        feature_index=i,
                                        start_index=start_idx,
                                        end_index=end_idx,
                                        start_log_prob=start_log_prob,
                                        end_log_prob=end_log_prob))



            prelim_predictions = sorted(
                prelim_predictions,
                key=lambda x: (x.start_log_prob + x.end_log_prob),
                reverse=True)

            answer_dict = {}
            answer_dict[qas_id] = []
            total_scores = []
            if prelim_predictions:
                for top_n in range(self.n_best):
                    best_index = prelim_predictions[top_n].feature_index
                    passage_ids = example_features[best_index]['passage_ids']
                    predicted_ids = passage_ids[prelim_predictions[top_n].start_index: \
                                                prelim_predictions[top_n].end_index + 1]
                    predicted_text = tokenizer.decode(predicted_ids)
                    qas_id_answer[qas_id] = predicted_text
                    total_scores.append(prelim_predictions[top_n].start_log_prob + prelim_predictions[top_n].end_log_prob)
                    answer_dict[qas_id].append({'text': predicted_text})

                _probs = _compute_softmax(total_scores)

                for top_n in range(self.n_best):
                    answer_dict[qas_id][top_n]['probability'] = _probs[top_n]
                final_result[qas_id] = qas_id_examples[qas_id]
            else:
                qas_id_answer[qas_id] = ""
                skipped_null.append(qas_id)
            final_result[qas_id]['answers'] = answer_dict
        return final_result
    
    def __call__(self, questions, contexts, qas_ids=[]):
        
        # If qas_id is empty, we assign positions as id
        if qas_ids == []:
            qas_ids = [i for i in range(len(questions))]
        # each question should have a context
        assert(len(questions) == len(contexts) == len(qas_ids))

        dev_examples = convert_question_context_to_standard_format(questions, contexts, qas_ids)
        qas_id_examples, qas_id_info, dev_features = self.convert_to_features(dev_examples)
        dev_dataset = self.convert_features_to_dataset(dev_features)
        start_logits_unstacked, end_logits_unstacked = self.run(dev_dataset)
        final_result = self.post_process(qas_id_examples, dev_features, qas_id_info, start_logits_unstacked, end_logits_unstacked)
        return final_result


### Use pipeline

In [None]:
def tokenizer_fn(features):
    result = {}
    question_sep_index =  features['input_ids'].index(tokenizer.sep_token) # We dont want sep-token in the starting
    question_ids = tokenizer.tokenize('question: ') + features['input_ids'][1:question_sep_index+1] # 1 to avoid CLS token

    passage_context_prompt = tokenizer.tokenize('context: ')
    passage_ids = passage_context_prompt + features['input_ids'][question_sep_index +1 :] # Adding context as prompt

    result['encoder_input_ids'] = tokenizer.convert_tokens_to_ids(question_ids)
    result['encoder_input_mask'] = [1] * len(result['encoder_input_ids'])

    result['decoder_input_ids'] = tokenizer.convert_tokens_to_ids(passage_ids)
    result['decoder_input_mask'] = [1] * len(result['decoder_input_ids'])

    return result

    pipeline = Span_Extraction_Pipeline(model = span_selection_model_serialized,
                tokenizer = tokenizer, 
                tokenizer_fn = tokenizer_fn, 
                SPECIAL_PIECE = SPIECE_UNDERLINE, 
                n_best_size = 20, 
                n_best = 5, 
                max_answer_length = 30, 
                max_seq_length = 384, 
                max_query_length=64, 
                doc_stride=128)

questions = ['When was Kerala formed?']
questions = ['What was prominent in Kerala?']
questions = ['How many districts are there in Kerala']
questions = ['When was Kerala formed?']

contexts = ['''Kerala (English: /ˈkɛrələ/; Malayalam: [ke:ɾɐɭɐm] About this soundlisten (help·info)) is a
state on the southwestern Malabar Coast of India. It was formed on 1 November 1956, 
following the passage of the States Reorganisation Act, by combining Malayalam-speaking regions of 
the erstwhile states of Travancore-Cochin and Madras. 
Spread over 38,863 km2 (15,005 sq mi), Kerala is the twenty-first largest Indian state by area. 
It is bordered by Karnataka to the north and northeast, Tamil Nadu to the east and south, and the Lakshadweep Sea[14] to the west. With 33,387,677 inhabitants as per the 2011 Census, Kerala is the thirteenth-largest Indian state by population. It is divided into 14 districts with the capital being Thiruvananthapuram. Malayalam is the most widely spoken language and is also the official language of the state.[15]

The Chera Dynasty was the first prominent kingdom based in Kerala. The Ay kingdom in the deep south and the Ezhimala kingdom in the north formed the other kingdoms in the early years of the Common Era (CE). The region had been a prominent spice exporter since 3000 BCE. The region's prominence in trade was noted in the works of Pliny as well as the Periplus around 100 CE. In the 15th century, the spice trade attracted Portuguese traders to Kerala, and paved the way for European colonisation of India. At the time of Indian independence movement in the early 20th century, there were two major princely states in Kerala-Travancore State and the Kingdom of Cochin. They united to form the state of Thiru-Kochi in 1949. The Malabar region, in the northern part of Kerala, had been a part of the Madras province of British India, which later became a part of the Madras State post-independence. After the States Reorganisation Act, 1956, the modern-day state of Kerala was formed by merging the Malabar district of Madras State (excluding Gudalur taluk of Nilgiris district, Lakshadweep Islands, Topslip, the Attappadi Forest east of Anakatti), the state of Thiru-Kochi (excluding four southern taluks of Kanyakumari district, Shenkottai and Tenkasi taluks), and the taluk of Kasaragod (now Kasaragod District) in South Canara (Tulunad) which was a part of Madras State.''']

result = pipeline(questions, contexts)

        

### Sanity Check TFlite 

In [None]:
#### lets do a sanity check

sample_inputs = {}
encoder_input_ids = tf.random.uniform(minval=0, maxval=100, shape=(1, 64), dtype=tf.int32)
decoder_input_ids = tf.random.uniform(minval=0, maxval=100, shape=(1, 384), dtype=tf.int32)

sample_inputs['encoder_input_ids'] = encoder_input_ids
sample_inputs['encoder_input_mask'] = tf.ones_like(sample_inputs['encoder_input_ids'])
sample_inputs['decoder_input_ids'] = decoder_input_ids
sample_inputs['decoder_input_mask'] = tf.ones_like(sample_inputs['decoder_input_ids'])

model_outputs = span_selection_model(sample_inputs)


# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="{}/converted_model.tflite".format(model_save_dir))
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(input_details[0]['index'], sample_inputs['decoder_input_ids'])

interpreter.set_tensor(input_details[1]['index'], sample_inputs['decoder_input_mask'])

interpreter.set_tensor(input_details[2]['index'], sample_inputs['encoder_input_ids'])

interpreter.set_tensor(input_details[3]['index'], sample_inputs['encoder_input_mask'])
interpreter.invoke()

end_logits = interpreter.get_tensor(output_details[0]['index'])
start_logits   = interpreter.get_tensor(output_details[1]['index'])

print("Start logits", tf.reduce_sum(model_outputs['start_logits']), tf.reduce_sum(start_logits))
print("End logits", tf.reduce_sum(model_outputs['end_logits']), tf.reduce_sum(end_logits))

# We are good :-)