<a href="https://colab.research.google.com/github/nakamura196/000_tools/blob/main/NDLTSR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NDLTSR (NDL Table Structure Recognition)

「資料画像に含まれる表の構造を認識するプログラム」を試すノートブックです。

https://github.com/ndl-lab/ndltsr

## セットアップ

In [None]:
!pip install fastapi
!pip install uvicorn
!pip install progress

In [None]:
CONTENT_DIR = "/content"
%cd {CONTENT_DIR}

!git clone https://github.com/ndl-lab/ndltsr.git

PROJECT_DIR=f"/{CONTENT_DIR }/ndltsr"

%cd {PROJECT_DIR}

!wget -nc https://lab.ndl.go.jp/dataset/ndltsrmodel/model_last.pth -P {PROJECT_DIR}/exp/ctdet_mid/train_wireless_ndl/
!wget -nc https://lab.ndl.go.jp/dataset/ndltsrmodel/processor_last.pth -P {PROJECT_DIR}/exp/ctdet_mid/train_wireless_ndl/

## サーバの起動

In [None]:
import subprocess
import requests
import time

PORT = 8081
invocations_url = f"http://127.0.0.1:{PORT}/invocations"

%cd {PROJECT_DIR}/src

# api.py をバックグランドで実行
process = subprocess.Popen(['uvicorn', 'api:app', '--port', str(PORT)])

# バックグランドプロセスが実行中であることを確認
# print("api.py is running in the background...")

# バックグランドプロセスの終了を待つには、次の行のコメントを解除してください
# process.wait()

# 起動するまで待つ
while True:
  try:
    time.sleep(3)
    r = requests.get(f"{invocations_url}")
    status_code = r.status_code
    if status_code == 405:
      break
  except:
    pass

## 推論の実行

In [None]:
# @title サンプル画像のダウンロード
img_path = "/content/ndltsr/src/sample.jpg" # @param {type:"string"}
url = "https://dl.ndl.go.jp/api/iiif/1046122/R0000005/1860,1370,1000,548/full/0/default.jpg" # @param {type:"string"}

!wget -O {img_path} {url}

In [None]:
import urllib.request
import msgpack
import json

data = {}
with open(img_path, "rb") as fp:
    data["img"] = fp.read()
payload = msgpack.packb(data, use_bin_type=True)
headers = {
    "Content-Type": "application/x-msgpack",
}
req = urllib.request.Request(
        f"{invocations_url}",
    payload, headers
)

# リクエストを送信し、応答を受け取る
with urllib.request.urlopen(req) as res:
    # 応答をJSONとして読み取る
    response_body = res.read()
    response_data = json.loads(response_body)

## 可視化

In [None]:
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display

font_path = "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf"

# 画像を読み込む（適切なパスに置き換える）
image = Image.open(img_path)

# RGBAモードをRGBに変換（透明度情報を破棄）
if image.mode == 'RGBA':
    image = image.convert('RGB')
elif image.mode == "P":
    image = image.convert('RGB')

# ImageDraw オブジェクトを作成
draw = ImageDraw.Draw(image)

rectangles = response_data["center"]
logi_data = response_data["logi"]

# 各四角形と対応するlogiデータについて描画
for rect, logi in zip(rectangles, logi_data):
    # 四角形を描画
    draw.polygon([tuple(point) for point in rect], outline='red')
    # logiデータをテキストとして描画
    # draw.text((rect[0][0], rect[0][1]), str(logi), fill="blue", font=font)

    # 四角形の幅を計算
    width = max(rect, key=lambda x: x[0])[0] - min(rect, key=lambda x: x[0])[0]

    # フォントサイズを幅に基づいて設定（例：幅の10%）
    font_size = max(int(width * 0.10), 12)  # 最小フォントサイズを12とする
    font = ImageFont.truetype(font_path, font_size)

    # logiデータをテキストとして描画
    text_position = (rect[0][0], rect[0][1] - font_size)  # テキスト位置を調整
    draw.text(text_position, str(logi), fill="blue", font=font)

# 変更を保存
image.save('output_image.jpg')
image.show()

display(image)

## その他

