# MobileBERT for Question Answering on the SQuAD dataset

### 4. Creating a Gradio app to deploy the model 

In these notebooks we are going use [MobileBERT implemented by HuggingFace](https://huggingface.co/docs/transformers/model_doc/mobilebert) on the question answering task by text-extraction on the [The Stanford Question Answering Dataset (SQuAD)](https://rajpurkar.github.io/SQuAD-explorer/). The data is composed by a set of questions and paragraphs that contain the answers. The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to create [Pipeline](https://huggingface.co/docs/transformers/main_classes/pipelines) and deploy it with a [Gradio](https://huggingface.co/gradio) app.

More info from HuggingFace docs:
- [Question Answering](https://huggingface.co/tasks/question-answering)
- [Glossary](https://huggingface.co/transformers/glossary.html#model-inputs)
- [Question Answering chapter of NLP course](https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt)

In [None]:
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import Pipeline
from transformers import AutoTokenizer, MobileBertForQuestionAnswering

In [None]:
class MobileBERTQAPipeline(Pipeline):
    def __init__(self, hf_model_checkpoint):
        tokenizer = AutoTokenizer.from_pretrained(hf_model_checkpoint)
        model = MobileBertForQuestionAnswering.from_pretrained(hf_model_checkpoint)
        
        model.eval()

        # load finetuned-model
        model.load_state_dict(
        torch.load('mobilebertqa_ft',
               map_location=torch.device('cpu'))
        )

        super().__init__(model, tokenizer)

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, inputs, maybe_arg=2):
        model_input = self.tokenizer(*inputs, return_tensors="pt")
        self.context_tokens = model_input['input_ids']
        return model_input

    def _forward(self, model_inputs):
        with torch.no_grad():
            outputs = self.model(**model_inputs)

        return outputs

    def postprocess(self, model_outputs):
        start_probs = F.softmax(model_outputs.start_logits, dim=-1)[0]
        end_probs   = F.softmax(model_outputs.end_logits,   dim=-1)[0]

        # find the max class that the softmax gives
        start = torch.argmax(start_probs).item()
        end = torch.argmax(end_probs).item()
        
        # predicted answer
        answer_tokens = self.context_tokens[0][start:end]
        answer_text = self.tokenizer.decode(answer_tokens, skip_special_tokens=True,
                                            clean_up_tokenization_spaces=True)

        # start position in text
        start_text = len(self.tokenizer.decode(self.context_tokens[0][:start], skip_special_tokens=True,
                                               clean_up_tokenization_spaces=True)) + 1

        before_answer = self.tokenizer.decode(self.context_tokens[0], skip_special_tokens=True,
                                               clean_up_tokenization_spaces=True)[:start_text]

        return {'start': start_text,
                'end': start_text + len(answer_text),
                'answer': answer_text}

In [None]:
qa_pipeline = MobileBERTQAPipeline('google/mobilebert-uncased')

In [None]:
def highlight_answer(question, paragraph,
                     tokenizer=AutoTokenizer.from_pretrained('google/mobilebert-uncased')):

    # Rewrite the paragraph in the way the tokenizer sees it
    # otherwise there may be mismatches between the 'raw' input
    # text and the decode one
    tokens = tokenizer(paragraph)['input_ids']
    paragraph = tokenizer.decode(tokens,
                                 skip_special_tokens=True,
                                 clean_up_tokenization_spaces=True)

    # Use the Hugging Face pipeline to get the answer
    answer = qa_pipeline((paragraph, question))

    # Extract start and end indices from the pipeline output
    start_index = answer['start']
    end_index = answer['end']

    # Highlight the answer within the paragraph
    highlighted_text = (paragraph[:start_index] + 
                        "<span style='background-color: #FFA500;'> <b>" + 
                        paragraph[start_index:end_index] + 
                        "</b> </span>" + 
                        paragraph[end_index:])

    return highlighted_text

In [None]:
gr.Interface(
    fn=highlight_answer,
    inputs=["text", "text"],
    outputs="html",
    title="Highlight Answer in Paragraph",
    description="Highlight the answer within the paragraph",
    examples=[
        ["What is the quick animal?", "The quick brown fox jumps over the lazy dog."],
        ["What color is the sky?", "The sky is blue."]
    ]
).launch()