<a href="https://colab.research.google.com/github/minh2210-hq/machine-learning-interview/blob/master/colabs/easyocr_tflite_report.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## SetUp

In [None]:
%%bash
pip install onnx
pip install onnxruntime
pip install git+https://github.com/onnx/onnx-tensorflow.git

In [None]:
import torch
import onnx
import onnxruntime
import numpy as np
import tensorflow as tf

import torch.nn.functional as F
from onnx_tf.backend import prepare

## Download converted ONNX model

In [None]:
%%bash
wget https://github.com/tulasiram58827/ocr_tflite/raw/main/models/easyocr_onnx/sequence_modeller.onnx
wget https://raw.githubusercontent.com/tulasiram58827/ocr_tflite/main/data/en.txt
wget https://github.com/tulasiram58827/ocr_tflite/raw/main/data/feature_extracted.pt

## Utilities

The below code is also part of [EasyOCR](https://github.com/JaidedAI/EasyOCR) repository.

In [None]:
dict_list = {}
dict_list['en'] = '/content/en.txt'
number = '0123456789'
symbol  = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÁÂÃÄÅÆÇÈÉÊËÍÎÑÒÓÔÕÖØÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿąęĮįıŁłŒœŠšųŽž'
characters = number+ symbol + chars

In [None]:
class CTCLabelConverter(object):
    """ Convert between text-label and text-index """

    def __init__(self, character, separator_list = {}, dict_pathlist = {}):
        # character (str): set of the possible characters.
        dict_character = list(character)

        self.dict = {}
        for i, char in enumerate(dict_character):
            self.dict[char] = i + 1

        self.character = ['[blank]'] + dict_character  # dummy '[blank]' token for CTCLoss (index 0)

        self.separator_list = separator_list
        separator_char = []
        for lang, sep in separator_list.items():
            separator_char += sep
        self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]

        ####### latin dict
        if len(separator_list) == 0:
            dict_list = []
            for lang, dict_path in dict_pathlist.items():
                try:
                    with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
                        word_count =  input_file.read().splitlines()
                    dict_list += word_count
                except:
                    pass
        else:
            dict_list = {}
            for lang, dict_path in dict_pathlist.items():
                with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
                    word_count =  input_file.read().splitlines()
                dict_list[lang] = word_count

        self.dict_list = dict_list

    def encode(self, text, batch_max_length=25):
        """convert text-label into text-index.
        input:
            text: text labels of each image. [batch_size]
        output:
            text: concatenated text index for CTCLoss.
                    [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
            length: length of each text. [batch_size]
        """
        length = [len(s) for s in text]
        text = ''.join(text)
        text = [self.dict[char] for char in text]

        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode_greedy(self, text_index, length):
        """ convert text-index into text-label. """
        texts = []
        index = 0
        for l in length:
            t = text_index[index:index + l]

            char_list = []
            for i in range(l):
                # removing repeated characters and blank (and separator).
                if t[i] not in self.ignore_idx and (not (i > 0 and t[i - 1] == t[i])):
                    char_list.append(self.character[t[i]])
            text = ''.join(char_list)

            texts.append(text)
            index += l
        return texts

In [None]:
def post_process(preds, character, separator_list, dict_list, batch_size=1):
    result = []
    ignore_idx = []
    converter = CTCLabelConverter(character, separator_list, dict_list)
    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
    ######## filter ignore_char, rebalance
    preds_prob = F.softmax(preds, dim=2)
    preds_prob = preds_prob.cpu().detach().numpy()
    preds_prob[:,:,ignore_idx] = 0.
    pred_norm = preds_prob.sum(axis=2)
    preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1)
    preds_prob = torch.from_numpy(preds_prob).float().to('cpu')
    # if decoder == 'greedy':
    # Select max probabilty (greedy decoding) then decode index to character
    _, preds_index = preds_prob.max(2)
    preds_index = preds_index.view(-1)
    preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
    preds_max_prob, _ = preds_prob.max(dim=2)
    for pred, pred_max_prob in zip(preds_str, preds_max_prob):
        confidence_score = pred_max_prob.cumprod(dim=0)[-1]
        result.append([pred, confidence_score.item()])
    return result

## ONNX Inference

In [None]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy()

In [None]:
data = torch.load('/content/feature_extracted.pt')

In [None]:
data.shape

torch.Size([1, 41, 512])

In [None]:
# Load sequence modeller of ONNX model
onnx_model = onnx.load("sequence_modeller.onnx")
# Check the model
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession("sequence_modeller.onnx")

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}
ort_outs = ort_session.run(None, ort_inputs)

final_prediction = ort_outs[0]

final_prediction = torch.from_numpy(final_prediction)
result = post_process(final_prediction, characters, {}, dict_list)
result

[['Available', 0.9877454042434692]]

**ONNX output is matching with the actual model output**

## Convert to Tensorflow Graph

In [None]:
onnx_model = onnx.load('sequence_modeller.onnx')
tf_rep = prepare(onnx_model)
tf_rep.export_graph('sequence_modeller.pb')

**Conversion to Tensorflow Graph Succesful**

## TFLite Conversion

In [None]:
loaded = tf.saved_model.load('sequence_modeller.pb')

concrete_func = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

concrete_func.inputs[0].set_shape([1, 100, 512])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tf_lite_model = converter.convert()
open('sequence_modeller.tflite', 'wb').write(tf_lite_model)

ConverterError: ignored

**This is the actual problem**