In [1]:
import os
import matplotlib.pyplot as plt
from PIL import Image
from tool.config import Cfg
from tool.translate import build_model, process_input, translate
import torch
import onnxruntime
import numpy as np

In [2]:
config = Cfg.load_config_from_file('./weights/custom_config_01102025.yml')
config['cnn']['pretrained']=False
config['device'] = 'cuda:0'
print(config['seq_modeling'])
model, vocab = build_model(config)
weight_path = './weights/transformerocr.pth'

transformer




In [3]:
# load weight
model.load_state_dict(torch.load(weight_path, map_location=torch.device(config['device'])))
model = model.eval() 

## Export CNN part

In [4]:
def convert_cnn_part(img, save_path, model, max_seq_length=128, sos_token=1, eos_token=2): 
    with torch.no_grad(): 
        src = model.cnn(img)
        torch.onnx.export(model.cnn, img, save_path, export_params=True, opset_version=12, do_constant_folding=True, verbose=True, input_names=['img'], output_names=['output'], dynamic_axes={'img': {3: 'lenght'}, 'output': {0: 'channel'}})
    
    return src

In [5]:
img = torch.rand(1, 3, 32, 475, device=torch.device(config['device']))
src = convert_cnn_part(img, './weight_onnx/cnn.onnx', model)

  torch.onnx.export(model.cnn, img, save_path, export_params=True, opset_version=12, do_constant_folding=True, verbose=True, input_names=['img'], output_names=['output'], dynamic_axes={'img': {3: 'lenght'}, 'output': {0: 'channel'}})


