# Import Lib

In [None]:
!nvidia-smi

In [1]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/data/tungtx2/tmp/transformers_hub'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import os
import json
from pathlib import Path
import numpy as np
from PIL import Image
import torch

torch.__version__

  from .autonotebook import tqdm as notebook_tqdm


'1.13.1+cu117'

# Prepare Data

In [3]:
train_dir = 'real_data/train_labeled_ocred'
val_dir = 'real_data/val_labeled_ocred'

def find_all_labels(data_dir, disable_marker=False):
    labels = {}
    for jp in Path(data_dir).rglob('*.json'):
        data = json.load(open(jp))
        for shape in data['shapes']:
            if disable_marker and 'marker' in shape['label']:
                label = 'text'
            else:
                label = shape['label']
                
            if label in labels:
                labels[label] += 1
            else:
                labels[label] = 1
                
    return labels

train_labels = find_all_labels(train_dir, disable_marker=False)
val_labels = find_all_labels(val_dir, disable_marker=False)
assert set(train_labels.keys()) == set(val_labels.keys())
for k, v in train_labels.items():
    print(k, ':', v)

text : 64974
account_number : 419
marker_account_number : 585
swift_code : 235
marker_swift_code : 407
bank_name : 1493
marker_bank_name : 426
fax : 566
marker_fax : 314
phone : 963
marker_phone : 490
company_address : 7566
company_name : 3916
marker_company_name : 1352
bank_address : 1228
marker_bank_address : 208
marker_company_address : 505
represented_position : 464
marker_represented_position : 89
represented_name : 1076
marker_represented_name : 612
tax : 57
marker_tax : 99


In [4]:
label_list = list(set(train_labels.keys()))
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}

print(label2id)
print(id2label)

{'marker_company_address': 0, 'text': 1, 'phone': 2, 'fax': 3, 'bank_address': 4, 'bank_name': 5, 'marker_bank_name': 6, 'marker_fax': 7, 'company_name': 8, 'marker_account_number': 9, 'swift_code': 10, 'marker_represented_name': 11, 'marker_represented_position': 12, 'marker_phone': 13, 'marker_bank_address': 14, 'marker_swift_code': 15, 'company_address': 16, 'represented_position': 17, 'marker_tax': 18, 'account_number': 19, 'tax': 20, 'represented_name': 21, 'marker_company_name': 22}
{0: 'marker_company_address', 1: 'text', 2: 'phone', 3: 'fax', 4: 'bank_address', 5: 'bank_name', 6: 'marker_bank_name', 7: 'marker_fax', 8: 'company_name', 9: 'marker_account_number', 10: 'swift_code', 11: 'marker_represented_name', 12: 'marker_represented_position', 13: 'marker_phone', 14: 'marker_bank_address', 15: 'marker_swift_code', 16: 'company_address', 17: 'represented_position', 18: 'marker_tax', 19: 'account_number', 20: 'tax', 21: 'represented_name', 22: 'marker_company_name'}


# Data Loader

In [5]:
from transformers import LayoutLMv3Processor, LayoutXLMTokenizerFast, LayoutLMv2FeatureExtractor

# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
# tokenizer = LayoutXLMTokenizerFast.from_pretrained('microsoft/layoutxlm-base')
# tokenizer.only_label_first_subword = False
# processor = LayoutXLMProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor = LayoutLMv3Processor.from_pretrained('microsoft/layoutlmv3-base', apply_ocr=False)
processor.tokenizer.only_label_first_subword = False

print(processor.tokenizer)
print(processor.tokenizer.only_label_first_subword)
print(processor.image_processor)

