### Importing Packages

In [None]:
from utils import tokenizer as tk
import openvino.runtime as ov
import warnings
from pathlib import Path
import numpy as np
import time
import torch

### Defining checkpoint 

We're defining the predefined huggingface model which we'll be using for the fill mask task. I'm using the bert-large-uncased-whole-word-masking model for this notebook. You can alternatively make a selection from any other model from the [list](https://huggingface.co/models?pipeline_tag=fill-mask). 

*P.S. There might be slight changes in the preprocessing and postprocessing steps of a few models.*

In [None]:
checkpoint = "bert-large-uncased-whole-word-masking"

### Serialization

Transformers provides a **transformers.onnx package** that enables you to convert model checkpoints to an ONNX graph by leveraging configuration objects. These configuration objects come ready made for a number of model architectures, and are designed to be easily extendable to other architectures.
More details about serialization and supported models can be found [here](https://huggingface.co/docs/transformers/serialization)

In [None]:
!python -m transformers.onnx -h

In [None]:
serialize_command = f"python -m transformers.onnx \
    -m {checkpoint} \
    --feature masked-lm model/"
! $serialize_command


### Model Optimization

Model Optimizer is a cross-platform command-line tool that facilitates the transition between training and deployment environments, performs static model analysis, and adjusts deep learning models for optimal execution on end-point target devices. [Click here](https://docs.openvino.ai/latest/openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html) to find details and features of model optimizer.

For my model, I'm using the onnx model and truncating the input size to 128, and using the input features input_ids, attention_mask, token_type_ids

In [None]:
onnx_model_path = "model.onnx"
MODEL_DIR = "model/"
MODEL_DIR = f"{MODEL_DIR}"
onnx_model_path = Path(MODEL_DIR) / onnx_model_path

optimizer_command = f"mo \
    --input_model {onnx_model_path} \
    --output_dir {MODEL_DIR} \
    --model_name {checkpoint} \
    --input input_ids,attention_mask,token_type_ids \
    --input_shape [1,128],[1,128],[1,128]"
! $optimizer_command

### Inference Request

I'm setting up the [inference request](https://docs.openvino.ai/latest/openvino_docs_OV_UG_Infer_request.html) for my model using the model graph file (.xml)

In [None]:
warnings.filterwarnings("ignore")

core = ov.Core()
ir_model_xml = str((Path(MODEL_DIR) / checkpoint).with_suffix(".xml"))
compiled_model = core.compile_model(ir_model_xml)
infer_request = compiled_model.create_infer_request()

### Softmax Function

Creating a softmax function to postprocess the outputs we get from the model

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

### Postprocessing
Here we create the output function, i.e. 
1. Preprocess the inputs passed to the model through the custom tokenizer built on the bert-base-uncased vocab file.
2. Check for any irregularities:
        i. If there is more than one [MASK] token, the sentence is not accepted because Fill Mask task supports exactly one                [MASK] token per sentence.
        ii. If there is no [MASK] token in the sentence, the [MASK] token is appended at the end of the sentence.
3. After we've got the outputs (in the form of logits) from the model, the softmax function is performed to predict the masked token.
4. We're displaying the top 10 results and rewriting the input sentence with the masked token replaced by the prediction

To check out the individual functions of the tokenizer, check the [tokenizer.py file](../utils/tokenizer.py).

In [None]:
def postprocess(text):
    output = []
    special_tokens = tk.special_tokens_list()
    multiple_mask = tk.check_mask_token(text)
    err_text = f"ERROR:\
        Too many\
        {special_tokens['mask_token']}\
        tokens in sentence.\
        Sentence should\
        have exactly one\
        {special_tokens['mask_token']}\
        token for\
        Fill Mask task"
    err_text = err_text.strip()
    if multiple_mask > 1:
        output = err_text
    else:
        inputs = tk.preprocess_text(text, 128)
        if inputs == -2:
            output = err_text
        else:
            result = infer_request.infer(inputs)
            input_ids = inputs["input_ids"][0]
            mask_token_index = tk.word_to_token(special_tokens["mask_token"])
            masked_index = [
                i for i in range(0, len(input_ids))
                if input_ids[i] == mask_token_index
            ][0]
            for i in result.values():
                outputs = i
            logits = outputs[0, masked_index, :]
            prob = softmax(logits)
            prob = torch.from_numpy(prob)
            value, prediction = prob.topk(10)
            isMaskExists = False
            text = text.strip()
            if special_tokens["mask_token"] in text:
                part1 = text.split("[MASK]")[0]
                part2 = text.split("[MASK]")[1]
                isMaskExists = True

            for v, p in zip(value.tolist(), prediction.tolist()):
                word = tk.tokens_to_ids(p)
                if isMaskExists:
                    output.append(
                        {
                            "Sequence": part1 + word + part2,
                            "token": word,
                            "score": "%.5f" % v,
                        }
                    )
                else:
                    output.append(
                        {"Sequence": text + " " + word,
                         "token": word,
                         "score": "%.5f" % v}
                    )

    return output

Calling the postprocess function and recording the total time for the operation

In [None]:
def getResults(text):
    print("Original Text: ", text)
    start_time = time.perf_counter()
    result = postprocess(text)

    end_time = time.perf_counter()
    total_time = end_time - start_time
    if type(result) != str:
        result.append({"total_time": str("%.2f" % total_time) + " seconds"})
    return result

### Results

Time for results now! Let's go. Replace the sentence with a sentence of your choice.

#### 1. Exactly one [MASK] token inside the sentence

In [None]:
text = "How are you? I haven't [MASK] you in a while."
result = getResults(text)
result

#### 2. No [MASK] token

In [None]:
text = "Have you seen my            "
result = getResults(text)
result

#### 3. Multiple [MASK] tokens in a single sentence

In [None]:
text = "How are [MASK] ? I haven't [MASK] you in a [MASK] ."
result = getResults(text)
result