In [1]:
from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
    HfArgumentParser,
    PreTrainedTokenizerFast,
    TrainingArguments,
    set_seed,
)
from layoutlmft.data.utils import load_image, normalize_bbox
import os
import json
import torch
import shutil
from glob import glob
import cv2
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/unilm/20220516_outputs/checkpoint-4000"
class_labels = ['O', 'B-OTHER', 'I-OTHER', 'B-SUPPLIER_NAME', 'I-SUPPLIER_NAME', 'B-SUPPLIER_ADDR', 'I-SUPPLIER_ADDR', 'B-TOTALAMOUNT', 'I-TOTALAMOUNT']
label_to_id = {l: i for i, l in enumerate(class_labels)}

config = AutoConfig.from_pretrained(
    model_path,
    num_labels=len(class_labels),
)
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    use_fast=True,
)
model = AutoModelForTokenClassification.from_pretrained(
    model_path,
    from_tf=bool(".ckpt" in model_path),
    config=config
)

In [3]:
def tokenize_and_align_labels(examples):
    padding = "max_length"
    text_column_name = "tokens"
    # label_column_name = "ner_tags"
    # label_all_tokens = False
    # label_to_id = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8}

    tokenized_inputs = tokenizer(
        examples[text_column_name],
        padding=padding,
        truncation=True,
        return_overflowing_tokens=True,
        # We use this argument because the texts in our dataset are lists of words (with a label for each word).
        is_split_into_words=True,
    )

    # labels = []
    bboxes = []
    images = []
    _word_ids = []

    for batch_index in range(len(tokenized_inputs["input_ids"])):
        word_ids = tokenized_inputs.word_ids(batch_index=batch_index)
        _word_ids += word_ids
        org_batch_index = tokenized_inputs["overflow_to_sample_mapping"][batch_index]

        # label = examples[label_column_name][org_batch_index]
        bbox = examples["norm_bboxes"][org_batch_index]
        image = examples["image"][org_batch_index]


        previous_word_idx = None
        # label_ids = []
        bbox_inputs = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                # label_ids.append(-100)
                bbox_inputs.append([0, 0, 0, 0])
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                # label_ids.append(label_to_id[label[word_idx]])
                bbox_inputs.append(bbox[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                # label_ids.append(label_to_id[label[word_idx]] if label_all_tokens else -100)
                bbox_inputs.append(bbox[word_idx])
            previous_word_idx = word_idx

        # labels.append(label_ids)
        bboxes.append(bbox_inputs)
        images.append(image)
    # tokenized_inputs["labels"] = labels
    tokenized_inputs["bbox"] = bboxes
    tokenized_inputs["image"] = images

    overflow_mapping = tokenized_inputs["overflow_to_sample_mapping"]
    tokenized_inputs.pop("overflow_to_sample_mapping", None)
    
    return tokenized_inputs, overflow_mapping, _word_ids

def generate_example(image_path, json_path):
    # ann_dir = os.path.join(filepath, "annotations")
    # img_dir = os.path.join(filepath, "images")
    tokens = []
    bboxes = []
    norm_bboxes = []
    ner_tags = []

    with open(json_path, "r", encoding="utf8") as f:
        data = json.load(f)
    image, size = load_image(image_path)
    for item in data["form"]:
        words, label = item["words"], item["label"]
        words = [w for w in words if w["text"].strip() != ""]
        if len(words) == 0:
            continue
        if label == "other":
            for w in words:
                tokens.append(w["text"])
                ner_tags.append("O")
                bboxes.append(w["box"])
                norm_bboxes.append(normalize_bbox(w["box"], size))
        else:
            tokens.append(words[0]["text"])
            ner_tags.append("B-" + label.upper())
            bboxes.append(words[0]["box"])
            norm_bboxes.append(normalize_bbox(words[0]["box"], size))
            for w in words[1:]:
                tokens.append(w["text"])
                ner_tags.append("I-" + label.upper())
                bboxes.append(w["box"])
                norm_bboxes.append(normalize_bbox(w["box"], size))

    return {"tokens": [tokens], "bboxes": [bboxes], "norm_bboxes": [norm_bboxes], "ner_tags": [ner_tags], "image": [image]}

def convert_to_tensor(inputs):
    inputs_t = dict()
    for k, v in inputs.items():
        if isinstance(v[0], list):
            inputs_t[k] = torch.tensor(v)
        elif isinstance(v[0], torch.Tensor):
            inputs_t[k] = torch.stack(v)
        else:
            raise Exception(f"{k} is a list of type {type(v[0])}")
    return inputs_t

In [71]:
image_path = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/invoice_data/testing_data/images/9508519816_page_1.jpg"
json_path = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/invoice_data/testing_data/annotations/9508519816_page_1.json"
# image_path = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/invoice_data/training_data/images/9533207517.jpg"
# json_path = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/invoice_data/training_data/annotations/9533207517.json"
# {'id': Value(dtype='string', id=None), 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'bboxes': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None), 'ner_tags': Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-OTHER', 'I-OTHER', 'B-SUPPLIER_NAME', 'I-SUPPLIER_NAME', 'B-SUPPLIER_ADDR', 'I-SUPPLIER_ADDR', 'B-TOTALAMOUNT', 'I-TOTALAMOUNT'], id=None), length=-1, id=None), 'image': Array3D(shape=(3, 224, 224), dtype='uint8', id=None)}

example = generate_example(image_path, json_path)
# example
inputs, overflow_mapping, word_ids = tokenize_and_align_labels(example)
# inputs.pop("labels", None)
inputs_t = convert_to_tensor(inputs)


In [72]:
outputs = model(**inputs_t)

  // self.config.image_feature_pool_shape[1]
  // self.config.image_feature_pool_shape[0]


In [4]:
debug_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/debug"

def debug(image_path, df):
    img = cv2.imread(image_path)
    for _, tag, _, class_name, word, box in df.itertuples(index=False):
        x1, y1, x2, y2 = box
        if class_name == "TOTALAMOUNT":
            color = (0, 0, 255)
        elif class_name == "SUPPLIER_NAME":
            color = (255, 0, 0)
        elif class_name == "SUPPLIER_ADDR":
            color = (0, 255, 0)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
        img = cv2.putText(img, tag, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 
                   0.2, color, 2, cv2.LINE_AA)
    output_path = os.path.join(debug_dir, os.path.basename(image_path))
    cv2.imwrite(output_path, img)

def same_box(b1, b2):
    for x, y in zip(b1, b2):
        if x != y:
            return False
    return True

def get_class(text):
    l = text.split("-")
    return l[1] if len(l) == 2 else "OTHER"

def refine(words, boxes, tags, word_ids):
    df = pd.DataFrame({"word_id": word_ids, "tag": tags})
    df.dropna(inplace=True)
    df = df.astype({"word_id": "int32"})
    df = df.drop_duplicates(["word_id"])
    df = df[df["tag"] != "O"]
    df["prefix"] = df["tag"].map(lambda x : x.split("-")[0])
    df["class"] = df["tag"].map(get_class)
    df["word"] = df["word_id"].map(lambda x : words[x])
    df["bbox"] = df["word_id"].map(lambda x : boxes[x])
    return df

In [74]:
input_boxes = inputs_t["bbox"].reshape([-1, 4]).tolist()
preds = torch.argmax(outputs.logits, -1).reshape([-1]).tolist()
pred_tags = [class_labels[i] for i in preds]
words = example["tokens"][0]
boxes = example["bboxes"][0]
# print(len(words), len(boxes), len(pred_tags), len(word_ids))
df = refine(words, boxes, pred_tags, word_ids)
debug(image_path, df)




In [5]:
if os.path.isdir(debug_dir):
    shutil.rmtree(debug_dir)
os.makedirs(debug_dir)

image_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/20220516_invoice_data/testing_data/images"
json_dir = "/media/minh/Storage/projects/EGS/InvoiceDataExtraction/20220516_invoice_data/testing_data/annotations"

image_paths = glob(os.path.join(image_dir, "*.jpg"))
for image_path in image_paths:
    prefix = os.path.splitext(os.path.basename(image_path))[0]
    json_path = os.path.join(json_dir, f"{prefix}.json")


    example = generate_example(image_path, json_path)
    # example
    inputs, overflow_mapping, word_ids = tokenize_and_align_labels(example)
    # inputs.pop("labels", None)
    inputs_t = convert_to_tensor(inputs)
    outputs = model(**inputs_t)
    input_boxes = inputs_t["bbox"].reshape([-1, 4]).tolist()
    preds = torch.argmax(outputs.logits, -1).reshape([-1]).tolist()
    pred_tags = [class_labels[i] for i in preds]
    words = example["tokens"][0]
    boxes = example["bboxes"][0]
    # print(len(words), len(boxes), len(pred_tags), len(word_ids))
    df = refine(words, boxes, pred_tags, word_ids)
    debug(image_path, df)

  // self.config.image_feature_pool_shape[1]
  // self.config.image_feature_pool_shape[0]