In [None]:
import os
import json
import numpy as np
from io import BytesIO
from PIL import Image
import base64
import re
import logging
import urllib
import msgpack
import json
import pandas as pd
import requests

def download_file(url, dst_path):
    try:
        with urllib.request.urlopen(url) as web_file:
            data = web_file.read()
            with open(dst_path, mode='wb') as local_file:
                local_file.write(data)
    except urllib.error.URLError as e:
        print(e)
def check_iou(a, b,thr=0.6):
    """
    a: [xmin, ymin, xmax, ymax]
    b: [xmin, ymin, xmax, ymax]

    return: array(iou)
    """
    b = np.asarray(b)
    a_area = (a[  2] - a[  0]) * (a[  3] - a[  1])
    b_area = (b[  2] - b[  0]) * (b[  3] - b[  1])
    intersection_xmin = np.maximum(a[0], b[0])
    intersection_ymin = np.maximum(a[1], b[1])
    intersection_xmax = np.minimum(a[2], b[2])
    intersection_ymax = np.minimum(a[3], b[3])
    intersection_w = np.maximum(0, intersection_xmax - intersection_xmin)
    intersection_h = np.maximum(0, intersection_ymax - intersection_ymin)
    intersection_area = intersection_w * intersection_h
    min_area=min(a_area,b_area)
    if intersection_area/min_area>thr:
        return True
    return False

def tdcreate(rpos,cpos,flagmap,text):
    rowsize,colsize=flagmap.shape
    tmpid=flagmap[rpos,cpos]
    deltac=0
    deltar=0
    for ct in range(cpos,colsize):
       if flagmap[rpos,ct]==tmpid:
            deltac+=1
       else:
            break
    for rt in range(rpos,rowsize):
        if flagmap[rt,cpos]==tmpid:
            deltar+=1
        else:
            break
    if deltac==1 and deltar==1:
        return '<td>{}</td>'.format(text)
    else:
        return '<td colspan="{}" rowspan="{}">{}</td>'.format(deltac,deltar,text)

def extractfromocr(coordobj,rectminx,rectminy):
    resobj=[]
    tmpobj=[]
    for obj in coordobj:
        dubflag=False
        xmin = int(obj["xmin"])-int(rectminx)
        ymin = int(obj["ymin"])-int(rectminy)
        xmax = int(obj["xmax"])-int(rectminx)
        ymax = int(obj["ymax"])-int(rectminy)
        if xmin>xmax:
            xmin,xmax=xmax,xmin
        if ymin>ymax:
            ymin,ymax=ymax,ymin
        bbox=[xmin,ymin,xmax,ymax]
        for tmp in tmpobj:
            if check_iou(bbox,tmp,thr=0.95):
                dubflag=True
                break
        if dubflag:
            continue
        tmpobj.append(bbox)
        text=obj["contenttext"]
        resobj.append({"bbox":bbox,"text":text})
    resobj=sorted(resobj, key=lambda x: x['bbox'][1])
    resobj=sorted(resobj, key=lambda x: x['bbox'][0])
    return resobj

def dupmerge(conv_atrobjlist,textbboxlist):
    #まずloreの出力をきれいにする
    newconv_atrobjlist=[]
    used=set()
    for idx1 in range(len(conv_atrobjlist)):
        if idx1 in used:
            continue
        bbox1=conv_atrobjlist[idx1][4]
        lbox1=conv_atrobjlist[idx1][:4]
        for idx2 in range(idx1+1,len(conv_atrobjlist)):
            bbox2=conv_atrobjlist[idx2][4]
            lbox2=conv_atrobjlist[idx2][:4]
            if check_iou(bbox1,bbox2):
                used.add(idx2)
                bbox1=[min(bbox1[0],bbox2[0]),min(bbox1[1],bbox2[1]),max(bbox1[2],bbox2[2]),max(bbox1[3],bbox2[3])]
                lbox1=[min(lbox1[0],lbox2[0]),min(lbox1[1],lbox2[1]),max(lbox1[2],lbox2[2]),max(lbox1[3],lbox2[3])]
        newconv_atrobjlist.append([lbox1,bbox1])
    #textboxとマージする
    reslist=[]
    for idx1 in range(len(newconv_atrobjlist)):
        bbox1=newconv_atrobjlist[idx1][1]
        lbox1=newconv_atrobjlist[idx1][0]
        restext=""
        for textobj in textbboxlist:
            bboxt=textobj["bbox"]
            text=textobj["text"]
            if check_iou(bbox1,bboxt,0.1):
                restext+=text
        lbox1.append(restext)
        reslist.append(lbox1)
    return reslist
