In [7]:
import os
import matplotlib.pyplot as plt
from PIL import Image
from tool.config import Cfg
from tool.translate import build_model, process_input, translate
import torch
import onnxruntime
import numpy as np

In [10]:
config = Cfg.load_config_from_file('ocr_model_cds_seq2seq/custom_config_seq2seq_12112025.yml')
config['cnn']['pretrained']=False
config['device'] = 'cpu'
model, vocab = build_model(config)
weight_path = 'ocr_model_cds_seq2seq/seq2seq.pth'



In [11]:
# load weight
model.load_state_dict(torch.load(weight_path, map_location=torch.device(config['device'])))
model = model.eval() 

## Export CNN part

In [4]:
def convert_cnn_part(img, save_path, model, max_seq_length=128, sos_token=1, eos_token=2): 
    with torch.no_grad(): 
        src = model.cnn(img)
        torch.onnx.export(model.cnn, img, save_path, export_params=True, opset_version=12, do_constant_folding=True, verbose=True, input_names=['img'], output_names=['output'], dynamic_axes={'img': {3: 'lenght'}, 'output': {0: 'channel'}})
    
    return src

In [6]:
img = torch.rand(1, 3, 32, 475)
src = convert_cnn_part(img, './converted_weights/cnn.onnx', model)

  torch.onnx.export(model.cnn, img, save_path, export_params=True, opset_version=12, do_constant_folding=True, verbose=True, input_names=['img'], output_names=['output'], dynamic_axes={'img': {3: 'lenght'}, 'output': {0: 'channel'}})


