In [24]:
import matplotlib.pyplot as plt
from PIL import Image
from tool.config import Cfg
from tool.translate import build_model, process_input, translate
import torch
import onnxruntime
import numpy as np

In [2]:
config = Cfg.load_config_from_file('./config/vgg-seq2seq.yml')
config['cnn']['pretrained']=False
config['device'] = 'cpu'
model, vocab = build_model(config)
weight_path = './weight/transformerocr.pth'

In [3]:
# load weight
model.load_state_dict(torch.load(weight_path, map_location=torch.device(config['device'])))
model = model.eval() 

## Export CNN part

In [4]:
def convert_cnn_part(img, save_path, model, max_seq_length=128, sos_token=1, eos_token=2): 
    with torch.no_grad(): 
        src = model.cnn(img)
        torch.onnx.export(model.cnn, img, save_path, export_params=True, opset_version=12, do_constant_folding=True, verbose=True, input_names=['img'], output_names=['output'], dynamic_axes={'img': {3: 'lenght'}, 'output': {0: 'channel'}})
    
    return src

In [5]:
img = torch.rand(1, 3, 32, 475)
src = convert_cnn_part(img, './weight/cnn.onnx', model)

graph(%img : Float(1, 3, 32, *, strides=[45600, 15200, 475, 1], requires_grad=0, device=cpu),
      %model.last_conv_1x1.weight : Float(256, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cpu),
      %model.last_conv_1x1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
      %190 : Float(64, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu),
      %191 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %193 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %194 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %196 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %197 : Float(128, strides=[1], requires_grad=0, device=cpu),
      %199 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cpu),
      %200 : Float(128, strides=[1], requires_grad=0, device=cpu),
      %202 : Float(256, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cpu),


## Export encoder part

In [6]:
def convert_encoder_part(model, src, save_path): 
    encoder_outputs, hidden = model.transformer.encoder(src) 
    torch.onnx.export(model.transformer.encoder, src, save_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['src'], output_names=['encoder_outputs', 'hidden'], dynamic_axes={'src':{0: "channel_input"}, 'encoder_outputs': {0: 'channel_output'}}) 
    return hidden, encoder_outputs

In [7]:
hidden, encoder_outputs = convert_encoder_part(model, src, './weight/encoder.onnx')

  "or define the initial states (h0/c0) as inputs of the model. ")


## Export decoder part

In [8]:
def convert_decoder_part(model, tgt, hidden, encoder_outputs, save_path):
    tgt = tgt[-1]
    
    torch.onnx.export(model.transformer.decoder,
        (tgt, hidden, encoder_outputs),
        save_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['tgt', 'hidden', 'encoder_outputs'],
        output_names=['output', 'hidden_out', 'last'],
        dynamic_axes={'encoder_outputs':{0: "channel_input"},
                    'last': {0: 'channel_output'}})

In [9]:
device = img.device
tgt = torch.LongTensor([[1] * len(img)]).to(device)

In [10]:
convert_decoder_part(model, tgt, hidden, encoder_outputs, './weight/decoder.onnx')

  assert (output == hidden).all()


## Load and check model

In [11]:
import onnx

In [12]:
cnn = onnx.load('./weight/cnn.onnx')
decoder = onnx.load('./weight/encoder.onnx')
encoder = onnx.load('./weight/decoder.onnx')

In [13]:
# confirm model has valid schema
onnx.checker.check_model(cnn)
onnx.checker.check_model(decoder)
onnx.checker.check_model(encoder)

In [14]:
# # Print a human readable representation of the graph
onnx.helper.printable_graph(encoder.graph)

'graph torch-jit-export (\n  %tgt[INT64, 1]\n  %hidden[FLOAT, 1x256]\n  %encoder_outputs[FLOAT, channel_inputx1x512]\n) initializers (\n  %attention.attn.bias[FLOAT, 256]\n  %embedding.weight[FLOAT, 233x256]\n  %fc_out.weight[FLOAT, 233x1024]\n  %fc_out.bias[FLOAT, 233]\n  %116[INT64, 1]\n  %117[INT64, 1]\n  %118[INT64, 1]\n  %119[INT64, 1]\n  %120[FLOAT, 768x256]\n  %121[FLOAT, 256x1]\n  %139[FLOAT, 1x768x768]\n  %140[FLOAT, 1x768x256]\n  %141[FLOAT, 1x1536]\n) {\n  %13 = Unsqueeze[axes = [0]](%tgt)\n  %14 = Gather(%embedding.weight, %13)\n  %15 = Shape(%encoder_outputs)\n  %16 = Constant[value = <Scalar Tensor []>]()\n  %17 = Gather[axis = 0](%15, %16)\n  %18 = Unsqueeze[axes = [1]](%hidden)\n  %22 = Unsqueeze[axes = [0]](%17)\n  %24 = Concat[axis = 0](%116, %22, %117)\n  %26 = Unsqueeze[axes = [0]](%17)\n  %28 = Concat[axis = 0](%118, %26, %119)\n  %29 = Shape(%24)\n  %30 = ConstantOfShape[value = <Tensor>](%29)\n  %31 = Expand(%18, %30)\n  %32 = Tile(%31, %28)\n  %33 = Transpose[pe

## Inference directly

In [15]:
img = Image.open('./sample/35944.png')
img = process_input(img, config['dataset']['image_height'], 
                config['dataset']['image_min_width'], config['dataset']['image_max_width'])  
img = img.to(config['device'])

In [20]:
s = translate(img, model)[0].tolist()
s = vocab.decode(s)
s

'Mâm non: 141 thí sinh'

## Inference with ONNX Runtime's Python API

In [23]:
# create inference session
cnn_session = onnxruntime.InferenceSession("./weight/cnn.onnx")
encoder_session = onnxruntime.InferenceSession("./weight/encoder.onnx")
decoder_session = onnxruntime.InferenceSession("./weight/decoder.onnx")

In [25]:
def translate_onnx(img, session, max_seq_length=128, sos_token=1, eos_token=2):
    """data: BxCxHxW"""
    cnn_session, encoder_session, decoder_session = session
    
    # create cnn input
    cnn_input = {cnn_session.get_inputs()[0].name: img}
    src = cnn_session.run(None, cnn_input)
    
    # create encoder input
    encoder_input = {encoder_session.get_inputs()[0].name: src[0]}
    encoder_outputs, hidden = encoder_session.run(None, encoder_input)
    translated_sentence = [[sos_token] * len(img)]
    max_length = 0

    while max_length <= max_seq_length and not all(
        np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
    ):
        tgt_inp = translated_sentence
        decoder_input = {decoder_session.get_inputs()[0].name: tgt_inp[-1], decoder_session.get_inputs()[1].name: hidden, decoder_session.get_inputs()[2].name: encoder_outputs}

        output, hidden, _ = decoder_session.run(None, decoder_input)
        output = np.expand_dims(output, axis=1)
        output = torch.Tensor(output)

        values, indices = torch.topk(output, 1)
        indices = indices[:, -1, 0]
        indices = indices.tolist()

        translated_sentence.append(indices)
        max_length += 1

        del output

    translated_sentence = np.asarray(translated_sentence).T

    return translated_sentence

In [28]:
session = (cnn_session, encoder_session, decoder_session)
s = translate_onnx(np.array(img), session)[0].tolist()
s = vocab.decode(s)
s

'Mâm non: 141 thí sinh'