In [92]:
from transformers import LayoutLMv3ImageProcessor, AutoModelForTokenClassification, LayoutLMv3Processor, LayoutLMv3Tokenizer
from datasets import DatasetDict
from difflib import SequenceMatcher
from PIL import Image, ImageDraw, ImageFont
import torch
import numpy as np

In [5]:
# processor = LayoutLMv3ImageProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=True)

processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=True)

image_processor = LayoutLMv3ImageProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=True)

tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")

model = AutoModelForTokenClassification.from_pretrained("bmeisburger/datathon", local_files_only=False)

dataset = DatasetDict.load_from_disk("dataset/")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

LayoutLMv3ForTokenClassification(
  (layoutlmv3): LayoutLMv3Model(
    (embeddings): LayoutLMv3TextEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
    )
    (patch_embed): LayoutLMv3PatchEmbeddings(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (encoder): LayoutLMv3Encoder

In [152]:
import string
printable = set(string.printable)

label2color = {'total': 'blue', 'company': 'green',
               'date': 'orange', 'address': 'violet', 'other': 'grey'}

id2label = {0: 'S-TOTAL', 1: 'S-DATE', 2: 'S-ADDRESS', 3: 'S-COMPANY', 4: 'O'}

bad_words = set(["<s>"])

def unnormalize_box(bbox, width, height):
    return [
        width * (bbox[0] / 1000),
        height * (bbox[1] / 1000),
        width * (bbox[2] / 1000),
        height * (bbox[3] / 1000),
    ]


def iob_to_label(label):
    label = label[2:]
    if not label:
      return 'other'
    return label


def to_dict(text, predictions):  
    """ 
    Given inputed list of strings from OCR and the resulting predictions in IOB form, 
    returns dictionary containing the words (tolkens) that are classified 
    as either the total, date, company, or address. 
    """
    im_info = {"total": [],
               "date": [],
               "company": [],
               "address": []} 

    for text, preds in zip(text, predictions):
        pred_label = iob_to_label(preds).lower()
        if pred_label != 'other':
            im_info[pred_label].append(text)

    im_info['total'] = list(set(im_info['total']))
    
    return im_info


def concat_dict(dict):
    s = ""
    for val in dict.values():
        if len(val) > 0:
            concat = ' '.join(val)
            s = s + ' ' + concat

    return s


# converts a list of subtokens and their offset mappings to
# a list of words that *should* be the same length as bboxes/preds
# def tokens_to_phrases(tokens, offset_mapping):
#     last_y = -1
#     word = ""
#     words = []
#     for token, offset_pair in zip(tokens, offset_mapping):
#         if offset_pair == [0, 0]:
#             words.append(token)
#             last_y = -1
#         elif offset_pair[0] == last_y:
#             word += token
#             last_y = offset_pair[1]
#         else:
#             words.append(word)
#             word = token
#             last_y = offset_pair[1]

#     words = [word.strip() for word in words]
#     words = [''.join(filter(lambda x: x in printable, w)) for w in words]
#     words = [word for word in words if word not in bad_words]

#     for i, word in enumerate(words[1:]):
#         prev_word = words[i]
#         # print("(", prev_word, ",", word, ")")
#         if len(word) > 3 and (SequenceMatcher(a=prev_word, b=word).ratio() > 0.7):
#             words[i] = '\n'

#     no_dupes = list(filter(lambda x: x != '\n', words))

#     print(no_dupes)
#     print(len(no_dupes))
#     print(words)
#     print(len(words))
    
#     return words

In [None]:
def process_example(example):
    image = example['image']

    encoded_inputs = processor(example['image'], example['words'], boxes=example['bboxes'], word_labels=example['ner_tags'],
                           padding="max_length", truncation=True, return_tensors="pt")

    

In [153]:
def process_image(image):
    # don't draw on original image!
    image = image.copy()

    width, height = image.size

    # encode
    encoding = processor(image, truncation=True, return_offsets_mapping=True, return_tensors="pt")
    
    # encoding = processor.image_processor.preprocess(image, return_tensors="pt")
    # words = encoding['words'][0]
    # print(encoding.keys(), flush=True)
    # encoding = processor.tokenizer(text=encoding['words'], boxes=encoding['boxes'], truncation=True, return_tensors="pt", return_offsets_mapping=True)

    words = processor.image_processor.preprocess(image)['words'][0]
    words = [''.join(filter(lambda x: x in printable, w)) for w in words]
    words.insert(0, "<s>")
    # print(words)
    # tokens = encoding['input_ids'].tolist()[0]
    # tokens = processor.tokenizer.batch_decode(tokens)

    # print(phrase)
    # print(tokens)
    # print(len(tokens))
    offset_mapping = encoding.pop('offset_mapping')

    # words = tokens_to_phrases(tokens, offset_mapping.tolist()[0])
    # print(words)
    # print(offset_mapping.tolist()[0])
    # print(len(offset_mapping.tolist()[0]))

    # return tokens, offset_mapping.tolist()[0]

    # forward pass
    outputs = model(**encoding)

    # get predictions
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()

    # only keep non-subword predictions
    is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]

    # draw predictions over the image
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    for word, prediction, box in zip(words, true_predictions, true_boxes):
        predicted_label = iob_to_label(prediction).lower()
        draw.rectangle(box, outline=label2color[predicted_label])
        draw.text((box[0]+10, box[1]-10), text=word, fill=label2color[predicted_label], font=font)
    
    image.show()

    return concat_dict(to_dict(words, true_predictions))

In [None]:
# 13, 1, 15
example = dataset["test"][25]
image = example['image']

dict = process_image(image)

print(dict)

# for word, pred in zip(words, preds):
#     label = iob_to_label(pred).lower()
#     print("Word:", word)
#     print("Label:", label)
#     print()