# ONNX Tutorial - BiDAF Model - Inference

The tutorial demonstrates loading trained model weights from BlobHub and using for inference. 

Bi-Directional Attention Flow for Machine Comprehension model is used in this tutorial 
(extractive question answering model).

References: 
- BlobHub model  
  https://blobhub.io/onnx-text-models/bi-att-flow
- Model details  
  https://allenai.github.io/bi-att-flow/  

## Table of Contents

- [Install Dependencies](#Install-Dependencies)
- [Download Model from BlobHub](#Download-Model-from-BlobHub)
- [Model Inference](#Model-Inference)
- [Example Questions](#Example-Questions)

## Install Dependencies

The following packages are required for this tutorial:

In [None]:
!pip install blobhub

In [None]:
!pip install onnx

In [None]:
!pip install onnxruntime

In [None]:
!pip install nltk

In [None]:
import nltk
nltk.download("punkt")

## Download Model from BlobHub

This snippet downloads model from public BlobHub blob. Blob reference:

In [None]:
ORG_ID = "onnx-text-models"
BLOB_ID = "bi-att-flow"

Model downloading code:

In [None]:
from blobhub.blob import Blob, Revision
from blobhub.presets.onnx import Onnx, Model       

# Find blob
blob = Blob(org_id=ORG_ID, blob_id=BLOB_ID)
revision = blob.revisions.latest()

# Initialize preset
onnx = Onnx(revision=revision)

# Download and save the model
downloaded_model = onnx.download()
assert None != downloaded_model

Downloaded model is stored locally and is accessible under:

In [None]:
downloaded_model.path

## Model Inference

Check model correctness (ONNX model consistency check):

In [None]:
import onnx

onnx_model = onnx.load(downloaded_model.path)
onnx.checker.check_model(onnx_model)

Initialize ONNX runtime:

In [None]:
import onnxruntime

ort_session = onnxruntime.InferenceSession(downloaded_model.path)

Inference helpers:

In [None]:
import numpy as np
import string
from nltk import word_tokenize

def preprocess(text):
    tokens = word_tokenize(text)
    # split into lower-case word tokens, in numpy array with shape of (seq, 1)
    words = np.asarray([w.lower() for w in tokens]).reshape(-1, 1)
    # split words into chars, in numpy array with shape of (seq, 1, 1, 16)
    chars = [[c for c in t][:16] for t in tokens]
    chars = [cs+['']*(16-len(cs)) for cs in chars]
    chars = np.asarray(chars).reshape(-1, 1, 1, 16)
    return words, chars

def infer(ort_session, context, query):
    # Prepare input
    cw, cc = preprocess(context)
    qw, qc = preprocess(query)
    
    # Run inference
    ort_inputs = {
        ort_session.get_inputs()[0].name: cw,
        ort_session.get_inputs()[1].name: cc,
        ort_session.get_inputs()[2].name: qw,
        ort_session.get_inputs()[3].name: qc
    }
    ort_outs = ort_session.run(None, ort_inputs)    
    
    # Parse output
    start = ort_outs[0].item()
    end = ort_outs[1].item()
    answer = [w.encode() for w in cw[start:end+1].reshape(-1)]
    
    return answer

## Example Questions

In [None]:
infer(
    ort_session=ort_session, 
    context="A quick brown fox jumps over the lazy dog.", 
    query="What color is the fox?"
)

In [None]:
infer(
    ort_session=ort_session, 
    context="The tokenized words are in lower case, while chars are not.", 
    query="What is not tokenized?"
)

In [None]:
infer(
    ort_session=ort_session, 
    context=
        "A black hole is a region of spacetime where gravity is so strong that nothing"
        " — no particles or even electromagnetic radiation such as light — can escape from it.", 
    query="What is a black hole?"
)