In [None]:
!pip install PyMuPDF ultralytics

In [None]:
!pip install paddlepaddle-gpu paddleocr

In [None]:
# import torch
# detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'yolov5/runs/train/yolov5s-custom-detection/weights/best.pt', force_reload=True)
# structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'yolov5/runs/train/yolov5s-custom-structure/weights/best.pt', force_reload=True)

from ultralytics import YOLO
device = 'cuda:0'
detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device)
structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device)

In [None]:
import cv2

# imgsz=640

# def table_detection(filename):
#     image = cv2.imread(filename)
#     pred = detection_model(image, size=imgsz)
#     pred = pred.xywhn[0]
#     result = pred.cpu().numpy()
#     return result

# def table_structure(filename):
#     imgsz = 1024
#     image = cv2.imread(filename)
#     pred = structure_model(image, size=imgsz)
#     pred = pred.xywhn[0]
#     result = pred.cpu().numpy()
#     return result

def table_detection(filename):
    imgsz = 800
    image = cv2.imread(filename)
    pred = detection_model.predict(image, imgsz=imgsz)
    pred = pred[0].boxes
    result = pred.cpu().numpy()
    result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
    return result_list

def table_structure(filename):
    imgsz = 1024
    image = cv2.imread(filename)
    pred = structure_model.predict(image, imgsz=imgsz)
    pred = pred[0].boxes
    result = pred.cpu().numpy()
    result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
    return result_list

In [None]:
detection_class_names = ['table', 'table rotated']

def crop_image(filename, detection_result):
    crop_filenames = []
    image = cv2.imread(filename)
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)
    for i, result in enumerate(detection_result):
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]

        # x1 = max(0, int((min_x-w/2-0.02)*width))  # TODO expand 2%
        # y1 = max(0, int((min_y-h/2-0.02)*height))  # TODO expand 2%
        # x2 = min(width, int((min_x+w/2+0.02)*width))  # TODO expand 2%
        # y2 = min(height, int((min_y+h/2+0.02)*height))  # TODO expand 2%
        x1 = max(0, int((min_x-w/2)*width)-10)  # TODO expand 10px
        y1 = max(0, int((min_y-h/2)*height)-10)  # TODO expand 10px
        x2 = min(width, int((min_x+w/2)*width)+10)  # TODO expand 10px
        y2 = min(height, int((min_y+h/2)*height)+10)  # TODO expand 10px
        # print(x1, y1, x2, y2)
        crop_image = image[y1:y2, x1:x2, :]
        crop_filename = filename[:-4]+'_'+str(i)+'_'+detection_class_names[class_id]+filename[-4:]
        crop_filenames.append(crop_filename)
        cv2.imwrite(crop_filename, crop_image)
    return crop_filenames

In [None]:
# import os
# import json
# import requests

# url = 'https://[ALL-IN-ONE-AI-URL]/inference?endpoint_name=[ENDPOINT-NAME]'

# headers = {'Content-Type': 'image/png'}

# def ocr(img_path):
#     words_filename = img_path[:-4]+'_words.json'

#     if os.path.exists(words_filename):
#         return words_filename

#     if not img_path.endswith('png'):
#         headers['Content-Type'] = 'image/'+img_path.split('.')[-1]
#     # print(headers)

#     with open(img_path, 'rb') as f:
#         data = f.read()
#     response = requests.post(url, headers=headers, data=data)
#     # print(response)

#     result = json.loads(response.text)
#     # print(result)

#     new_result = []
#     for label, bbox in zip(result['label'], result['bbox']):
#         new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label})

#     json.dump(new_result, open(words_filename, 'w'), ensure_ascii=False)
#     return words_filename

In [None]:
import os
import json
from paddleocr import PaddleOCR

ocr_model = PaddleOCR(use_angle_cls=True, lang="ch", det_limit_side_len=1920)  # TODO use large det_limit_side_len to get better OCR result