LayoutLMv3TokenizerFast(name_or_path='microsoft/layoutlmv3-base', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)})
False
LayoutLMv3ImageProcessor {
  "apply_ocr": false,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_ex

In [6]:
from os import listdir
from torch.utils.data import Dataset
import torch
from PIL import Image
import unidecode
from PIL import Image, ImageDraw, ImageFont
import pdb
import xml.etree.ElementTree as ET
from shapely.geometry import Polygon
import cv2


def normalize_bbox(bbox, width, height):
     return [
         int(1000 * (bbox[0] / width)),
         int(1000 * (bbox[1] / height)),
         int(1000 * (bbox[2] / width)),
         int(1000 * (bbox[3] / height)),
     ]
    
    
def parse_xml(xml_path):
    root = ET.parse(xml_path).getroot()
    objs = root.findall('object')
    boxes, obj_names = [], []
    for obj in objs:
        obj_name = obj.find('name').text
        box = obj.find('bndbox')
        xmin = int(box.find('xmin').text)
        ymin = int(box.find('ymin').text)
        xmax = int(box.find('xmax').text)
        ymax = int(box.find('ymax').text)
        boxes.append([xmin, ymin, xmax, ymax])
        obj_names.append(obj_name)
    return boxes, obj_names


def widen_box(box, percent_x, percent_y):
        xmin, ymin, xmax, ymax = box
        w = xmax - xmin
        h = ymax - ymin
        xmin -= w * percent_x
        ymin -= h * percent_y
        xmax += w * percent_x
        ymax += h * percent_y
        return (int(xmin), int(ymin), int(xmax), int(ymax))

    
def draw_json_on_img(img, json_data):
    labels = list(set(shape['label'] for shape in json_data['shapes']))
    color = {}
    for i in range(len(labels)):
        color[labels[i]] = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
        
    img = img.copy()
    draw = ImageDraw.Draw(img)
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_size = 0.5# Draw the text on the image
    # font = ImageFont.truetype(font.font.family, font_size)
    for i, shape in enumerate(json_data['shapes']):
        polys = shape['points']
        polys = [(int(pt[0]), int(pt[1])) for pt in polys]
        label = shape['label']
        draw.polygon(polys, outline=color[label], width=2)
        # Draw the text on the image
        img = np.array(img)
        cv2.putText(img, shape['label'], (polys[0][0], polys[0][1]-5), font, font_size, color[label], thickness=1)
        img = Image.fromarray(img)
        draw = ImageDraw.Draw(img)
    return img
    
    
def mask_image(img, boxes, json_data, widen_range_x, widen_range_y):
    # widen block
    if isinstance(widen_range_x, list) and isinstance(widen_range_y, list):
        boxes = [widen_box(box, np.random.uniform(widen_range_x[0], widen_range_x[1]), np.random.uniform(widen_range_y[0], widen_range_y[1])) for box in boxes]
    else:
        boxes = [widen_box(box, widen_range_x, widen_range_y) for box in boxes]
        
    
    ls_polys2keep = []
    ls_area2keep = []
    iou_threshold = 0.
    for box_idx, box in enumerate(boxes):
        xmin, ymin, xmax, ymax = box
        box_pts = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
        p_box = Polygon(box_pts)
        for shape_idx, shape in enumerate(json_data['shapes']):
            if shape_idx in ls_polys2keep:
                continue
            pts = shape['points']
            p_shape = Polygon(pts)
            intersect_area = p_box.intersection(p_shape).area
            if intersect_area / p_shape.area > iou_threshold:
                ls_polys2keep.append(shape_idx)
                pts = [coord for pt in pts for coord in pt]
                poly_xmin = min(pts[::2])
                poly_ymin = min(pts[1::2])
                poly_xmax = max(pts[::2])
                poly_ymax = max(pts[1::2])
                ls_area2keep.append((poly_xmin, poly_ymin, poly_xmax, poly_ymax))

    # mask white all area of image that is not in block
    mask = np.zeros(img.shape[:2], dtype=np.uint8)
    for box in boxes:
        xmin, ymin, xmax, ymax = box
        xmin = max(0, xmin)
        ymin = max(0, ymin)
        xmax = min(img.shape[1], xmax)
        ymax = min(img.shape[0], ymax)
        mask[ymin:ymax, xmin:xmax] = 255

    for area2keep in ls_area2keep:
        xmin, ymin, xmax, ymax = area2keep
        xmin = int(max(0, xmin))
        ymin = int(max(0, ymin))
        xmax = int(min(img.shape[1], xmax))
        ymax = int(min(img.shape[0], ymax))
        mask[ymin:ymax, xmin:xmax] = 255

    # mask white
    img[mask == 0] = 255

    # delete all poly that is not in block
    ls_idx2del = [idx for idx, shape in enumerate(json_data['shapes']) if idx not in ls_polys2keep]
    for idx in sorted(ls_idx2del, reverse=True):
        del json_data['shapes'][idx]

    return img, json_data
        

def gen_annotation_for_img(img_fp, xml_fp, json_fp, masked=False, widen_range_x=[0.1, 0.2], widen_range_y=[0.1, 0.25], disable_marker=False):
    img = Image.open(img_fp).convert("RGB")
    json_data = json.load(open(json_fp))
    
    if masked:
        block_boxes, obj_names = parse_xml(xml_fp)
        img, json_data = mask_image(np.array(img), boxes=block_boxes, json_data=json_data, widen_range_x=widen_range_x, widen_range_y=widen_range_y)
        img = Image.fromarray(img)
    # pdb.set_trace()
        
    words, orig_polys, normalized_boxes, labels = [], [], [], []
    img_h, img_w = json_data['imageHeight'], json_data['imageWidth']
    for i, shape in enumerate(json_data['shapes']):
        if disable_marker and 'marker' in shape['label']:
            current_label = 'text'
        else:
            current_label = shape['label']
            
        words.append(unidecode.unidecode(shape['text'].lower()))
        # words.append(shape['text'].lower())
        labels.append(current_label)
        pts = [coord for pt in shape['points'] for coord in pt]
        xmin = min(pts[0::2])
        xmax = max(pts[0::2])
        ymin = min(pts[1::2])
        ymax = max(pts[1::2])

        xmin = max(xmin, 0)
        ymin = max(ymin, 0)
        xmax = min(img_w, xmax)
        ymax = min(img_h, ymax)

        normalized_boxes.append(normalize_bbox((xmin, ymin, xmax, ymax), img_w, img_h))
        orig_polys.append(tuple([tuple(pt) for pt in shape['points']]))

    return img, words, orig_polys, normalized_boxes, labels


class CORDDataset(Dataset):
    """CORD dataset."""

    def __init__(self, file_paths, processor=None, max_length=512, masked=False, widen_range_x=[0.1, 0.2], widen_range_y=[0.1, 0.25], disable_marker=False):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.ls_img_fp, self.ls_xml_fp, self.ls_json_fp = file_paths
        assert len(self.ls_img_fp) == len(self.ls_json_fp) == len(self.ls_xml_fp)
        self.processor = processor
        self.masked = masked
        self.widen_range_x = widen_range_x
        self.widen_range_y = widen_range_y
        self.disable_marker = disable_marker

    def __len__(self):
        return len(self.ls_img_fp)

    def __getitem__(self, index):
        # first, take an image
        img_fp = self.ls_img_fp[index]
        xml_fp = self.ls_xml_fp[index]
        json_fp = self.ls_json_fp[index]
        
        img, words, _, boxes, text_labels = gen_annotation_for_img(img_fp, xml_fp, json_fp, masked=self.masked, widen_range_x=self.widen_range_x, widen_range_y=self.widen_range_y, disable_marker=self.disable_marker)
        idx_labels = [label2id[label] for label in text_labels]

        encoded_inputs = self.processor(img, words, boxes=boxes, word_labels=idx_labels, truncation=True, stride =128, 
                            padding="max_length", max_length=512, return_overflowing_tokens=True, return_offsets_mapping=True, return_tensors="pt")  
        
        # print(encoded_inputs.keys())
        overflow_to_sample_mapping = encoded_inputs.pop('overflow_to_sample_mapping')
        offset_mapping = encoded_inputs.pop('offset_mapping')
        # print('overflow_to_sample_mapping: ', overflow_to_sample_mapping)
        # print('offset_mapping: ', offset_mapping)

        # remove batch dimension
        idx = np.random.randint(0, len(encoded_inputs['pixel_values']))
        for k, v in encoded_inputs.items():
            encoded_inputs[k] = v[idx]
      
        return encoded_inputs

In [7]:
def get_file_paths(data_dir):
    ls_img_fp, ls_xml_fp, ls_json_fp = [], [], []
    for img_fp in Path(data_dir).rglob('*.jpg'):
        json_fp = img_fp.with_suffix('.json')
        xml_fp = img_fp.with_suffix('.xml')
        
        ls_img_fp.append(str(img_fp))
        ls_xml_fp.append(str(xml_fp))
        ls_json_fp.append(str(json_fp))
    
    return ls_img_fp, ls_xml_fp, ls_json_fp


train_file_paths = get_file_paths(train_dir)
val_file_paths = get_file_paths(val_dir)

widen_range_x = [0., 0.01]
widen_range_y = [0.15, 0.3]
disable_marker = False
train_dataset = CORDDataset(file_paths=train_file_paths, processor=processor, masked=True, 
                            widen_range_x=widen_range_x, widen_range_y=widen_range_y, disable_marker=disable_marker)
val_dataset = CORDDataset(file_paths=val_file_paths, processor=processor, masked=True, 
                          widen_range_x=0.1, widen_range_y=0.15, disable_marker=disable_marker)

print(len(train_dataset))
print(len(val_dataset))

303
33


In [9]:
encoding = val_dataset[9]
for k,v in encoding.items():
  print(k, v.shape)

input_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])
pixel_values torch.Size([3, 224, 224])


