In [None]:
'''
This notebook loads the easyocr module and converts the ocr model to inferentia compatible neuron model.
'''

In [14]:
import cv2
import torch
import easyocr
import torch_neuron
from matplotlib import pyplot as plt
from easyocr.detection import get_detector

def save_model():
    '''
    This function initiates the OCR module. The original OCR model is converted to neuron model and saved to local directory.
    '''
    ocr_reader = easyocr.Reader(['en'], detector=True, recognizer=False, gpu=False, 
                                download_enabled=False, model_storage_directory='model_file', 
                                user_network_directory='user_network')
    #load original model
    model = get_detector(r'model_file/craft_mlt_25k.pth')
    dummy_image = torch.zeros([1, 3, 224, 224], dtype=torch.float32)
    #compile original model to neuron model and save it
    neuron_model = torch.neuron.trace(model, example_inputs=[dummy_image])
    neuron_model.save("ocr_neuron.pt")
    print('Model Saved')
    return ocr_reader

In [15]:
def infer_image(image, ocr_reader):
    '''
    This function calls the customized detect function and passes the neuron model to generate inference.
    '''
    #load neuron model
    neuron_model = torch.jit.load('ocr_neuron.pt')
    resized_image = cv2.resize(image, (224,224), interpolation = cv2.INTER_AREA)
    #customized detect function that takes neuron model as input
    result = ocr_reader.detect(resized_image, net=neuron_model)
    return result

In [16]:
def plot_bbox(image, result):
    '''
    This functions plots bounding boxes around detected texts in an image.
    '''
    resized_image = cv2.resize(image, (224,224), interpolation = cv2.INTER_AREA)
    for coord in result[0][0]:
        image = cv2.rectangle(resized_image, (coord[0],coord[2]), (coord[1],coord[3]), (0, 255, 0), 2)
    plt.imshow(resized_image)
    plt.title('Detected Text')
    plt.show()

In [17]:
#save the model
neuron_model, ocr_reader = save_model()

Using CPU. Note: This module is much faster with a GPU.


AttributeError: module 'torch' has no attribute 'neuron'

In [None]:
image = cv2.imread('images/chinese.jpg')
#detect text
result = infer_image(image, ocr_reader)
#plot bounding box
plot_bbox(image,result)