In [3]:
from transformers import (
    AutoConfig,
    AutoTokenizer,
)
from layoutlmft import AutoModelForTokenClassification
from layoutlmft.data.utils import load_image, normalize_bbox
from detectron2.data.transforms import ResizeTransform, TransformList
import os
import json
import torch
import shutil
from glob import glob
import cv2
import pandas as pd
import numpy as np
from pdf2image import convert_from_path
from collections import Counter
import re
import shutil

In [4]:
model_path = "/home/minh/Storage/projects/EGS/InvoiceDataExtraction/unilm/20220516_outputs/checkpoint-4000"
class_labels = ['O', 'B-OTHER', 'I-OTHER', 'B-SUPPLIER_NAME', 'I-SUPPLIER_NAME', 'B-SUPPLIER_ADDR', 'I-SUPPLIER_ADDR', 'B-TOTALAMOUNT', 'I-TOTALAMOUNT']

config = AutoConfig.from_pretrained(
    model_path,
    num_labels=len(class_labels),
)
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    use_fast=True,
)
model = AutoModelForTokenClassification.from_pretrained(
    model_path,
    from_tf=bool(".ckpt" in model_path),
    config=config
)

# model_path = "/home/minh/Storage/projects/EGS/VetDoc/20220713_models/checkpoint-1000"
# class_labels = ["O", "B-OTHER", "B-PROVIDER", "B-DATE", "B-NUMBER", "B-RECIPIENT", "I-OTHER", "I-PROVIDER", "I-DATE", "I-NUMBER", "I-RECIPIENT"]

# config = AutoConfig.from_pretrained(
#     model_path,
#     num_labels=len(class_labels),
# )
# tokenizer = AutoTokenizer.from_pretrained(
#     model_path,
#     use_fast=True,
# )
# model = AutoModelForTokenClassification.from_pretrained(
#     model_path,
#     from_tf=bool(".ckpt" in model_path),
#     config=config
# )