Exported graph: graph(%img : Float(1, 3, 32, *, strides=[45600, 15200, 475, 1], requires_grad=0, device=cpu),
      %model.last_conv_1x1.weight : Float(256, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cpu),
      %model.last_conv_1x1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
      %onnx::Conv_180 : Float(64, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu),
      %onnx::Conv_181 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %onnx::Conv_183 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %onnx::Conv_184 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %onnx::Conv_186 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %onnx::Conv_187 : Float(128, strides=[1], requires_grad=0, device=cpu),
      %onnx::Conv_189 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cpu),
      %onnx::Conv_190 : Float(128, strides=[1], requires_grad=0,

## Export encoder part

In [7]:
def convert_encoder_part(model, src, save_path): 
    encoder_outputs, hidden = model.transformer.encoder(src) 
    torch.onnx.export(model.transformer.encoder, src, save_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['src'], output_names=['encoder_outputs', 'hidden'], dynamic_axes={'src':{0: "channel_input"}, 'encoder_outputs': {0: 'channel_output'}}) 
    return hidden, encoder_outputs

In [8]:
hidden, encoder_outputs = convert_encoder_part(model, src, './converted_weights/encoder.onnx')

  torch.onnx.export(model.transformer.encoder, src, save_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['src'], output_names=['encoder_outputs', 'hidden'], dynamic_axes={'src':{0: "channel_input"}, 'encoder_outputs': {0: 'channel_output'}})


## Export decoder part

In [9]:
def convert_decoder_part(model, tgt, hidden, encoder_outputs, save_path):
    tgt = tgt[-1]
    
    torch.onnx.export(model.transformer.decoder,
        (tgt, hidden, encoder_outputs),
        save_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['tgt', 'hidden', 'encoder_outputs'],
        output_names=['output', 'hidden_out', 'last'],
        dynamic_axes={'encoder_outputs':{0: "channel_input"},
                    'last': {0: 'channel_output'}})

In [10]:
device = img.device
tgt = torch.LongTensor([[1] * len(img)]).to(device)

In [11]:
convert_decoder_part(model, tgt, hidden, encoder_outputs, './converted_weights/decoder.onnx')

  torch.onnx.export(model.transformer.decoder,
  assert (output == hidden).all()


## Load and check model

In [1]:
import onnx

In [2]:
cnn = onnx.load('./converted_weights/cnn.onnx')
decoder = onnx.load('./converted_weights/encoder.onnx')
encoder = onnx.load('./converted_weights/decoder.onnx')

In [3]:
# confirm model has valid schema
onnx.checker.check_model(cnn)
onnx.checker.check_model(decoder)
onnx.checker.check_model(encoder)

In [4]:
# # Print a human readable representation of the graph
onnx.helper.printable_graph(encoder.graph)

  onnx.helper.printable_graph(encoder.graph)


'graph main_graph (\n  %tgt[INT64, 1]\n  %hidden[FLOAT, 1x256]\n  %encoder_outputs[FLOAT, channel_inputx1x512]\n) initializers (\n  %attention.attn.bias[FLOAT, 256]\n  %embedding.weight[FLOAT, 233x256]\n  %fc_out.weight[FLOAT, 233x1024]\n  %fc_out.bias[FLOAT, 233]\n  %onnx::MatMul_118[FLOAT, 768x256]\n  %onnx::MatMul_119[FLOAT, 256x1]\n  %onnx::GRU_137[FLOAT, 1x768x768]\n  %onnx::GRU_138[FLOAT, 1x768x256]\n  %onnx::GRU_139[FLOAT, 1x1536]\n) {\n  %/Unsqueeze_output_0 = Unsqueeze[axes = [0]](%tgt)\n  %/embedding/Gather_output_0 = Gather(%embedding.weight, %/Unsqueeze_output_0)\n  %/attention/Shape_output_0 = Shape(%encoder_outputs)\n  %/attention/Constant_output_0 = Constant[value = <Scalar Tensor []>]()\n  %/attention/Gather_output_0 = Gather[axis = 0](%/attention/Shape_output_0, %/attention/Constant_output_0)\n  %/attention/Unsqueeze_output_0 = Unsqueeze[axes = [1]](%hidden)\n  %/attention/Constant_1_output_0 = Constant[value = <Tensor>]()\n  %/attention/Unsqueeze_1_output_0 = Unsqueez

## Inference directly

In [12]:
img = Image.open('./sample/35944.png')
img = process_input(img, config['dataset']['image_height'], 
                config['dataset']['image_min_width'], config['dataset']['image_max_width'])  
img = img.to(config['device'])

In [13]:
s = translate(img, model)[0].tolist()
s = vocab.decode(s)
s

'Mầm non: 141 thí sinh'

## Inference with ONNX Runtime's Python API

In [14]:
import onnxruntime

# Check available providers
available_providers = onnxruntime.get_available_providers()
print(f"Available ONNX Runtime providers: {available_providers}")

# Function to create session with fallback to CPU if GPU fails
def create_session_with_fallback(model_path, preferred_providers=None):
    """Create ONNX session, trying GPU first, falling back to CPU if GPU fails"""
    if preferred_providers is None:
        # Try GPU providers first
        if 'CUDAExecutionProvider' in available_providers:
            preferred_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        elif 'TensorrtExecutionProvider' in available_providers:
            preferred_providers = ['TensorrtExecutionProvider', 'CPUExecutionProvider']
        else:
            preferred_providers = ['CPUExecutionProvider']
    
    try:
        # Try to create session with preferred providers
        session = onnxruntime.InferenceSession(model_path, providers=preferred_providers)
        # Check which provider was actually used
        actual_provider = session.get_providers()[0]
        if 'CUDA' in actual_provider or 'Tensorrt' in actual_provider:
            print(f"✓ {model_path}: Using {actual_provider} (GPU)")
        else:
            print(f"✓ {model_path}: Using {actual_provider} (CPU)")
        return session
    except Exception as e:
        # If GPU fails, fall back to CPU
        print(f"⚠ GPU initialization failed for {model_path}, falling back to CPU")
        print(f"  Error: {str(e)[:100]}...")
        session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
        print(f"✓ {model_path}: Using CPUExecutionProvider (CPU)")
        return session

# Create inference sessions with automatic GPU/CPU fallback
print("\nCreating ONNX inference sessions...")
cnn_session = create_session_with_fallback("./converted_weights/cnn.onnx")
encoder_session = create_session_with_fallback("./converted_weights/encoder.onnx")
decoder_session = create_session_with_fallback("./converted_weights/decoder.onnx")
print("\nAll sessions created successfully!")

Available ONNX Runtime providers: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']

Creating ONNX inference sessions...
✓ ./converted_weights/cnn.onnx: Using CUDAExecutionProvider (GPU)
✓ ./converted_weights/encoder.onnx: Using CUDAExecutionProvider (GPU)
✓ ./converted_weights/decoder.onnx: Using CUDAExecutionProvider (GPU)

All sessions created successfully!


In [None]:
def translate_onnx(img, session, max_seq_length=128, sos_token=1, eos_token=2):
    """data: BxCxHxW"""
    cnn_session, encoder_session, decoder_session = session

    cnn_input = {cnn_session.get_inputs()[0].name: img}
    src = cnn_session.run(None, cnn_input)[0]

    encoder_input = {encoder_session.get_inputs()[0].name: src}
    encoder_outputs, hidden = encoder_session.run(None, encoder_input)

    translated_sentence = [[sos_token] * img.shape[0]]
    max_length = 0

    while max_length <= max_seq_length and not all(
        np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
    ):
        # Get the last token from the translated sentence (tgt should be the last token)
        # translated_sentence[-1] is a list of tokens for all batch items at the last timestep
        tgt_inp = np.asarray(translated_sentence[-1], dtype=np.int64)
        # Ensure it's 1D: (batch_size,)
        if len(tgt_inp.shape) > 1:
            tgt_inp = tgt_inp.flatten()
        
        decoder_input = {
            decoder_session.get_inputs()[0].name: tgt_inp,
            decoder_session.get_inputs()[1].name: hidden,
            decoder_session.get_inputs()[2].name: encoder_outputs,
        }

        logits, hidden, _ = decoder_session.run(None, decoder_input)
        output = torch.from_numpy(logits)

        values, indices = torch.topk(output, 1)
        # Get the prediction for the current token
        # indices shape is (batch_size, 1) after topk, so we take the first (and only) column
        if len(indices.shape) == 2:
            indices = indices[:, 0]
        else:
            indices = indices.squeeze()
        indices = indices.tolist()

        translated_sentence.append(indices)
        max_length += 1

    translated_sentence = np.asarray(translated_sentence).T

    return translated_sentence

In [17]:
import time

session = (cnn_session, encoder_session, decoder_session)

start = time.perf_counter()

img_np = img.detach().cpu().numpy()
s = translate_onnx(img_np, session)[0].tolist()
s = vocab.decode(s)

end = time.perf_counter()
print(f"Time taken: {end - start} seconds")

s

Time taken: 0.1838204250088893 seconds


'Mầm non: 141 thí sinh'