In [10]:
ls_token = [processor.tokenizer.decode(input_id) for input_id in encoding['input_ids']]
ls_label = [id2label[int(label_id)] if label_id != -100 else 'SPECIAL' for label_id in encoding['labels'] ]
ls_bb = list(encoding['bbox'])
for item in zip(ls_token, ls_label, ls_bb):
  print(item)
  # break

('<s>', 'SPECIAL', tensor([0, 0, 0, 0]))
(' mentioned', 'text', tensor([616, 353, 680, 362]))
('.', 'text', tensor([616, 353, 680, 362]))
(' here', 'text', tensor([555, 353, 612, 362]))
('und', 'text', tensor([555, 353, 612, 362]))
('e', 'text', tensor([555, 353, 612, 362]))
(' conditions', 'text', tensor([496, 353, 551, 362]))
(" '", 'text', tensor([332, 353, 369, 362]))
('article', 'text', tensor([332, 353, 369, 362]))
(' and', 'text', tensor([468, 350, 493, 364]))
(' terms', 'text', tensor([434, 350, 468, 364]))
(' the', 'text', tensor([413, 352, 432, 363]))
(' with', 'text', tensor([386, 350, 413, 364]))
(' r', 'text', tensor([372, 352, 386, 364]))
("'", 'text', tensor([372, 352, 386, 364]))
(' at', 'text', tensor([315, 352, 328, 364]))
(' items', 'text', tensor([281, 350, 313, 364]))
(' of', 'text', tensor([263, 352, 277, 363]))
(' sales', 'text', tensor([229, 350, 262, 364]))
(' the', 'text', tensor([207, 350, 230, 364]))
(' agreed', 'text', tensor([149, 350, 189, 363]))
(' on', 

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

for item in train_dataloader:
  for k, v in item.items():
    print(k, v.shape)
  break

input_ids torch.Size([4, 512])
attention_mask torch.Size([4, 512])
bbox torch.Size([4, 512, 4])
labels torch.Size([4, 512])
pixel_values torch.Size([4, 3, 224, 224])


# Model

In [12]:
from transformers import LayoutLMv3ForTokenClassification, AdamW
import torch
from tqdm.notebook import tqdm

model = LayoutLMv3ForTokenClassification.from_pretrained('ckpt/masked/real_data/layoutlmv3_pretrained_fake_data/checkpoint-5200-best_f1-0.984-acc-0.995')
# model = LayoutLMv3ForTokenClassification.from_pretrained('microsoft/layoutlmv3-base', id2label=id2label)

print(model)

LayoutLMv3ForTokenClassification(
  (layoutlmv3): LayoutLMv3Model(
    (embeddings): LayoutLMv3TextEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
    )
    (patch_embed): LayoutLMv3PatchEmbeddings(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (encoder): LayoutLMv3Encoder

In [13]:
model.config.id2label = id2label
model.config.label2id = label2id

In [None]:
dir(model)

In [None]:
import torch.nn as nn

model.classifier.out_proj = nn.Linear(in_features=768, out_features=len(label_list), bias=True)

model.num_labels = len(label_list)

print(model)
print()
print(model.config)

In [None]:
len(label_list)

# Hugging Face Trainer

In [14]:
label_list

['marker_company_address',
 'text',
 'phone',
 'fax',
 'bank_address',
 'bank_name',
 'marker_bank_name',
 'marker_fax',
 'company_name',
 'marker_account_number',
 'swift_code',
 'marker_represented_name',
 'marker_represented_position',
 'marker_phone',
 'marker_bank_address',
 'marker_swift_code',
 'company_address',
 'represented_position',
 'marker_tax',
 'account_number',
 'tax',
 'represented_name',
 'marker_company_name']

In [15]:
import evaluate

metric = evaluate.load("seqeval")

import numpy as np
from seqeval.metrics import classification_report

return_entity_level_metrics = False

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

In [17]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="ckpt/masked/real_data/layoutlmv3_pretrained_real_data_widen_x_0_widen_y_0.15_0.3",
                                  num_train_epochs=50,
                                  learning_rate=2e-5,
                                  weight_decay=1e-2,
                                  evaluation_strategy="steps",
                                  save_strategy='steps',
                                  eval_steps=250,
                                  save_steps=250,
                                  save_total_limit=15,
                                  load_best_model_at_end=True,
                                  metric_for_best_model="f1",
                                  warmup_ratio = 0.1,
                                  do_eval=True)

In [18]:
from transformers.data.data_collator import default_data_collator

class CustomTrainer(Trainer):
  def get_train_dataloader(self):
    return train_dataloader

  def get_eval_dataloader(self, eval_dataset = None):
    return val_dataloader

# Initialize our Trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics,
)

In [19]:
trainer.train()

  return lib.intersection(a, b, **kwargs)


Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
250,No log,0.027912,0.970765,0.980035,0.975378,0.99408
500,1.819800,0.026692,0.980017,0.979167,0.979592,0.993948
750,1.819800,0.039328,0.983493,0.982639,0.983066,0.994738
1000,0.004900,0.096145,0.985179,0.980903,0.983036,0.994606
1250,0.004900,0.052028,0.980836,0.977431,0.97913,0.993553
1500,0.001800,0.026548,0.984334,0.981771,0.983051,0.994474
1750,0.001800,0.089436,0.979983,0.977431,0.978705,0.993685
2000,0.001300,0.105592,0.97913,0.977431,0.97828,0.993685
2250,0.001300,0.019283,0.980836,0.977431,0.97913,0.993817
2500,0.000800,0.108437,0.983421,0.978299,0.980853,0.99408


  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)


