## imports etc

In [1]:
import time
import torch
import numpy as np

from transformers import BertTokenizer, BertForQuestionAnswering
from onnxruntime import InferenceSession

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
def run_n_times(func, desc, n = 100): 
    st= time.time()
    for i in range(n):
        func()
    et = time.time()
    print('Execution time for', desc, ':', et - st, 'seconds')

## Torch

In [3]:
tokenizer = BertTokenizer.from_pretrained("deepset/bert-base-cased-squad2")
model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors="pt")

### run

In [4]:
def run_torch():
    global outputs_torch
    with torch.no_grad():
        outputs_torch = model(**inputs)
run_torch()        

### get answer

In [5]:
answer_start_index = outputs_torch.start_logits.argmax()
answer_end_index = outputs_torch.end_logits.argmax()

predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
tokenizer.decode(predict_answer_tokens)

'a nice puppet'

### save  to disk

In [6]:
tokenizer.save_pretrained("local-pt-checkpoint-squad2")
model.save_pretrained("local-pt-checkpoint-squad2")

## ONNX

In [21]:
session = InferenceSession("onnx/model.onnx")

inputs = tokenizer(question, text, return_tensors="pt") # "np" (?)
onnx_inputs = {key: np.array(inputs[key], dtype=np.int64) for key in inputs}

### run

In [22]:
def run_onnx():
    global outputs_onnx
    outputs_onnx = session.run(output_names=["start_logits", "end_logits"], input_feed=dict(onnx_inputs))
run_onnx()        

### get answer

In [9]:
answer_start_index = outputs_onnx[0].argmax(axis=1)
answer_end_index = outputs_onnx[1].argmax(axis=1)

predict_answer_tokens = inputs.input_ids[0, torch.tensor(answer_start_index) : torch.tensor(answer_end_index) + 1]
tokenizer.decode(predict_answer_tokens)

'a nice puppet'

## little test just4fun

In [18]:
run_n_times(run_torch, desc = "torch")        
run_n_times(run_onnx, desc = "onnx")        

Execution time for torch : 4.1156511306762695 seconds
Execution time for onnx : 1.5255463123321533 seconds