def extractfromlore(resultobj,textbboxlist):
    bndobjlist=[]
    atrobjlist=[]
    axis_set_row=set()
    axis_set_col=set()
    for bndobj,logiobj in zip(resultobj["center"],resultobj["logi"]):
        xmin = int(min([bndobj[0][0],bndobj[1][0],bndobj[2][0],bndobj[3][0]]))
        ymin = int(min([bndobj[0][1],bndobj[1][1],bndobj[2][1],bndobj[3][1]]))
        xmax = int(max([bndobj[0][0],bndobj[1][0],bndobj[2][0],bndobj[3][0]]))
        ymax = int(max([bndobj[0][1],bndobj[1][1],bndobj[2][1],bndobj[3][1]]))
        bbox=[xmin,ymin,xmax,ymax]
        bndobjlist.append(bbox)
        rowmin,rowmax,colmin,colmax=None,None,None,None
        rowmin = int(logiobj[0])
        rowmax = int(logiobj[1])
        colmin = int(logiobj[2])
        colmax = int(logiobj[3])
        if rowmin>rowmax:
            rowmin,rowmax=rowmax,rowmin
        if colmin>colmax:
            colmin,colmax=colmax,colmin
        axis_set_row.add(rowmin)
        axis_set_row.add(rowmax)
        axis_set_col.add(colmin)
        axis_set_col.add(colmax)
        atrobjlist.append([rowmin,rowmax,colmin,colmax])
    col2idx={}
    row2idx={}
    for idx,colval in enumerate(sorted(axis_set_col)):
        col2idx[colval]=idx
    for idx,rowval in enumerate(sorted(axis_set_row)):
        row2idx[rowval]=idx

    conv_atrobjlist=[]

    for idx,(rowmin,rowmax,colmin,colmax) in enumerate(atrobjlist):
        conv_atrobjlist.append([row2idx[rowmin],row2idx[rowmax],col2idx[colmin],col2idx[colmax],bndobjlist[idx]])

    conv_atrobjlist=dupmerge(conv_atrobjlist,textbboxlist)
    sorted_data = sorted(conv_atrobjlist, key=lambda x: (x[0], x[2], x[1], x[3]))

    colsize=len(col2idx)
    rowsize = len(row2idx)
    targetcolcnt={}
    for ii in range(rowsize+1):
        targetcolcnt[ii]=colsize
    tablestr='<table  border="1"><tr>'
    currentrow=0
    currentcol=0
    flagmap=np.zeros((rowsize+1,colsize+1))-1
    tmpid2text={}
    for tmpidx, (rowmin, rowmax, colmin, colmax,text) in enumerate(sorted_data):
        for r in range(rowmin,rowmax+1):
            for c in range(colmin,colmax+1):
                flagmap[r,c]=tmpidx
        tmpid2text[tmpidx]=text
    #print(flagmap)
    tmpidxset=set()
    for r in range(rowsize):
        tablestr+="</tr><tr>"
        for c in range(colsize):
            if flagmap[r,c]==-1:
                tablestr += '<td></td>'
            elif flagmap[r,c] in tmpidxset:
                continue
            else:
                tmpidxset.add(flagmap[r,c])
                tablestr += tdcreate(r,c,flagmap,tmpid2text[flagmap[r,c]])
    tablestr+="</tr></table>"
    return tablestr

In [None]:
textbboxlist = []
tablestr=extractfromlore(response_data, textbboxlist)

# dfs=pd.read_html(tablestr)
# df=dfs[0]
# tsv_string = df.to_csv(index=None,header = False,sep="\t")

from IPython.display import HTML
HTML(tablestr)