TrainOutput(global_step=3800, training_loss=0.2407649511255716, metrics={'train_runtime': 3952.2285, 'train_samples_per_second': 3.833, 'train_steps_per_second': 0.961, 'total_flos': 4021680212121600.0, 'train_loss': 0.2407649511255716, 'epoch': 50.0})

In [None]:
model_dir = 'ckpt/masked/real_data/layoutlmv3_pretrained_fake_data/'
max_f1, max_acc = 0, 0
best_f1_model, best_acc_model = None, None
for model_fn in os.listdir(model_dir):
    if model_fn == 'runs':
        continue
    model_fp = os.path.join(model_dir, model_fn)
    loaded_model = LayoutLMv3ForTokenClassification.from_pretrained(model_fp).to('cuda')
    trainer.model = loaded_model
    res = trainer.evaluate()
    if res['eval_f1'] > max_f1:
        best_f1_model = model_fn
        max_f1 = res['eval_f1']
    if res['eval_accuracy'] > max_acc:
        best_acc_model = model_fn
        max_acc = res['eval_accuracy']

print(f'Best f1 model: {best_f1_model} - {max_f1}')
print(f'Best acc model: {best_acc_model} - {max_acc}')

In [20]:
# delete model
import shutil

model_dir = 'ckpt/masked/real_data/layoutlmv3_pretrained_real_data_widen_x_0_widen_y_0.15_0.3'
for model_fn in os.listdir(model_dir):
    if model_fn == 'runs':
        continue
    model_fp = os.path.join(model_dir, model_fn)
    if model_fn not in ['checkpoint-1000']:
        shutil.rmtree(model_fp)
        print(f'removed {model_fn}')

