In [None]:
from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
)
from layoutlmft.data.utils import load_image, normalize_bbox
from detectron2.data.transforms import ResizeTransform, TransformList
import os
import json
import torch
import shutil
from glob import glob
import cv2
import pandas as pd
import numpy as np
from pdf2image import convert_from_path
from collections import Counter
import re
import shutil

## Load Model

In [None]:
model_path = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/unilm/20220516_outputs/checkpoint-4000"
class_labels = ['O', 'B-OTHER', 'I-OTHER', 'B-SUPPLIER_NAME', 'I-SUPPLIER_NAME', 'B-SUPPLIER_ADDR', 'I-SUPPLIER_ADDR', 'B-TOTALAMOUNT', 'I-TOTALAMOUNT']

config = AutoConfig.from_pretrained(
    model_path,
    num_labels=len(class_labels),
)
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    use_fast=True,
)
model = AutoModelForTokenClassification.from_pretrained(
    model_path,
    from_tf=bool(".ckpt" in model_path),
    config=config
)

## Create TextractReader

In [None]:
class TextractReader:

    def __init__(self, textract_path, doc_dir):
        self.doc_dir = doc_dir
        self.textract_data = {}
        self.parse_textract(textract_path)

    def parse_textract(self, textract_path):
        textract_data = pd.read_json(textract_path)
        
        for filename, res in textract_data.itertuples(index=False):
            
            # if filename not in ["9508645819.pdf"]: continue
            if res.get("JobStatus") not in ["SUCCEEDED", None]: 
                print(f"SKIP {filename} due to {res.get('JobStatus')} status")
                continue
            doc_path = os.path.join(self.doc_dir, filename)
            if not os.path.exists(doc_path):
                print(f"SKIP {filename} because {doc_path} does not exist")

            page_num = res["DocumentMetadata"]["Pages"]
            blocks = self.parse_blocks(res["Blocks"], page_num)
            imgs = self.read_images(filename)

            pages = {}
            for i in range(page_num):
                page_id = i+1
                page_df = blocks[blocks["page"] == page_id].copy()
                page_df = page_df.reset_index(drop=True)

                if len(page_df) == 0: 
                    continue
                height, width, _ = imgs[i].shape
                page_df[["x0", "x2", "xtl", "xbr"]] *= width
                page_df[["y0", "y2", "ytl", "ybr"]] *= height
                angle = self.check_orientation(page_df)
                page_df = self.refine_orientation(page_df, angle, width, height)
                page_df["page"] = page_id
                page_df = page_df[["xtl", "ytl", "xbr", "ybr", "text", "page"]]
                page_data = {"ocr": page_df, "angle": angle, "width": width, "height": height}
                pages[page_id] = page_data

            self.textract_data[filename] = pages
            
    def get_data(self, filename, return_img=True):
        if self.textract_data.get(filename) is None:
            return None
            
        pages = self.textract_data.get(filename)
        if return_img:
            imgs = self.read_images(filename)
            for page_id, page_data in pages.items():
                img = imgs[page_id - 1]
                angle = page_data["angle"]
                angle_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
                if angle != 0:
                    img = cv2.rotate(img, angle_map[angle])
                pages[page_id]["img"] = img
        return pages

    
    def read_images(self, filename):
        doc_path = os.path.join(self.doc_dir, filename)
        imgs = []
        if filename[-3:] in ["jpg", "png"]:
            img = cv2.imread(doc_path)
            imgs.append(img)
        elif filename[-3:] in ["pdf"]:
            pil_imgs = convert_from_path(doc_path)
            for i, img in enumerate(pil_imgs):
                img = np.array(img)
                img = img[:, :, ::-1]
                imgs.append(img)
        return imgs
            
    def parse_blocks(self, blocks, page_num):
        block_df = pd.DataFrame()
        for block in blocks:
            if block["BlockType"] != "WORD": continue
            conf = block["Confidence"]
            text = block["Text"]
            page = block.get("Page")
            if page is None and page_num > 1:
                raise Exception("Page is None while number of pages > 1")
            if page is None:
                page = 1
            polygon = block["Geometry"]["Polygon"]
            X = [p['X'] for p in polygon]
            Y = [p['Y'] for p in polygon]
            data = {"x0": X[0], "y0": Y[0], "x2": X[2], "y2": Y[2],
                    "xtl": min(X), "ytl": min(Y), "xbr": max(X), "ybr": max(Y),
                    "text": text, "score": conf, "page": page}
            block_df = block_df.append(data, ignore_index=True)
        return block_df

    def check_orientation(self, df, ref_num=10):
        df = df.copy()
        df = df.sort_values(by=["score"], ascending=False)
        df = df[["x0", "y0", "x2", "y2"]]
        angles = []
        for x0, y0, x2, y2 in df.head(ref_num).itertuples(index=False):
            if x0 < x2 and y0 < y2:
                angles.append(0)
            elif x0 < x2 and y0 > y2:
                angles.append(90)
            elif x0 > x2 and y0 < y2:
                angles.append(270)
            else:
                angles.append(180)
        cnt = Counter(angles)
        return cnt.most_common()[0][0]

    def refine_orientation(self, textract_df, angle, width, height):
        if angle == 90:
            textract_df["xtl"] = height - textract_df["y0"]
            textract_df["xbr"] = height - textract_df["y2"]
            textract_df["ytl"] = textract_df["x0"]
            textract_df["ybr"] = textract_df["x2"]
        elif angle == 270:
            textract_df["xtl"] = textract_df["y0"]
            textract_df["xbr"] = textract_df["y2"]
            textract_df["ytl"] = width - textract_df["x0"]
            textract_df["ybr"] = width - textract_df["x2"]
        elif angle == 180:
            textract_df["xtl"] = width - textract_df["x0"]
            textract_df["xbr"] = width - textract_df["x2"]
            textract_df["ytl"] = height - textract_df["y0"]
            textract_df["ybr"] = height - textract_df["y2"]
        return textract_df

