In [None]:
import os,sys
import cv2
import math
import numpy as np
import itertools
import torch
import torch.nn.functional as F
from PIL import Image
sys.path.append('../ctpn.pytorch/')
sys.path.append('../crnn.pytorch/')
from ctpn import config
from ctpn.ctpn import CTPN_Model
from ctpn.utils import gen_anchor, transform_bbox, clip_bbox, filter_bbox, nms, TextProposalConnectorOriented
import crnn
from config import cfg
import matplotlib.pyplot as plt

In [None]:
def get_text_boxes(image, display = True, prob_thresh = 0.5):
    h, w = image.shape[:2]
    rescale_fac = max(h, w) / 1000
    if rescale_fac > 1.0:
        h = int(h / rescale_fac)
        w = int(w / rescale_fac)
        image = cv2.resize(image, (w,h))
        h, w = image.shape[:2]
    image_c = image.copy()
    image = image.astype(np.float32) - config.IMAGE_MEAN
    image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float().to(device)

    with torch.no_grad():
        cls, regr = model_ctpn(image)
        cls_prob = F.softmax(cls, dim=-1).cpu().numpy()
        regr = regr.cpu().numpy()
        anchor = gen_anchor((int(h / 16), int(w / 16)), 16)
        bbox = transform_bbox(anchor, regr)
        bbox = clip_bbox(bbox, [h, w])

        fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]
        select_anchor = bbox[fg, :]
        select_score = cls_prob[0, fg, 1]
        select_anchor = select_anchor.astype(np.int32)
        keep_index = filter_bbox(select_anchor, 16)

        select_anchor = select_anchor[keep_index]
        select_score = select_score[keep_index]
        select_score = np.reshape(select_score, (select_score.shape[0], 1))
        nmsbox = np.hstack((select_anchor, select_score))
        keep = nms(nmsbox, 0.3)
        select_anchor = select_anchor[keep]
        select_score = select_score[keep]

        textConn = TextProposalConnectorOriented()
        text = textConn.get_text_lines(select_anchor, select_score, [h, w])
        if display:
            for i in text:
                s = str(round(i[-1] * 100, 2)) + '%'
                i = [int(j) for j in i]
                cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)
                cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)
                cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)
                cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)
                #cv2.putText(image_c, s, (i[0]+13, i[1]+13), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,0,0), 2, cv2.LINE_AA)

        return text, image_c, rescale_fac
    
def load_image(image):
    h, w = image.shape[:2]
    if h != 32 and h < w:
        new_w = int(w * 32 / h)
        image = cv2.resize(image, (new_w, 32))
    if w != 32 and w < h:
        new_h = int(h * 32 / w)
        image = cv2.resize(image, (32, new_h))

    image = Image.fromarray(image).convert('L')
    # cv2.imwrite(image_path, np.array(image))
    image = np.array(image)
    if h < w:
        image = np.array(image).T  # [W,H]
    image = image.astype(np.float32) / 255.
    image -= 0.5
    image /= 0.5
    image = image[np.newaxis, np.newaxis, :, :]  # [B,C,W,H]
    return image

def get_minAreaRect(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.bitwise_not(gray)
    thresh = cv2.threshold(gray, 0, 255,
        cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
    coords = np.column_stack(np.where(thresh > 0))
    return cv2.minAreaRect(coords)

def rotate_bound(image, angle):
    (h, w) = image.shape[:2]
    (cX, cY) = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])
    nW = int((h * sin) + (w * cos))
    nH = h
    M[0, 2] += (nW / 2) - cX
    M[1, 2] += (nH / 2) - cY
    return cv2.warpAffine(image, M, (nW, nH),flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)

def rotate_image(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = cv2.bitwise_not(gray)
    thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
    coords = np.column_stack(np.where(thresh > 0))
    angle = cv2.minAreaRect(coords)[-1]
    if angle < -45:
        angle = -(90 + angle)
    else:
        angle = -angle
    (h, w) = img.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
    return rotated

In [None]:
device = torch.device('cpu')
weights_ctpn = '/notebook/aigo-ocr/ctpn.pytorch/weights/ctpn.pth'
model_ctpn = CTPN_Model().to(device)
model_ctpn.load_state_dict(torch.load(weights_ctpn, map_location=device)['model_state_dict'])
model_ctpn.eval()

In [None]:
weights_crnn = '/notebook/aigo-ocr/crnn.pytorch/models/crnn.horizontal.061.pth'
alpha = cfg.word.get_all_words()
model_crnn = crnn.CRNN(num_classes=len(alpha))
model_crnn.load_state_dict(torch.load(weights_crnn, map_location=device)['model'])
model_crnn.eval()

In [None]:
img_path = '/notebook/images/insurance/IMG_5620.jpg'
img = cv2.imread(img_path)
img_rotated = rotate_image(img)
text, out_img, scale = get_text_boxes(img_rotated, prob_thresh=0.1)
plt.figure(figsize=(20,40))
plt.imshow(out_img)

In [None]:
padding = 10
for k in range(text.shape[0]):
    X = text[k][:8][::2]
    Y = text[k][:8][1::2]
    x1, y1 = int(min(X) * scale) - padding, int(min(Y) * scale) + 5
    x2, y2 = int(max(X) * scale) + padding, int(max(Y) * scale) - 5
    img_crop = img[y1:y2, x1:x2]
    img_crnn = load_image(img_crop)
    img_crnn = torch.FloatTensor(img_crnn)
    predict = model_crnn(img_crnn)[0].detach().numpy()
    label = np.argmax(predict[:], axis=1)
    label = [alpha[class_id] for class_id in label]
    label = [k for k, g in itertools.groupby(list(label))]
    label = ''.join(label).replace(' ', '')
    print(label)
    plt.figure(figsize=(20,40))
    plt.imshow(img_crop)