RuntimeError: Error(s) in loading state_dict for LayoutLMv2ForTokenClassification:
	size mismatch for classifier.weight: copying a param with shape torch.Size([7, 768]) from checkpoint, the shape in current model is torch.Size([11, 768]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([7]) from checkpoint, the shape in current model is torch.Size([11]).

In [None]:
class TextractReader:

    def __init__(self, textract_path, doc_dir):
        self.doc_dir = doc_dir
        self.textract_data = {}
        self.parse_textract(textract_path)

    def parse_textract(self, textract_path):
        textract_data = pd.read_json(textract_path, encoding="utf-8")
        for filename, res in textract_data.itertuples(index=False):
            # if filename not in ["9498101535.jpg"]: continue
            if res.get("JobStatus") not in ["SUCCEEDED", None]: 
                print(f"SKIP {filename} due to {res.get('JobStatus')} status")
                continue
            doc_path = os.path.join(self.doc_dir, filename)
            if not os.path.exists(doc_path):
                print(f"SKIP {filename} because {doc_path} does not exist")

            page_num = res["DocumentMetadata"]["Pages"]
            blocks = self.parse_blocks(res["Blocks"], page_num)
            imgs = self.read_images(filename)

            pages = []
            for i in range(page_num):
                page_id = i+1
                page_df = blocks[blocks["page"] == page_id].copy()
                page_df = page_df.reset_index(drop=True)

                if len(page_df) == 0: 
                    continue
                height, width, _ = imgs[i].shape
                page_df[["x0", "x2", "xtl", "xbr"]] *= width
                page_df[["y0", "y2", "ytl", "ybr"]] *= height
                angle = self.check_orientation(page_df)
                page_df = self.refine_orientation(page_df, angle, width, height)
                page_df["page"] = page_id
                page_df = page_df[["xtl", "ytl", "xbr", "ybr", "text", "page"]]
                pages.append({"data": page_df, "angle": angle, "page_id": page_id})
            doc_data = {"pages": pages, "page_num": page_num}
            self.textract_data[filename] = doc_data

    def get_data(self, filename, return_img=True):
        if self.textract_data.get(filename) is None:
            return None
            
        page_num = self.textract_data[filename]["page_num"]
        if return_img:
            imgs = self.read_images(filename)
            assert len(imgs) == page_num
        else:
            imgs = []
        pages = self.textract_data[filename]["pages"]

        data = []
        for i, page in enumerate(pages):
            item = {}
            page_df = page["data"]
            item["anno"] = page_df
            if return_img:
                angle = page["angle"]
                img = imgs[page["page_id"]-1]
                angle_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
                if angle != 0:
                    img = cv2.rotate(img, angle_map[angle])
                item["img"] = img
            data.append(item)
        return data

    
    def read_images(self, filename):
        doc_path = os.path.join(self.doc_dir, filename)
        imgs = []
        if filename[-3:] in ["jpg", "png"]:
            img = cv2.imread(doc_path)
            imgs.append(img)
        elif filename[-3:] in ["pdf"]:
            pil_imgs = convert_from_path(doc_path)
            for i, img in enumerate(pil_imgs):
                img = np.array(img)
                img = img[:, :, ::-1]
                imgs.append(img)
        return imgs
            
    def parse_blocks(self, blocks, page_num):
        block_df = pd.DataFrame()
        for block in blocks:
            if block["BlockType"] != "WORD": continue
            conf = block["Confidence"]
            text = block["Text"]
            page = block.get("Page")
            if page is None and page_num > 1:
                raise Exception("Page is None while number of pages > 1")
            if page is None:
                page = 1
            polygon = block["Geometry"]["Polygon"]
            X = [p['X'] for p in polygon]
            Y = [p['Y'] for p in polygon]
            data = {"x0": X[0], "y0": Y[0], "x2": X[2], "y2": Y[2],
                    "xtl": min(X), "ytl": min(Y), "xbr": max(X), "ybr": max(Y),
                    "text": text, "score": conf, "page": page}
            block_df = block_df.append(data, ignore_index=True)
        return block_df

    def check_orientation(self, df, ref_num=10):
        df = df.copy()
        df = df.sort_values(by=["score"], ascending=False)
        df = df[["x0", "y0", "x2", "y2"]]
        angles = []
        for x0, y0, x2, y2 in df.head(ref_num).itertuples(index=False):
            if x0 < x2 and y0 < y2:
                angles.append(0)
            elif x0 < x2 and y0 > y2:
                angles.append(90)
            elif x0 > x2 and y0 < y2:
                angles.append(270)
            else:
                angles.append(180)
        cnt = Counter(angles)
        return cnt.most_common()[0][0]

    def refine_orientation(self, textract_df, angle, width, height):
        if angle == 90:
            textract_df["xtl"] = height - textract_df["y0"]
            textract_df["xbr"] = height - textract_df["y2"]
            textract_df["ytl"] = textract_df["x0"]
            textract_df["ybr"] = textract_df["x2"]
        elif angle == 270:
            textract_df["xtl"] = textract_df["y0"]
            textract_df["xbr"] = textract_df["y2"]
            textract_df["ytl"] = width - textract_df["x0"]
            textract_df["ybr"] = width - textract_df["x2"]
        elif angle == 180:
            textract_df["xtl"] = width - textract_df["x0"]
            textract_df["xbr"] = width - textract_df["x2"]
            textract_df["ytl"] = height - textract_df["y0"]
            textract_df["ybr"] = height - textract_df["y2"]
        return textract_df

# textract_reader = TextractReader("/media/minh/Storage/projects/EGS/InvoiceDataExtraction/test/commercial-invoice.json", "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/test")
textract_reader = TextractReader("/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/textract_extraction_filtered.json", "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/20220425_FilteredData")
print(len(textract_reader.textract_data))

In [None]:
debug_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/debug"

def debug(img, df, output_path):
    img = img.copy()
    for _, tag, _, class_name, word, box in df.itertuples(index=False):
        x1, y1, x2, y2 = map(round, box)
        if class_name == "TOTALAMOUNT":
            color = (0, 0, 255)
        elif class_name == "SUPPLIER_NAME":
            color = (255, 0, 0)
        elif class_name == "SUPPLIER_ADDR":
            color = (0, 255, 0)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
        # img = cv2.putText(img, tag, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 
        #            0.2, color, 2, cv2.LINE_AA)
    cv2.imwrite(output_path, img)

def tokenize_and_align_labels(examples):
    padding = "max_length"
    text_column_name = "tokens"

    tokenized_inputs = tokenizer(
        examples[text_column_name],
        padding=padding,
        truncation=True,
        return_overflowing_tokens=True,
        is_split_into_words=True,
    )

    bboxes = []
    images = []
    _word_ids = []

    for batch_index in range(len(tokenized_inputs["input_ids"])):
        word_ids = tokenized_inputs.word_ids(batch_index=batch_index)
        _word_ids += word_ids
        org_batch_index = tokenized_inputs["overflow_to_sample_mapping"][batch_index]

        bbox = examples["norm_bboxes"][org_batch_index]
        image = examples["image"][org_batch_index]


        previous_word_idx = None
        # label_ids = []
        bbox_inputs = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                bbox_inputs.append([0, 0, 0, 0])
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                bbox_inputs.append(bbox[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                bbox_inputs.append(bbox[word_idx])
            previous_word_idx = word_idx

        bboxes.append(bbox_inputs)
        images.append(image)
    tokenized_inputs["bbox"] = bboxes
    tokenized_inputs["image"] = images

    overflow_mapping = tokenized_inputs["overflow_to_sample_mapping"]
    tokenized_inputs.pop("overflow_to_sample_mapping", None)
    
    return tokenized_inputs, overflow_mapping, _word_ids

def transform_image(image):
    h = image.shape[0]
    w = image.shape[1]
    img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
    image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1)  # copy to make it writeable
    return image, (w, h)

def generate_example(img, page_df):
    tokens = []
    bboxes = []
    norm_bboxes = []
    img, size = transform_image(img)
    for xtl, ytl, xbr, ybr, text, _ in page_df.itertuples(index=False):
        box = [xtl, ytl, xbr, ybr]
        tokens.append(text)
        bboxes.append(box)
        norm_bboxes.append(normalize_bbox(box, size))

    return {"tokens": [tokens], "bboxes": [bboxes], "norm_bboxes": [norm_bboxes], "image": [img]}

def convert_to_tensor(inputs):
    inputs_t = dict()
    for k, v in inputs.items():
        if isinstance(v[0], list):
            inputs_t[k] = torch.tensor(v)
        elif isinstance(v[0], torch.Tensor):
            inputs_t[k] = torch.stack(v)
        else:
            raise Exception(f"{k} is a list of type {type(v[0])}")
    return inputs_t

def get_class(text):
    l = text.split("-")
    return l[1] if len(l) == 2 else "OTHER"

def refine(words, boxes, tags, word_ids):
    df = pd.DataFrame({"word_id": word_ids, "tag": tags})
    df.dropna(inplace=True)
    df = df.astype({"word_id": "int32"})
    df = df.drop_duplicates(["word_id"])
    df = df[df["tag"] != "O"]
    df["prefix"] = df["tag"].map(lambda x : x.split("-")[0])
    df["class"] = df["tag"].map(get_class)
    df["word"] = df["word_id"].map(lambda x : words[x])
    df["bbox"] = df["word_id"].map(lambda x : boxes[x])
    return df

def infer_one_page(page_data):
    res = pd.DataFrame(columns=["xtl", "ytl", "xbr", "ybr", "text", "label", "page"])
    class_map = {"SUPPLIER_NAME": "supplier_name", "SUPPLIER_ADDR": "supplier_addr", "TOTALAMOUNT": "totalAmount"}

    img = page_data["img"]
    page_df = page_data["anno"]
    page_id = np.unique(page_df["page"]).item()

    example = generate_example(img, page_df)
    inputs, overflow_mapping, word_ids = tokenize_and_align_labels(example)
    inputs_t = convert_to_tensor(inputs)

    with torch.no_grad():
        outputs = model(**inputs_t)
        input_boxes = inputs_t["bbox"].reshape([-1, 4]).tolist()
        preds = torch.argmax(outputs.logits, -1).reshape([-1]).tolist()
        
    pred_tags = [class_labels[i] for i in preds]
    words = example["tokens"][0]
    boxes = example["bboxes"][0]
    df = refine(words, boxes, pred_tags, word_ids)
    # debug(img, df, os.path.join(debug_dir, "commercial-invoice.jpg"))
    for _, tag, _, class_name, word, box in df.itertuples(index=False):
        x1, y1, x2, y2 = map(round, box)
        res = res.append({"xtl": x1, "ytl": y1, "xbr": x2, "ybr": y2, "text": word, "label": class_map[class_name], "page": page_id}, ignore_index=True)
    return res

In [None]:
# if os.path.isdir(debug_dir):
#     shutil.rmtree(debug_dir)
# os.makedirs(debug_dir)

def infer_one_doc(doc_data):
    doc_df = pd.DataFrame(columns=["xtl", "ytl", "xbr", "ybr", "text", "label", "page"])
    for page_data in doc_data:
        page_df = infer_one_page(page_data)
        doc_df = doc_df.append(page_df, ignore_index=True)
    return doc_df

# input_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/test"
input_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/20220425_FilteredData"

output_dir = "20220525_layoutlmv2_outputs"
if os.path.isdir(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)
fnames = ["9498101535.jpg"]
filenames = [x for x in os.listdir(input_dir) if os.path.splitext(x)[1] in [".jpg", ".png", ".pdf"]]

for filename in filenames:
    # if filename not in fnames: continue
    prefix = os.path.splitext(filename)[0]
    doc_data = textract_reader.get_data(filename)
    if doc_data is None:
        print(f"Skip {filename} because Textract data is missing")
        continue
    doc_df = infer_one_doc(doc_data)
    # optional debug
    doc_df.to_csv(os.path.join(output_dir, f"{prefix}.csv"), sep=",", index=False)

In [None]:
def is_adjacent(bi, bj):
    if bi["label"] != bj["label"]:
        return False
    min_h = min(bi["ybr"] - bi["ytl"], bj["ybr"] - bj["ytl"])
    min_w = min(bi["xbr"] - bi["xtl"], bj["xbr"] - bj["xtl"])
    dx = min(bi["xbr"], bj["xbr"]) - max(bi["xtl"], bj["xtl"])
    dy = min(bi["ybr"], bj["ybr"]) - max(bi["ytl"], bj["ytl"])

    dxr = dx/min_w
    dyr = dy/min_h

    if bi["label"] in ["supplier_name", "supplier_addr"] and dxr > 0.1 and dy > -min_h/2:
        return True
    if dyr > 0.2 and dx > -min_h:
        return True
    return False

def dfs(adj_mat, i, groups, gid):
    if groups[i] > -1:
        return groups
    groups[i] = gid
    for j in range(len(adj_mat[i])):
        if adj_mat[i, j] == 1:
            groups = dfs(adj_mat, j, groups, gid)
    return groups

def merge_words_one_page(page_df):
    res = pd.DataFrame(columns=["text", "label", "xtl", "ytl", "xbr", "ybr", "page"])
    page_df = page_df.reset_index(drop=True)
    adj_mat = np.zeros((len(page_df), len(page_df)), dtype=np.int32)
    for i, bi in page_df.iterrows():
        for j, bj in page_df.iterrows():
            if is_adjacent(bi, bj):
                adj_mat[i][j] = 1
    
    groups = [-1] * len(adj_mat)
    gid = 0
    for i in range(len(adj_mat)):
        groups = dfs(adj_mat, i, groups, gid)
        gid += 1
    page_df["group"] = groups
    page_id = np.unique(page_df["page"]).item()
    for _, group_df in page_df.groupby("group"):
        label = np.unique(group_df["label"]).item()
        text = " ".join(group_df["text"].tolist())
        xtl = group_df["xtl"].min()
        ytl = group_df["ytl"].min()
        xbr = group_df["xbr"].max()
        ybr = group_df["ybr"].max()
        res = res.append({"text": text, "label": label, "xtl": xtl, "ytl": ytl, "xbr": xbr, "ybr": ybr, "page": page_id}, ignore_index=True)
    return res

def is_aligned(b, x):
    W = b[2] - b[0]
    H = b[3] - b[1]
    roi = [0, b[1]-50*H, b[2]+50*W, b[3]]
    dx = min(roi[2], x[2]) - max(roi[0], x[0])
    dx = max(dx, 0)
    dy = min(roi[3], x[3]) - max(roi[1], x[1])
    dy = max(dy, 0)
    return dx*dy > 0

class DataExtractor:

    def __init__(self, country_currency_map_file, symbol_currency_map_file):
        self.country_df = pd.read_csv(country_currency_map_file)
        self.country_df["CountryCode"] = self.country_df["CountryCode"].astype("str")
        self.symbol_df = pd.read_csv(symbol_currency_map_file)

    def __call__(self, doc_pred_df, doc_ocr_data):
        res = {"supplier_name": "", "totalAmount": "", "currencyCode": ""}
        # get total amount
        total_amount, total_amount_full_text, total_amount_box, total_amount_page = self.extract_total_amount(doc_pred_df)
        res["totalAmount"] = total_amount
        # get currency
        currency = self.extract_currency(total_amount_full_text, total_amount_box, total_amount_page, doc_ocr_data, doc_pred_df)
        res["currencyCode"] = currency
        
        # get supplier name
        supplier_name = ""
        name_pool = doc_pred_df[(doc_pred_df["page"] == total_amount_page) & (doc_pred_df["label"] == "supplier_name")]["text"].tolist()
        if name_pool:
            supplier_name = sorted(name_pool, key=lambda x : len(x.split()), reverse=True)[0]
        if not supplier_name:
            for page, post_page_df in doc_pred_df.groupby("page"):
                name_pool = post_page_df[post_page_df["label"] == "supplier_name"]["text"].tolist()
                if name_pool:
                    supplier_name = sorted(name_pool, key=lambda x : len(x.split()), reverse=True)[0]
                    break
        res["supplier_name"] = supplier_name
        return res

    def extract_total_amount(self, doc_pred_df):
        total_amount = ""
        total_amount_full_text = ""
        total_amount_box = None
        total_amount_page = None
        for text, _, xtl, ytl, xbr, ybr, page_id in doc_pred_df[doc_pred_df["label"] == "totalAmount"].itertuples(index=False):
            m = re.search("\d+(?:.\d|,\d)*\d*", text)
            if m is not None:
                subtext = m.group()
                # print(subtext)
                subtext = re.sub(",", ".", subtext)
                subtext = re.sub("[.](?=.*[.]|\d{3})", "", subtext)
                # print(subtext)
                try:
                    t = float(subtext)
                except:
                    t = None
                else:
                    total_amount = subtext
                    total_amount_full_text = text
                    total_amount_box = [xtl, ytl, xbr, ybr]
                    total_amount_page = page_id
                    break
        return total_amount, total_amount_full_text, total_amount_box, total_amount_page

    def extract_currency(self, total_amount_full_text, total_amount_box, total_amount_page, doc_ocr_data, doc_pred_df):
        symbol = ""
        currency_code = ""
        # extract currency from total amount text
        if total_amount_full_text:
            for sym, code in self.symbol_df.itertuples(index=False):
                if sym in total_amount_full_text:
                    currency_code = code
                    symbol = sym
                    if symbol != "$":
                        break
        # print(symbol, currency_code)
        if not currency_code:
            # extract currency from neighbor boxes -> full page
            
            aligned_list=[]
            non_aligned_list=[]
            for page_data in doc_ocr_data:
                page_df = page_data["anno"]
                page_id = np.unique(page_df["page"]).item()
                if page_id != total_amount_page:
                    continue
                # print(total_amount_box)
                for xtl, ytl, xbr, ybr, text, page in page_df.itertuples(index=False):
                    box = [xtl, ytl, xbr, ybr]
                    for sym, code in self.symbol_df.itertuples(index=False):
                        if sym in text:
                            if total_amount_box is not None and is_aligned(total_amount_box, box):
                                aligned_list.append((sym, code))
                            else:
                                non_aligned_list.append((sym, code))
            
            if not aligned_list and not non_aligned_list:
                for page_data in doc_ocr_data:
                    page_df = page_data["anno"]
                    for xtl, ytl, xbr, ybr, text, page in page_df.itertuples(index=False):
                        box = [xtl, ytl, xbr, ybr]
                        for sym, code in self.symbol_df.itertuples(index=False):
                            if sym in text:
                                if total_amount_box is not None and is_aligned(total_amount_box, box):
                                    aligned_list.append((sym, code))
                                else:
                                    non_aligned_list.append((sym, code))
            
            aligned_list = Counter(aligned_list)
            non_aligned_list = Counter(non_aligned_list)
            if aligned_list:
                symbol, currency_code = aligned_list.most_common()[0][0]
                if symbol == "$" and len(aligned_list) > 1:
                    s, c = aligned_list.most_common()[1][0]
                    if c == currency_code:
                        symbol = s
            elif non_aligned_list:
                symbol, currency_code = non_aligned_list.most_common()[0][0]
                if symbol == "$" and len(non_aligned_list) > 1:
                    s, c = non_aligned_list.most_common()[1][0]
                    if c == currency_code:
                        symbol = s
        # print(symbol, currency_code)
        # extract currency by  address
        if symbol in ["", "$"]:
            codes = []
            addresses = doc_pred_df[(doc_pred_df["label"] == "supplier_addr") & (doc_pred_df["page"] == total_amount_page)]["text"]
            for address in addresses:
                for country, country_code ,_, code in self.country_df.itertuples(index=False):
                    if country.lower() in address.lower():
                        # print(country)
                        codes.append(code)
            codes = Counter(codes)
            if codes:
                code = codes.most_common()[0][0]
                if symbol == "":
                    currency_code = code
                elif symbol == "$":
                    if code in ["USD", "SGD", "HKD", "NZD", "AUD", "CAD", "TWD", "HKD"]:
                        currency_code = code
        # print(symbol, currency_code)
        return currency_code
        
extractor = DataExtractor("../../../sdmgr_inference/country_currency_mapping.csv", "../../../sdmgr_inference/currency_symbol_mapping.csv")


In [None]:
result_file = "20220525_layoutlmv2_results.csv"
out_df = pd.DataFrame(columns=["code", "supplier_name", "totalAmount", "currencyCode"])
# fnames = ["9506922289"]
filenames = os.listdir("/media/minh/Storage/projects/EGS/InvoiceDataExtraction/sdmgr_inference/20220425_FilteredData")
from glob import glob
import requests
csv_paths = glob(os.path.join("20220525_layoutlmv2_outputs", "*csv"))
for csv_path in csv_paths:
    prefix = os.path.splitext(os.path.basename(csv_path))[0]
    # if prefix not in fnames: continue
    print(csv_path)
    for name in filenames:
        if prefix in name:
            filename = name
    doc_ocr_data = textract_reader.get_data(filename, return_img=False)
    doc_kie_df = pd.read_csv(csv_path)
    doc_kie_df = doc_kie_df.astype({"text": "str"})
    doc_pred_df = pd.DataFrame(columns=["text", "label", "xtl", "ytl", "xbr", "ybr", "page"])
    for page, page_kie_df in doc_kie_df.groupby("page"):
        page_pred_df = merge_words_one_page(page_kie_df)
        doc_pred_df = doc_pred_df.append(page_pred_df, ignore_index=True)
    # print(doc_pred_df)
    doc_out = {"code": prefix}
    ext_out = extractor(doc_pred_df, doc_ocr_data)
    doc_out.update(ext_out)
    out_df = out_df.append(doc_out, ignore_index=True)

out_df.to_csv(result_file, index=False)

In [None]:
prd_df = pd.read_csv("20220525_layoutlmv2_results.csv")
prd_df.columns = ["code", "pred_supplier_name", "pred_totalAmount", "pred_currencyCode"]
gt_df = pd.read_csv("/media/minh/Storage/projects/EGS/InvoiceDataExtraction/metadata.csv")
gt_df["true_supplier_name,true_totalAmount,true_currencyCode".split(",")] = gt_df["supplier_name,totalAmount,currencyCode".split(",")]
gt_df = gt_df["code,true_supplier_name,true_totalAmount,true_currencyCode".split(",")]
df = prd_df.merge(gt_df, on="code", how="inner")
df["correct_supplier_name"] = ""
df["correct_totalAmount"] = ""
df["correct_currencyCode"] = ""
df = df[["code", "pred_supplier_name", "true_supplier_name", "correct_supplier_name", "pred_totalAmount", "true_totalAmount", "pred_currencyCode", "true_currencyCode", "correct_currencyCode"]]

# df.to_excel("20220525_layoutlmv2_results_eval.xlsx", index=False)
df.to_csv("20220525_layoutlmv2_results_eval.csv", index=False)

In [None]:
d = pd.read_csv("20220525_layoutlmv2_results_eval.csv")
print(np.sum(d["pred_currencyCode"] != d["true_currencyCode"]), np.sum(d["pred_totalAmount"] != d["true_totalAmount"]))