removed checkpoint-1750
removed checkpoint-3000
removed checkpoint-3750
removed checkpoint-250
removed checkpoint-2250
removed checkpoint-750
removed checkpoint-500
removed checkpoint-3250
removed checkpoint-2500
removed checkpoint-1500
removed checkpoint-2000
removed checkpoint-1250
removed checkpoint-2750
removed checkpoint-3500


# Inference

In [None]:
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor

model = LayoutLMv3ForTokenClassification.from_pretrained('ckpt/masked/real_data/layoutlmv3_pretrained_fake_data/checkpoint-5200-best_f1-0.984-acc-0.995')
processor = LayoutLMv3Processor.from_pretrained('microsoft/layoutlmv3-base')
processor.tokenizer.only_label_first_subword = False

print(model)
print(processor)

In [None]:
import pdb
from transformers import LayoutXLMProcessor
from collections import Counter
from transformers import LayoutXLMProcessor, LayoutXLMTokenizerFast, LayoutLMv2FeatureExtractor, LayoutLMv3Processor

processor = LayoutLMv3Processor.from_pretrained('microsoft/layoutlmv3-base', apply_ocr=False)
processor.tokenizer.only_label_first_subword = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def denormalize(bb, img_w, img_h):
  return (
      int(bb[0] / 1000 * img_w),
      int(bb[1] / 1000 * img_h),
      int(bb[2] / 1000 * img_w),
      int(bb[3] / 1000 * img_h),
  )