Exported graph: graph(%img : Float(1, 3, 32, *, strides=[45600, 15200, 475, 1], requires_grad=0, device=cuda:0),
      %model.last_conv_1x1.weight : Float(256, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0),
      %model.last_conv_1x1.bias : Float(256, strides=[1], requires_grad=1, device=cuda:0),
      %onnx::Conv_180 : Float(64, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0),
      %onnx::Conv_181 : Float(64, strides=[1], requires_grad=0, device=cuda:0),
      %onnx::Conv_183 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cuda:0),
      %onnx::Conv_184 : Float(64, strides=[1], requires_grad=0, device=cuda:0),
      %onnx::Conv_186 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cuda:0),
      %onnx::Conv_187 : Float(128, strides=[1], requires_grad=0, device=cuda:0),
      %onnx::Conv_189 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cuda:0),
      %onnx::Conv_190 : Float(128,

In [6]:
class EncoderWrapper(torch.nn.Module):
    def __init__(self, transformer):
        super().__init__()
        self.transformer = transformer

    def forward(self, src):
        return self.transformer.forward_encoder(src)


class DecoderWrapper(torch.nn.Module):
    def __init__(self, transformer):
        super().__init__()
        self.transformer = transformer

    def forward(self, tgt, memory):
        return self.transformer.forward_decoder(tgt, memory)



## Export encoder part

In [7]:
def convert_encoder_part(model, src, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    encoder_module = EncoderWrapper(model.transformer)
    encoder_module.eval()
    encoder_module.to(src.device)
    with torch.no_grad():
        memory = encoder_module(src)
        torch.onnx.export(
            encoder_module,
            src,
            save_path,
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['src'],
            output_names=['memory'],
            dynamic_axes={
                'src': {0: 'seq_len', 1: 'batch'},
                'memory': {0: 'seq_len', 1: 'batch'},
            },
        )
    return memory.detach()

In [None]:
memory = convert_encoder_part(model, src, './weight_onnx/encoder.onnx')

## Export decoder part

In [None]:
def convert_decoder_part(model, tgt, memory, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    decoder_module = DecoderWrapper(model.transformer)
    decoder_module.eval()
    decoder_module.to(tgt.device)
    memory = memory.to(tgt.device)
    with torch.no_grad():
        torch.onnx.export(
            decoder_module,
            (tgt, memory),
            save_path,
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['tgt', 'memory'],
            output_names=['logits', 'memory_out'],
            dynamic_axes={
                'tgt': {0: 'tgt_seq', 1: 'batch'},
                'memory': {0: 'src_seq', 1: 'batch'},
                'logits': {0: 'batch', 1: 'tgt_seq'},
                'memory_out': {0: 'src_seq', 1: 'batch'},
            },
        )

In [None]:
device = img.device
tgt = torch.full((1, img.shape[0]), 1, dtype=torch.long, device=device)

In [None]:
convert_decoder_part(model, tgt, memory, './weight_onnx/decoder.onnx')

  assert (output == hidden).all()


## Load and check model

In [None]:
import onnx

In [None]:
cnn = onnx.load('./weight/cnn.onnx')
decoder = onnx.load('./weight/encoder.onnx')
encoder = onnx.load('./weight/decoder.onnx')

In [None]:
# confirm model has valid schema
onnx.checker.check_model(cnn)
onnx.checker.check_model(decoder)
onnx.checker.check_model(encoder)

In [None]:
# # Print a human readable representation of the graph
onnx.helper.printable_graph(encoder.graph)

'graph torch-jit-export (\n  %tgt[INT64, 1]\n  %hidden[FLOAT, 1x256]\n  %encoder_outputs[FLOAT, channel_inputx1x512]\n) initializers (\n  %attention.attn.bias[FLOAT, 256]\n  %embedding.weight[FLOAT, 233x256]\n  %fc_out.weight[FLOAT, 233x1024]\n  %fc_out.bias[FLOAT, 233]\n  %116[INT64, 1]\n  %117[INT64, 1]\n  %118[INT64, 1]\n  %119[INT64, 1]\n  %120[FLOAT, 768x256]\n  %121[FLOAT, 256x1]\n  %139[FLOAT, 1x768x768]\n  %140[FLOAT, 1x768x256]\n  %141[FLOAT, 1x1536]\n) {\n  %13 = Unsqueeze[axes = [0]](%tgt)\n  %14 = Gather(%embedding.weight, %13)\n  %15 = Shape(%encoder_outputs)\n  %16 = Constant[value = <Scalar Tensor []>]()\n  %17 = Gather[axis = 0](%15, %16)\n  %18 = Unsqueeze[axes = [1]](%hidden)\n  %22 = Unsqueeze[axes = [0]](%17)\n  %24 = Concat[axis = 0](%116, %22, %117)\n  %26 = Unsqueeze[axes = [0]](%17)\n  %28 = Concat[axis = 0](%118, %26, %119)\n  %29 = Shape(%24)\n  %30 = ConstantOfShape[value = <Tensor>](%29)\n  %31 = Expand(%18, %30)\n  %32 = Tile(%31, %28)\n  %33 = Transpose[pe

## Inference directly

In [None]:
img = Image.open('./sample/35944.png')
img = process_input(img, config['dataset']['image_height'], 
                config['dataset']['image_min_width'], config['dataset']['image_max_width'])  
img = img.to(config['device'])

In [None]:
s = translate(img, model)[0].tolist()
s = vocab.decode(s)
s

'Mâm non: 141 thí sinh'

## Inference with ONNX Runtime's Python API

In [None]:
# create inference session
cnn_session = onnxruntime.InferenceSession("./weight_onnx/cnn.onnx")
encoder_session = onnxruntime.InferenceSession("./weight_onnx/encoder.onnx")
decoder_session = onnxruntime.InferenceSession("./weight_onnx/decoder.onnx")

In [None]:
def translate_onnx(img, session, max_seq_length=128, sos_token=1, eos_token=2):
    """data: BxCxHxW"""
    cnn_session, encoder_session, decoder_session = session

    cnn_input = {cnn_session.get_inputs()[0].name: img}
    src = cnn_session.run(None, cnn_input)[0]

    encoder_input = {encoder_session.get_inputs()[0].name: src}
    memory = encoder_session.run(None, encoder_input)[0]

    translated_sentence = [[sos_token] * img.shape[0]]
    max_length = 0

    while max_length <= max_seq_length and not all(
        np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
    ):
        tgt_inp = np.asarray(translated_sentence, dtype=np.int64)
        decoder_input = {
            decoder_session.get_inputs()[0].name: tgt_inp,
            decoder_session.get_inputs()[1].name: memory,
        }

        logits, memory = decoder_session.run(None, decoder_input)
        output = torch.from_numpy(logits)

        values, indices = torch.topk(output, 1)
        indices = indices[:, -1, 0]
        indices = indices.tolist()

        translated_sentence.append(indices)
        max_length += 1

    translated_sentence = np.asarray(translated_sentence).T

    return translated_sentence

In [None]:
session = (cnn_session, encoder_session, decoder_session)
img_np = img.detach().cpu().numpy()
s = translate_onnx(img_np, session)[0].tolist()
s = vocab.decode(s)
s

'Mâm non: 141 thí sinh'