textract_path = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/textract_extraction_filtered.json"
doc_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/20220425_FilteredData"
textract_reader = TextractReader(textract_path, doc_dir)
print(len(textract_reader.textract_data))

## Layoutlmv2 Data Pipeline (SDMGR only)

In [None]:
debug_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/debug"

def debug(img, df, output_path):
    img = img.copy()
    for _, tag, _, class_name, word, box in df.itertuples(index=False):
        x1, y1, x2, y2 = map(round, box)
        if class_name == "TOTALAMOUNT":
            color = (0, 0, 255)
        elif class_name == "SUPPLIER_NAME":
            color = (255, 0, 0)
        elif class_name == "SUPPLIER_ADDR":
            color = (0, 255, 0)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
        # img = cv2.putText(img, tag, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 
        #            0.2, color, 2, cv2.LINE_AA)
    cv2.imwrite(output_path, img)

def tokenize_and_align_labels(examples):
    padding = "max_length"
    text_column_name = "tokens"

    tokenized_inputs = tokenizer(
        examples[text_column_name],
        padding=padding,
        truncation=True,
        return_overflowing_tokens=True,
        is_split_into_words=True,
    )

    bboxes = []
    images = []
    _word_ids = []

    for batch_index in range(len(tokenized_inputs["input_ids"])):
        word_ids = tokenized_inputs.word_ids(batch_index=batch_index)
        _word_ids += word_ids
        org_batch_index = tokenized_inputs["overflow_to_sample_mapping"][batch_index]

        bbox = examples["norm_bboxes"][org_batch_index]
        image = examples["image"][org_batch_index]


        previous_word_idx = None
        # label_ids = []
        bbox_inputs = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                bbox_inputs.append([0, 0, 0, 0])
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                bbox_inputs.append(bbox[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                bbox_inputs.append(bbox[word_idx])
            previous_word_idx = word_idx

        bboxes.append(bbox_inputs)
        images.append(image)
    tokenized_inputs["bbox"] = bboxes
    tokenized_inputs["image"] = images

    overflow_mapping = tokenized_inputs["overflow_to_sample_mapping"]
    tokenized_inputs.pop("overflow_to_sample_mapping", None)
    
    return tokenized_inputs, overflow_mapping, _word_ids

def transform_image(image):
    h = image.shape[0]
    w = image.shape[1]
    img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
    image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1)  # copy to make it writeable
    return image, (w, h)

def generate_example(img, page_df):
    tokens = []
    bboxes = []
    norm_bboxes = []
    img, size = transform_image(img)
    for xtl, ytl, xbr, ybr, text, _ in page_df.itertuples(index=False):
        box = [xtl, ytl, xbr, ybr]
        tokens.append(text)
        bboxes.append(box)
        norm_bboxes.append(normalize_bbox(box, size))

    return {"tokens": [tokens], "bboxes": [bboxes], "norm_bboxes": [norm_bboxes], "image": [img]}

def convert_to_tensor(inputs):
    inputs_t = dict()
    for k, v in inputs.items():
        if isinstance(v[0], list):
            inputs_t[k] = torch.tensor(v)
        elif isinstance(v[0], torch.Tensor):
            inputs_t[k] = torch.stack(v)
        else:
            raise Exception(f"{k} is a list of type {type(v[0])}")
    return inputs_t

def get_class(text):
    l = text.split("-")
    return l[1] if len(l) == 2 else "OTHER"

def refine(words, boxes, tags, scores, word_ids):
    df = pd.DataFrame({"word_id": word_ids, "tag": tags, "conf": scores})
    df.dropna(inplace=True)
    df = df.astype({"word_id": "int32"})
    df = df.drop_duplicates(["word_id"])
    df = df[df["tag"] != "O"]
    df["prefix"] = df["tag"].map(lambda x : x.split("-")[0])
    df["class"] = df["tag"].map(get_class)
    df["word"] = df["word_id"].map(lambda x : words[x])
    df["bbox"] = df["word_id"].map(lambda x : boxes[x])
    return df

def infer_one_page(page_data):
    res = pd.DataFrame(columns=["xtl", "ytl", "xbr", "ybr", "text", "label", "page"])
    class_map = {"SUPPLIER_NAME": "supplier_name", "SUPPLIER_ADDR": "supplier_addr", "TOTALAMOUNT": "totalAmount"}

    img = page_data["img"]
    page_df = page_data["ocr"]
    page_id = np.unique(page_df["page"]).item()

    example = generate_example(img, page_df)
    inputs, overflow_mapping, word_ids = tokenize_and_align_labels(example)
    inputs_t = convert_to_tensor(inputs)

    with torch.no_grad():
        outputs = model(**inputs_t)
        input_boxes = inputs_t["bbox"].reshape([-1, 4]).tolist()
        scores, preds = torch.max(outputs.logits, -1)
        scores = scores.reshape([-1]).tolist()
        preds = preds.reshape([-1]).tolist()
    
    pred_tags = [class_labels[i] for i in preds]
    words = example["tokens"][0]
    boxes = example["bboxes"][0]
    df = refine(words, boxes, pred_tags, scores, word_ids)
    # debug(img, df, os.path.join(debug_dir, "commercial-invoice.jpg"))
    for _, tag, conf, _, class_name, word, box in df.itertuples(index=False):
        x1, y1, x2, y2 = map(round, box)
        res = res.append({"xtl": x1, "ytl": y1, "xbr": x2, "ybr": y2, "text": word, "label": class_map[class_name], "conf": conf, "page": page_id}, ignore_index=True)
    return res

## Inference

In [None]:
# if os.path.isdir(debug_dir):
#     shutil.rmtree(debug_dir)
# os.makedirs(debug_dir)

def infer_one_doc(doc_data):
    doc_df = pd.DataFrame(columns=["xtl", "ytl", "xbr", "ybr", "text", "label", "conf", "page"])
    for page_id, page_data in doc_data.items():
        page_df = infer_one_page(page_data)
        doc_df = doc_df.append(page_df, ignore_index=True)
    return doc_df

# input_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/test"
input_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/20220425_FilteredData"

output_dir = "20220530_layoutlmv2_outputs"
if os.path.isdir(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)
fnames = ["9498101535.jpg"]
filenames = [x for x in os.listdir(input_dir) if os.path.splitext(x)[1] in [".jpg", ".png", ".pdf"]]

for filename in filenames:
    # if filename not in fnames: continue
    prefix = os.path.splitext(filename)[0]
    doc_data = textract_reader.get_data(filename)
    if doc_data is None:
        print(f"Skip {filename} because Textract data is missing")
        continue
    doc_df = infer_one_doc(doc_data)
    # optional debug
    doc_df.to_csv(os.path.join(output_dir, f"{prefix}.csv"), sep=",", index=False)

## Post processing

In [None]:
class DataExtractor:
    def __init__(self, country_currency_map_file, symbol_currency_map_file, textract_data):
        self.country_code = pd.read_csv(country_currency_map_file)
        self.country_code["CountryCode"] = self.country_code["CountryCode"].astype("str")
        self.money_symbol = pd.read_csv(symbol_currency_map_file)
        self.textract_data = textract_data
        
    def __call__(self, doc_kie_df, doc_ocr_data):
        res = {"supplier_name": "", "totalAmount": "", "currencyCode": ""}
        doc_kie_df = self.merge_words(doc_kie_df)
        total_amount, total_amount_page, total_amount_full_text, total_amount_box = self.extract_total_amount(doc_kie_df)
        res['supplier_name'] = self.extract_supplier_name(doc_kie_df, total_amount_page)
        res['totalAmount'] = total_amount
        address = self.extract_supplier_address(doc_kie_df, total_amount_page)
        res['currencyCode'] = self.extract_currency_all_page(doc_kie_df, doc_ocr_data, total_amount, total_amount_page, total_amount_full_text, total_amount_box, address)
        return res
        
    def process_money_text(self, text):
        m = re.search("\d+(?:.\d|,\d)*\d*", text)
        if m is not None:
            subtext = m.group()
            subtext = re.sub(",", ".", subtext)
            subtext = re.sub("[.](?=.*[.]|\d{3})", "", subtext)
            try:
                t = float(subtext)
            except:
                pass
            else:
                return str(float(subtext))
        return None

    def extract_total_amount(self, doc_kie_df):
        total_amount = ""
        total_amount_full_text = ""
        total_amount_box = None
        total_amount_page = None

        for _, page_kie_df in doc_kie_df.groupby("page"):
            max_score = 0
            for text, _, conf, xtl, ytl, xbr, ybr, page_id in page_kie_df[page_kie_df["label"] == "totalAmount"].itertuples(index=False):
                subtext = self.process_money_text(text)
                if subtext is not None and conf > max_score:
                    total_amount = subtext
                    total_amount_full_text = text
                    total_amount_box = [xtl, ytl, xbr, ybr]
                    total_amount_page = page_id
            if total_amount:
                break
        
        return total_amount, total_amount_page, total_amount_full_text, total_amount_box

    def extract_supplier_name(self, doc_kie_df, total_amount_page):
        supplier_name = ""
        name_pool = doc_kie_df[(doc_kie_df["page"] == total_amount_page) & (doc_kie_df["label"] == "supplier_name")]["text"].tolist()
        if name_pool:
            supplier_name = sorted(name_pool, key=lambda x : len(x.split()), reverse=True)[0]
        if not supplier_name:
            for page, post_page_df in doc_kie_df.groupby("page"):
                name_pool = post_page_df[post_page_df["label"] == "supplier_name"]["text"].tolist()
                if name_pool:
                    supplier_name = sorted(name_pool, key=lambda x : len(x.split()), reverse=True)[0]
                    break
        if "Google".lower() in supplier_name.lower():
            supplier_name = "Google"
        elif "Microsoft".lower() in supplier_name.lower():
            supplier_name = "Microsoft"
        elif "Grab".lower() in supplier_name.lower():
            supplier_name = "Grab"
        return supplier_name

    def extract_supplier_address(self, post_df, total_amount_page):
        addresses = post_df[(post_df["label"] == 'supplier_addr') & (post_df["page"] == total_amount_page)]["text"].tolist()
        address = " ".join(addresses)
        return address

    def open_json_file(self, filename, textract_data):
        file_name=[filename[:10]+'.jpg',filename[:10]+'.pdf']
        if file_name[0] in textract_data.keys():
            for i in textract_data[file_name[0]]['pages']:
                if file_name[0][:-4]+'_page_'+str(i['page_id'])==filename:
                    return i['data'], i['width'], i['height']
        if file_name[1] in textract_data.keys():
            for i in textract_data[file_name[1]]['pages']:
                if file_name[0][:-4]+'_page_'+str(i['page_id'])==filename:
                    return i['data'], i['width'], i['height']
                    
    def extract_currency_one_page(self, doc_ocr_data, page, total_amount, total_amount_full_text, total_amount_box, address, currencyCode):
        if currencyCode:
            return currencyCode
        check=False

        page_ocr_data = doc_ocr_data[page]
        if total_amount:
            if 'd' in total_amount_full_text:
                return 'VND'
            for symbol, Code in self.money_symbol.itertuples(index=False):
                if symbol in total_amount_full_text:
                    if symbol == '$':
                        check = True
                    else:
                        return Code

            bbox = total_amount_box
            
            bw = bbox[2] - bbox[0]
            bh = bbox[3] - bbox[1]
            x0 = bbox[0] - bw
            y0 = bbox[1] - 4*bh
            x2 = bbox[2] + bw
            y2 = bbox[3] + 4*bh
            
            
            list_currency_outside=[]
            list_currency_inside=[]
            
            for xtl, ytl, xbr, ybr, text, page in page_ocr_data["ocr"].itertuples(index=False):
                check_box = [xtl, ytl, xbr, ybr]
                
                for symbol, currencyCode in self.money_symbol.itertuples(index=False):
                    pattern = ['[a-zA-Z]+' + symbol, symbol + '[a-zA-Z]+']
                    if symbol not in text: continue
                    elif re.search(pattern[0], text) or re.search(pattern[1], text):
                        continue
                    else:
                        if symbol == '$':
                            check = True
                            
                        if check_box[0] >= x0 and check_box[2] <= x2:
                            list_currency_inside.append(currencyCode)
                        elif check_box[1] >= y0 and check_box[3] <= y2:
                            list_currency_inside.append(currencyCode)
                        else:
                            list_currency_outside.append(currencyCode)

            if check: 
                if 'SGD' in list_currency_inside or 'SGD' in list_currency_outside or 'Singapore'.lower() in address.lower():    
                    return 'SGD'
            if list_currency_inside:
                if check:
                    if len(set(list_currency_inside + list_currency_outside)) == 2 and '$' in list_currency_inside:
                        list_currency = set(list_currency_inside+list_currency_outside)
                        list_currency.remove('$')
                        return list_currency.pop()      
                return self.most_frequent(list_currency_inside)
            elif list_currency_outside:
                return self.most_frequent(list_currency_outside)
        else:
            list_currency_all=[]
            for xtl, ytl, xbr, ybr, text, page in page_ocr_data["ocr"].itertuples(index=False):
                for symbol, currencyCode in self.money_symbol.itertuples(index=False):
                    pattern=['[a-zA-Z]+' + symbol, symbol + '[a-zA-Z]+']
                    if symbol not in text: continue
                    elif re.search(pattern[0], text) or re.search(pattern[1], text):
                        continue
                    else:
                        if symbol == '$':
                            check = True
                        list_currency_all.append(currencyCode)
            if list_currency_all:
                if check: 
                    if 'SGD' in list_currency_all or 'Singapore'.lower() in address.lower():    
                        return 'SGD'
                    if len(set(list_currency_all)) == 2 and '$' in list_currency_all:
                        list_currency = set(list_currency_all)
                        list_currency.remove('$')
                        return list_currency.pop()
            return self.most_frequent(list_currency_all)            
        if address:
            for index in range(len(self.country_code)):
                if str(self.country_code.loc[index].Country) in address or str(self.country_code.loc[index].CountryCode) in address:
                    return self.country_code.loc[index].Code
        return ""

    def extract_currency_all_page(self, doc_kie_df, doc_ocr_data, total_amount, total_amount_page, total_amount_full_text, total_amount_box, address):
        currency = ""
        all_pages = set()
        if total_amount:
            currency = self.extract_currency_one_page(doc_ocr_data, total_amount_page, total_amount, total_amount_full_text, total_amount_box, address, currency)
            all_pages.add(total_amount_page)
            
        if not currency:
            for text, label, conf, xtl, ytl, xbr, ybr, page in doc_kie_df.itertuples(index=False):
                if page in all_pages:
                    continue
                if total_amount != "": 
                    if label == 'totalAmount' and total_amount == self.process_money_text(text):
                        bbox = [xtl, ytl, xbr, ybr]
                        currency = self.extract_currency_one_page(doc_ocr_data, page, total_amount, text, bbox, address, currency)
                        all_pages.add(page)
                        if currency != "":
                            break

        if not currency:
            for text, label, conf, xtl, ytl, xbr, ybr, page in doc_kie_df.itertuples(index=False):
                if page in all_pages:
                    continue
                currency = self.extract_currency_one_page(doc_ocr_data, page, "", "", None, address, currency)
                all_pages.add(page)
                if currency != "":
                    break
        if currency == "$":
            currency = "USD"
        return currency
    
    def most_frequent(self, l):
        if l:
            return Counter(l).most_common()[0][0]
        return None

    def is_adjacent(self, bi, bj):
        if bi["label"] != bj["label"] or bi["page"] != bj["page"]:
            return False
        min_h = min(bi["ybr"] - bi["ytl"], bj["ybr"] - bj["ytl"])
        min_w = min(bi["xbr"] - bi["xtl"], bj["xbr"] - bj["xtl"])
        dx = min(bi["xbr"], bj["xbr"]) - max(bi["xtl"], bj["xtl"])
        dy = min(bi["ybr"], bj["ybr"]) - max(bi["ytl"], bj["ytl"])

        dxr = dx/min_w
        dyr = dy/min_h

        if bi["label"] in ["supplier_name", "supplier_addr"] and dxr > 0.1 and dy > -min_h/2:
            return True
        if dyr > 0.2 and dx > -min_h:
            return True
        return False

    def dfs(self, adj_mat, i, groups, gid):
        if groups[i] > -1:
            return groups
        groups[i] = gid
        for j in range(len(adj_mat[i])):
            if adj_mat[i, j] == 1:
                groups = self.dfs(adj_mat, j, groups, gid)
        return groups

    def merge_words(self, df):
        res = pd.DataFrame(columns=["text", "label", "conf", "xtl", "ytl", "xbr", "ybr", "page"])
        if len(df) == 0:
            return res
        df = df.reset_index(drop=True)
        adj_mat = np.zeros((len(df), len(df)), dtype=np.int32)
        for i, bi in df.iterrows():
            for j, bj in df.iterrows():
                if self.is_adjacent(bi, bj):
                    adj_mat[i][j] = 1
        
        groups = [-1] * len(adj_mat)
        gid = 0
        for i in range(len(adj_mat)):
            groups = self.dfs(adj_mat, i, groups, gid)
            gid += 1
        df["group"] = groups
        for _, group_df in df.groupby("group"):
            label = np.unique(group_df["label"]).item()
            text = " ".join(group_df["text"].tolist())
            conf = group_df["conf"].mean()
            xtl = group_df["xtl"].min()
            ytl = group_df["ytl"].min()
            xbr = group_df["xbr"].max()
            ybr = group_df["ybr"].max()
            page_id = np.unique(group_df["page"]).item()
            res = res.append({"text": text, "label": label, "conf": conf, "xtl": xtl, "ytl": ytl, "xbr": xbr, "ybr": ybr, "page": page_id}, ignore_index=True)
        return res

        
extractor = DataExtractor("/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/country_currency_mapping.csv", "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/currency_symbol_mapping_2.csv", textract_reader.textract_data)

In [None]:
result_file = "20220530_layoutlmv2_results.csv"
csv_dir = "20220530_layoutlmv2_outputs"
filenames = os.listdir(doc_dir)

# fnames = ["9508645819"]
out_df = pd.DataFrame(columns=["code", "supplier_name", "totalAmount", "currencyCode"])

for filename in filenames:
    prefix = os.path.splitext(filename)[0]
    # if prefix not in fnames: continue
    csv_path = os.path.join(csv_dir, f"{prefix}.csv")
    if not os.path.exists(csv_path):
        print(f"Skip {filename} because {csv_path} does not exists")
        continue
    print(filename)
    doc_ocr_data = textract_reader.get_data(filename, return_img=False)
    doc_kie_df = pd.read_csv(csv_path)
    doc_kie_df = doc_kie_df.astype({"text": "str"})
    doc_out = {"code": prefix}
    ext_out = extractor(doc_kie_df, doc_ocr_data)
    doc_out.update(ext_out)
    print(doc_out)
    out_df = out_df.append(doc_out, ignore_index=True)

out_df.to_csv(result_file, index=False)

In [None]:
result_file = "20220530_layoutlmv2_results.csv"
template_file = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/20220519_results_eval.xlsx"
df = pd.read_csv(result_file)
tp = pd.read_excel(template_file)

merged = tp.merge(df, on="code", how="inner")
merged["pred_supplier_name"] = merged["supplier_name"]
merged["pred_totalAmount"] = merged["totalAmount"]
merged["pred_currencyCode"] = merged["currencyCode"]
merged['Correct Amount (TRUE/FALSE)'] = ""
merged['Correct Supplier (TRUE/FALSE)'] = ""
merged['Correct Currency (TRUE/FALSE)'] = ""
merged = merged[['code', 'pred_supplier_name', 'true_supplier_name', 'Supplier name present (TRUE/FALSE)', 'Correct Supplier (TRUE/FALSE)', 'pred_totalAmount', 'true_totalAmount', 'total amount present (TRUE/FALSE)', 'Correct Amount (TRUE/FALSE)', 'Autopub', 'pred_currencyCode', 'true_currencyCode', 'Correct Currency (TRUE/FALSE)', 'Autopublish']]
merged.to_excel("20220530_layoutlmv2_results_eval.xlsx", index=False)
print(np.sum(merged["pred_totalAmount"] == merged["true_totalAmount"])/len(merged))
print(np.sum(merged["pred_currencyCode"] == merged["true_currencyCode"])/len(merged))

In [None]:
915 - np.sum(merged["pred_currencyCode"] == merged["true_currencyCode"])

In [None]:
tp.info()