def ocr(img_path):
    words_filename = img_path[:-4]+'_words.json'

    # if os.path.exists(words_filename):
    #     return words_filename

    result = ocr_model.ocr(img_path, cls=True)
    result = result[0]
    new_result = []
    if result is not None:
        bounding_boxes = [line[0] for line in result]
        txts = [line[1][0] for line in result]
        scores = [line[1][1] for line in result]
        # print('txts:', txts)
        # print('scores:', scores)
        # print('bounding_boxes:', bounding_boxes)
        for label, bbox in zip(txts, bounding_boxes):
            new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label})

    json.dump(new_result, open(words_filename, 'w'), ensure_ascii=False)

    return words_filename

In [None]:
def visualize_structure(filename, structure_result):
    image = cv2.imread(filename)
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)
    for i, result in enumerate(structure_result):
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]

        x1 = int((min_x-w/2)*width)
        y1 = int((min_y-h/2)*height)
        x2 = int((min_x+w/2)*width)
        y2 = int((min_y+h/2)*height)
        # print(x1, y1, x2, y2)

        if score >= 0.5:
            cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,0,255))
            cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
    new_filename = filename[:-4]+'_structure'+filename[-4:]
    cv2.imwrite(new_filename, image)
    return new_filename

In [None]:
import json
import postprocess

structure_class_names = [
    'table', 'table column', 'table row', 'table column header',
    'table projected row header', 'table spanning cell', 'no object'
]
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
structure_class_thresholds = {
    "table": 0.5,
    "table column": 0.5,
    "table row": 0.5,
    "table column header": 0.5,
    "table projected row header": 0.5,
    "table spanning cell": 0.5,
    "no object": 10
}

def convert_stucture(words_filename, filename, structure_result):
    image = cv2.imread(filename)
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)

    bboxes = []
    scores = []
    labels = []
    for i, result in enumerate(structure_result):
        class_id = int(result[5])
        score = float(result[4])
        min_x = result[0]
        min_y = result[1]
        w = result[2]
        h = result[3]

        x1 = int((min_x-w/2)*width)
        y1 = int((min_y-h/2)*height)
        x2 = int((min_x+w/2)*width)
        y2 = int((min_y+h/2)*height)
        # print(x1, y1, x2, y2)

        bboxes.append([x1, y1, x2, y2])
        scores.append(score)
        labels.append(class_id)

    table_objects = []
    for bbox, score, label in zip(bboxes, scores, labels):
        table_objects.append({'bbox': bbox, 'score': score, 'label': label})
    # print('table_objects:', table_objects)

    table = {'objects': table_objects, 'page_num': 0}

    table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
    if len(table_class_objects) > 1:
        table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
    try:
        table_bbox = list(table_class_objects[0]['bbox'])
    except:
        table_bbox = (0,0,1000,1000)
    # print('table_class_objects:', table_class_objects)
    # print('table_bbox:', table_bbox)

    page_tokens = json.load(open(words_filename, 'r'))
    tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
    # print('tokens_in_table:', tokens_in_table)

    table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)

    return table_structures, cells, confidence_score

In [None]:
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw, ImageFont

