## Caching in Attention Models
This challenge is about applying caching in attention models to speed up inference. We will use the pix2Struct model.
</br>
First, we will be exporting the checkpoint from HF using the right architecture, </br> </br>
Note: run with python3.8

In [None]:
#!python -m pip install optimum
!optimum-cli export onnx --model="google/pix2struct-docvqa-base" \
    --device "cpu" --atol=1e-3 --framework="pt" \
    --task="visual-question-answering-with-past" \
    "./export/original/docvqa/"

In [None]:
!ls ./export/original/docvqa/

## Sample Inference
Now we will run some inference in the plain model - encoder_model.onnx, and decoder_model.onnx</br>
You will re-use all these image pre-processing routines, tokenization, question-on-top-of-image rendering, etc. </br>
You need to focus only in "wiring" model inputs/outputs to obtain the desired speedup.

In [None]:
import os
import re
import time
import numpy as np

from PIL import Image
from inference import run

In [None]:
#defines the questions
questions = ["What happens from 11:44am to 12:25am?",
             "What is the designated time for Questions and Answers?",
             "When is the Coffee Break?",
             "Who is giving the Introductory Remarks?",
             "Who is going to take part of the individual interviews?",
             "What time do the Exhibits Open?",
             "Where will the Coffee be served?",
             "Who is the TRRF Vice President?",
             "What is the designated time for TRRF Scientific Advisory Council Meeting?",
             "Who is the TRRF Treasurer?"             
           ]

In [None]:
#prepare inputs for the run wrapper present in Inference script
decoderModelPath = "./export/original/docvqa/decoder_model.onnx"
encoderModelPath = "./export/original/docvqa/encoder_model.onnx"

inputs = {}
inputs["encoderPath"] = encoderModelPath
inputs["decoderPath"] = decoderModelPath
inputs["decoderWithCachePath"] = decoderModelPath
inputs["pieceModelPath"] = "./export/original/docvqa/spiece.model"
inputs["fontPath"] = "./resources/Arial.ttf"
inputs["imagePath"] = "./resources/download.png"

In [None]:
# Take a look at the sample image
img = Image.open(inputs["imagePath"])
img

In [None]:
#perform inference
encoderTime = [] 
decoderTime = []
originalAnswers = [] 

for question in questions:
    temp_result = {}
    ques,ans,ecoder_time,decoder_time,image_time = run(inputs,question,weightsType=32,cache=False,log=False)
    temp_result["decoded_question"] = ques
    temp_result["decoded_answer"] = ans
    temp_result["encoder_time"] = ecoder_time
    temp_result["decoder_time"] = decoder_time
    
    
    encoderTime.append(ecoder_time)
    decoderTime.append(decoder_time)
    
    cleanedAnswer =  re.sub(r'[^\w]', '', ans).lower()
    originalAnswers.append(cleanedAnswer)

    print(temp_result,end="\n\n")

## Your Task
Now it's when the fun begins! You will make caching work in the decoder by using the decoder_model_merged.onnx you obtained on previous steps. Rules,</br>
* Feel free to modify existing files to accomodate for the new caching feature.
* Code quality matters.
* Memory utilization matters.
* Good selection of data structures, algorithmic complexity matters.
* Documentation.... well you guessed it, it matters

In [None]:
#prepare inputs for the run wrapper present in Inference script -- keep almost same as above
decoderModelPath = "./export/original/docvqa/decoder_model.onnx"
decoderWithCacheModelPath = "./export/original/docvqa/decoder_model_merged.onnx"
encoderModelPath = "./export/original/docvqa/encoder_model.onnx"

inputs = {}
inputs["encoderPath"] = encoderModelPath
inputs["decoderPath"] = decoderModelPath
inputs["decoderWithCachePath"] = decoderWithCacheModelPath
inputs["pieceModelPath"] = "./export/original/docvqa/spiece.model"
inputs["fontPath"] = "./resources/Arial.ttf"
inputs["imagePath"] = "./resources/download.png"

In [None]:
#perform inference -- keep the same interface as before
encoderTime = [] 
decoderTime = []
originalAnswers = [] 

for question in questions:
    temp_result = {}
    # Just change the cache to True
    ques,ans,ecoder_time,decoder_time,image_time = run(inputs,question,weightsType=32,cache=True,log=False)
    temp_result["decoded_question"] = ques
    temp_result["decoded_answer"] = ans
    temp_result["encoder_time"] = ecoder_time
    temp_result["decoder_time"] = decoder_time 
    
    encoderTime.append(ecoder_time)
    decoderTime.append(decoder_time)
    
    cleanedAnswer =  re.sub(r'[^\w]', '', ans).lower()
    originalAnswers.append(cleanedAnswer)

    print(temp_result,end="\n\n")

In [None]:
import onnx
from google.protobuf.json_format import MessageToDict

print("Decoder Merged")
model = onnx.load("export/original/docvqa/decoder_model_merged.onnx")
for _input in model.graph.input:
    print(MessageToDict(_input))
print("=====================================")
print("=====================================")
print("Decoder")
model = onnx.load("export/original/docvqa/decoder_model.onnx")
for _input in model.graph.input:
    print(MessageToDict(_input))

In [None]:
import onnxruntime
import onnx

options = onnxruntime.SessionOptions()
#options.log_severity_level = 0
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
#options.enable_mem_pattern = True

#coso = onnxruntime.InferenceSession("./export/original/docvqa/decoder_model_merged.onnx")

model = onnx.load("./export/original/docvqa/decoder_model_merged.onnx")
coso = onnxruntime.InferenceSession(model.SerializeToString(), options)