In [1]:
import sys
import warnings
import onnx
import torch
import torch.onnx

warnings.filterwarnings("ignore")
sys.path.append('..')

from src.model import DefaultModel

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cpu


## Load model network and weight

In [3]:
detector_cfg = '../configs/craft_config.yaml'
detector_model = '../models/text_detector/craft_mlt_25k.pth'
recognizer_cfg = '../configs/star_config.yaml'
recognizer_model = '../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth'

In [4]:
model = DefaultModel(detector_cfg, detector_model, 
                     recognizer_cfg, recognizer_model)

Loading weights from checkpoint (../models/text_detector/craft_mlt_25k.pth)
Loading weights from checkpoint (../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth)


# Detector

## Exporter Model
Batch Size X Channel X Height X Width

In [5]:
detector_dummy_input = torch.randn(1, 3, 1280, 720)

In [6]:
model.detector(detector_dummy_input)

(tensor([[[[0.0010, 0.0002],
           [0.0085, 0.0020],
           [0.0010, 0.0002],
           ...,
           [0.0010, 0.0002],
           [0.0019, 0.0016],
           [0.0010, 0.0002]],
 
          [[0.0022, 0.0016],
           [0.0010, 0.0002],
           [0.0010, 0.0002],
           ...,
           [0.0010, 0.0002],
           [0.0010, 0.0002],
           [0.0010, 0.0002]],
 
          [[0.0010, 0.0002],
           [0.0010, 0.0002],
           [0.0010, 0.0002],
           ...,
           [0.0010, 0.0002],
           [0.0010, 0.0002],
           [0.0010, 0.0002]],
 
          ...,
 
          [[0.0010, 0.0002],
           [0.0010, 0.0002],
           [0.0010, 0.0002],
           ...,
           [0.0010, 0.0002],
           [0.0010, 0.0002],
           [0.0010, 0.0002]],
 
          [[0.0128, 0.0013],
           [0.0056, 0.0013],
           [0.0065, 0.0015],
           ...,
           [0.0010, 0.0002],
           [0.0010, 0.0002],
           [0.0020, 0.0006]],
 
          [[0.0065

In [7]:
out_detector_model = '../models/text_detector/craft.onnx'

In [8]:
# Export the model
torch.onnx.export(model.detector,            
                  detector_dummy_input,
                  out_detector_model,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input' : {0:'batch_size', 2:'height', 3:'width'},
                                'output' : {0:'batch_size'}})

## Inspecting Model

In [9]:
# Load the ONNX model
onnx_model = onnx.load(out_detector_model)

# Check that the IR is well formed
onnx.checker.check_model(onnx_model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(onnx_model.graph))

graph torch-jit-export (
  %input[FLOAT, batch_sizex3xheightxwidth]
) initializers (
  %basenet.slice5.1.weight[FLOAT, 1024x512x3x3]
  %basenet.slice5.1.bias[FLOAT, 1024]
  %basenet.slice5.2.weight[FLOAT, 1024x1024x1x1]
  %basenet.slice5.2.bias[FLOAT, 1024]
  %conv_cls.0.weight[FLOAT, 32x32x3x3]
  %conv_cls.0.bias[FLOAT, 32]
  %conv_cls.2.weight[FLOAT, 32x32x3x3]
  %conv_cls.2.bias[FLOAT, 32]
  %conv_cls.4.weight[FLOAT, 16x32x3x3]
  %conv_cls.4.bias[FLOAT, 16]
  %conv_cls.6.weight[FLOAT, 16x16x1x1]
  %conv_cls.6.bias[FLOAT, 16]
  %conv_cls.8.weight[FLOAT, 2x16x1x1]
  %conv_cls.8.bias[FLOAT, 2]
  %299[FLOAT, 64x3x3x3]
  %300[FLOAT, 64]
  %302[FLOAT, 64x64x3x3]
  %303[FLOAT, 64]
  %305[FLOAT, 128x64x3x3]
  %306[FLOAT, 128]
  %308[FLOAT, 128x128x3x3]
  %309[FLOAT, 128]
  %311[FLOAT, 256x128x3x3]
  %312[FLOAT, 256]
  %314[FLOAT, 256x256x3x3]
  %315[FLOAT, 256]
  %317[FLOAT, 256x256x3x3]
  %318[FLOAT, 256]
  %320[FLOAT, 512x256x3x3]
  %321[FLOAT, 512]
  %323[FLOAT, 512x512x3x3]
  %324[FLOAT

# Recognizer

ERROR UNSOLVED BY CREATOR
https://github.com/pytorch/pytorch/issues/27212

## Exporter Model
Batch Size X Channel X Height X Width

In [10]:
recognizer_dummy_input = torch.randn(100, 1, 32, 100)
recognizer_dummy_text = torch.LongTensor(100, 26).fill_(0)

In [11]:
model.recognizer.module(recognizer_dummy_input, recognizer_dummy_text)

tensor([[[ -6.9061,  -7.5613,  -5.4793,  ...,  -2.4262,  -4.9112,  -4.1666],
         [-11.9275, -13.6367, -11.6911,  ..., -10.5464, -11.6802, -11.4881],
         [-15.8177, -16.6920, -15.7986,  ..., -15.0340, -15.5713, -15.5707],
         ...,
         [-13.8261,  -2.1892, -13.1530,  ..., -13.9286, -13.0425, -13.7967],
         [-13.8304,  -2.1870, -13.1481,  ..., -13.9520, -13.0352, -13.8002],
         [-13.8197,  -2.1547, -13.1142,  ..., -13.9470, -13.0154, -13.7876]],

        [[ -8.2746,  -8.6355,  -6.8791,  ...,  -6.2013,  -8.2145,  -6.9954],
         [-13.3459, -18.1712, -12.9131,  ..., -12.4395, -12.8243, -11.7817],
         [-15.9471, -18.3970, -13.9544,  ..., -14.9989, -15.8966, -14.6564],
         ...,
         [-13.8625,  -3.7098, -11.3744,  ..., -13.9167, -12.4986, -14.3229],
         [-13.8717,  -3.6356, -11.3872,  ..., -13.9434, -12.5487, -14.3171],
         [-13.8459,  -3.6648, -11.3504,  ..., -13.9306, -12.5294, -14.2871]],

        [[ -9.2234,  -8.9115,  -7.3134,  ...

In [12]:
out_recognizer_model = '../models/text_recognizer/star.onnx'

In [13]:
# Export the model
torch.onnx.export(model.recognizer.module,            
                  (recognizer_dummy_input, recognizer_dummy_text),
                  out_recognizer_model,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input' : {0:'batch_size', 2:'height', 3:'width'},
                                'output' : {0:'batch_size'}})

RuntimeError: Exporting the operator grid_sampler to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

## Inspecting Model

In [None]:
# Load the ONNX model
onnx_model = onnx.load(out_recognizer_model)

# Check that the IR is well formed
onnx.checker.check_model(onnx_model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(onnx_model.graph))