def predict(img, words, boxes):
    assert len(words) == len(boxes)

    preds_val = None
    out_label_ids = None
    num_no_split = 0
    num_has_split = 0
    result_dict = {}

    img_w, img_h = img.size
    # encode input for model
    encoded_inputs = processor(img, words, boxes=boxes, truncation=True, stride=128,
                               padding="max_length", max_length=512, return_overflowing_tokens=True, return_offsets_mapping=True, return_tensors="pt")
    overflow_to_sample_mapping = encoded_inputs.pop('overflow_to_sample_mapping')
    offset_mapping = encoded_inputs.pop('offset_mapping')

    n = len(encoded_inputs['pixel_values'])
    print(f'{n} split')

    ls_bb2label = []
    for idx, image in enumerate(encoded_inputs['pixel_values']):
        # prepare input to model
        input_ids = encoded_inputs['input_ids'][idx].unsqueeze(0).to(device)
        bbox = encoded_inputs['bbox'][idx].unsqueeze(0).to(device)
        image = encoded_inputs['pixel_values'][idx].unsqueeze(0).to(device)
        attention_mask = encoded_inputs['attention_mask'][idx].unsqueeze(0).to(device)

        # forward
        outputs = model(input_ids=input_ids, bbox=bbox, pixel_values=image, attention_mask=attention_mask)

        # process output
        preds_val = outputs.logits.detach().cpu().numpy()[0].tolist()

        bbs = bbox.detach().cpu().squeeze().numpy().tolist()
        input_ids = input_ids.cpu().squeeze().numpy().tolist()
        bb2label = {}
        for i, (pred, bb) in enumerate(zip(preds_val, bbs)):
            bb = tuple(bb)
            if bb not in bb2label:
              bb2label[bb] = [np.argmax(pred)]
            else:
              bb2label[bb].append(np.argmax(pred))
        ls_bb2label.append(bb2label)

    # get predictions for all parts
    final_bb2label = {}
    for bb2label in ls_bb2label:
        for bb, label in bb2label.items():
            if bb not in final_bb2label:
                final_bb2label[bb] = label
            else:
                final_bb2label[bb].extend(label)

    # get final predictions
    bb2label = {bb: Counter(label).most_common(1)[0][0] for bb, label in final_bb2label.items()}

    return {denormalize(bb, img_w, img_h): label for bb, label in bb2label.items()}

In [None]:
data_dir = 'real_data/val_labeled_ocred'
result_dict = {}
widen_range_x = 0.1
widen_range_y = 0.2
for jp in Path(data_dir).rglob('*.json'):
    img, words, boxes, labels = gen_annotation_for_img(img_fp=jp.with_suffix('.jpg'), 
                                                       json_fp=jp, 
                                                       xml_fp=jp.with_suffix('.xml'), 
                                                       masked=True,
                                                       widen_range_x=widen_range_x,
                                                       widen_range_y=widen_range_y)
    result_dict[jp.with_suffix('.jpg')] = predict(img, words, boxes)
    print(f'Done {jp}')

In [None]:
id2label = model.config.id2label
print(id2label)

In [None]:
import shutil

dir = 'real_data/val_labeled_ocred'
out_dir = 'real_data/val_labeled_ocred_layoutlmv3_nonmasked_pred'
os.makedirs(out_dir, exist_ok=True)

for img_fp, bb2label in result_dict.items():
    json_fp = img_fp.with_suffix('.json')
    with open(json_fp, 'r') as f:
        data = json.load(f)

    new_shapes = []
    for bb, label in bb2label.items():
        points = [
            [bb[0], bb[1]],
            [bb[2], bb[1]],
            [bb[2], bb[3]],
            [bb[0], bb[3]]
        ]
        shape = {
            'label': id2label[label],
            'points': points,
            'shape_type': 'polygon',
            'flags': {}
        }
        new_shapes.append(shape)
    data['shapes'] = new_shapes

    with open(os.path.join(out_dir, json_fp.name), 'w') as f:
        json.dump(data, f)
    shutil.copy(img_fp, out_dir)
    print(f'done {img_fp}')