def visualize_cells(filename, table_structures, cells):
    image = cv2.imread(filename)
    width = image.shape[1]
    height = image.shape[0]
    # print(width, height)
    empty_image = np.zeros((height, width, 3), np.uint8)
    empty_image.fill(255)
    empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(empty_image)
    fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8")

    num_cols = len(table_structures['columns'])
    num_rows = len(table_structures['rows'])
    data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)]
    for i, cell in enumerate(cells):
        bbox = cell['bbox']
        x1 = int(bbox[0])
        y1 = int(bbox[1])
        x2 = int(bbox[2])
        y2 = int(bbox[3])
        col_num = cell['column_nums'][0]
        row_num = cell['row_nums'][0]
        spans = cell['spans']
        text = ''
        for span in spans:
            if 'text' in span:
                text += span['text']
        data_rows[row_num][col_num] = text

        # print('text:', text)
        text_len = len(text)
        # print('text_len:', text_len)
        cell_width = x2-x1
        # print('cell_width:', cell_width)
        num_per_line = cell_width//10
        # print('num_per_line:', num_per_line)
        if num_per_line != 0:
            line_num = text_len//num_per_line
        else:
            line_num = 0
        # print('line_num:', line_num)
        new_text = text[:num_per_line]+'\n'
        for j in range(line_num):
            new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n'
        # print('new_text:', new_text)
        text = new_text

        cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0))
        cv2.putText(image, str(row_num)+'-'+str(col_num), (x1, y1+30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))

        # cv2.rectangle(empty_image, (x1, y1), (x2, y2), color=(0,0,255))
        # cv2.putText(empty_image, str(row_num)+'-'+str(col_num), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
        # cv2.putText(empty_image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
        draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0))
        draw.text((x1-20, y1), str(row_num)+'-'+str(col_num), (255,0,0), font=fontStyle)
        draw.text((x1, y1), text, (0,0,255), font=fontStyle)
    new_filename = filename[:-4]+'_cells'+filename[-4:]
    cv2.imwrite(new_filename, image)
    reconstruct_filename = filename[:-4]+'_reconstruct'+filename[-4:]
    empty_image = cv2.cvtColor(np.asarray(empty_image), cv2.COLOR_RGB2BGR)
    cv2.imwrite(reconstruct_filename, empty_image)
    excel_filename = filename[:-4]+'.xlsx'
    data = pd.DataFrame(data_rows)
    data.to_excel(excel_filename, index=False, header=False)
    return new_filename

In [None]:
import xml.etree.ElementTree as ET

def cells_to_html(cells):
    cells = sorted(cells, key=lambda k: min(k['column_nums']))
    cells = sorted(cells, key=lambda k: min(k['row_nums']))

    table = ET.Element("table")
    table.set('style', 'border-collapse: collapse;')
    current_row = -1

    for cell in cells:
        this_row = min(cell['row_nums'])

        attrib = {}
        colspan = len(cell['column_nums'])
        if colspan > 1:
            attrib['colspan'] = str(colspan)
        rowspan = len(cell['row_nums'])
        if rowspan > 1:
            attrib['rowspan'] = str(rowspan)
        if this_row > current_row:
            current_row = this_row
            if 'column header' in cell:
                cell_tag = "th"
                row = ET.SubElement(table, "thead")
                row.set('style', 'border: 1px solid black;')
            else:
                cell_tag = "td"
                row = ET.SubElement(table, "tr")
                row.set('style', 'border: 1px solid black;')
        tcell = ET.SubElement(row, cell_tag, attrib=attrib)
        tcell.set('style', 'border: 1px solid black; padding: 5px;')
        tcell.text = ''
        for span in cell['spans']:
            tcell.text += span['text']+'\n'

    return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))

In [None]:
sample_filename = 'zh_val_0.jpg'
# sample_filename = 'demo/image.png'

sample_detection_result = table_detection(sample_filename)
# print('sample_detection_result:', sample_detection_result)

sample_crop_filenames = crop_image(sample_filename, sample_detection_result)
# print('sample_crop_filenames:', sample_crop_filenames)

for crop_filename in sample_crop_filenames:
    words_filename = ocr(crop_filename)
    # print('words_filename:', words_filename)
    structure_result = table_structure(crop_filename)
    # print('structure_result:', structure_result)
    structure_filename = visualize_structure(crop_filename, structure_result)
    # print('structure_filename:', structure_filename)
    table_structures, cells, confidence_score = convert_stucture(words_filename, crop_filename, structure_result)
    # print('table_structures:', table_structures)
    # print('cells:', cells)
    # print('confidence_score:', confidence_score)
    cells_filename = visualize_cells(crop_filename, table_structures, cells)
    # print('cells_filename:', cells_filename)

    html = cells_to_html(cells)
    html_filename = crop_filename[:-4]+'.html'
    with open(html_filename, 'w') as f:
        f.write(html)