In [None]:
!pwd

In [None]:
!nvidia-smi

In [1]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/data/tungtx2/tmp/transformers_hub'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import os
import json
from pathlib import Path
import numpy as np
from PIL import Image
import torch
from os import listdir
from torch.utils.data import Dataset
import torch
from PIL import Image
import unidecode
from PIL import Image, ImageDraw, ImageFont
import pdb
import xml.etree.ElementTree as ET
from shapely.geometry import Polygon
import cv2

torch.__version__

  from .autonotebook import tqdm as notebook_tqdm


'1.13.1+cu117'

# Prepare Data

In [3]:
train_dir = 'real_data/train_labeled_ocred'
val_dir = 'real_data/val_labeled_ocred'

def find_all_labels(data_dir):
    labels = {}
    for jp in Path(data_dir).rglob('*.json'):
        data = json.load(open(jp))
        for shape in data['shapes']:
            if shape['label'] in labels:
                labels[shape['label']] += 1
            else:
                labels[shape['label']] = 0
                
    return labels

train_labels = find_all_labels(train_dir)
val_labels = find_all_labels(val_dir)
assert set(train_labels.keys()) == set(val_labels.keys())
for k, v in train_labels.items():
    print(k, ':', v)

text : 64973
account_number : 418
marker_account_number : 584
swift_code : 234
marker_swift_code : 406
bank_name : 1492
marker_bank_name : 425
fax : 565
marker_fax : 313
phone : 962
marker_phone : 489
company_address : 7565
company_name : 3915
marker_company_name : 1351
bank_address : 1227
marker_bank_address : 207
marker_company_address : 504
represented_position : 463
marker_represented_position : 88
represented_name : 1075
marker_represented_name : 611
tax : 56
marker_tax : 98


In [4]:
label_list = list(set(train_labels.keys()))
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}

print(label2id)
print(id2label)

{'bank_name': 0, 'fax': 1, 'represented_position': 2, 'bank_address': 3, 'swift_code': 4, 'marker_represented_name': 5, 'marker_bank_address': 6, 'marker_represented_position': 7, 'marker_tax': 8, 'account_number': 9, 'marker_company_name': 10, 'marker_bank_name': 11, 'marker_fax': 12, 'marker_company_address': 13, 'text': 14, 'tax': 15, 'represented_name': 16, 'marker_phone': 17, 'company_address': 18, 'phone': 19, 'company_name': 20, 'marker_swift_code': 21, 'marker_account_number': 22}
{0: 'bank_name', 1: 'fax', 2: 'represented_position', 3: 'bank_address', 4: 'swift_code', 5: 'marker_represented_name', 6: 'marker_bank_address', 7: 'marker_represented_position', 8: 'marker_tax', 9: 'account_number', 10: 'marker_company_name', 11: 'marker_bank_name', 12: 'marker_fax', 13: 'marker_company_address', 14: 'text', 15: 'tax', 16: 'represented_name', 17: 'marker_phone', 18: 'company_address', 19: 'phone', 20: 'company_name', 21: 'marker_swift_code', 22: 'marker_account_number'}


# Data Loader

In [5]:
from transformers import LayoutXLMTokenizerFast, LayoutLMv2FeatureExtractor, LayoutXLMProcessor

feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
tokenizer = LayoutXLMTokenizerFast.from_pretrained('SCUT-DLVCLab/lilt-infoxlm-base')
tokenizer.only_label_first_subword = False
processor = LayoutXLMProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
print(processor.tokenizer)
print(processor.feature_extractor)



LayoutXLMTokenizerFast(name_or_path='SCUT-DLVCLab/lilt-infoxlm-base', vocab_size=250002, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
LayoutLMv2FeatureExtractor {
  "apply_ocr": false,
  "do_resize": true,
  "image_processor_type": "LayoutLMv2FeatureExtractor",
  "ocr_lang": null,
  "resample": 2,
  "size": {
    "height": 224,
    "width": 224
  },
  "tesseract_config": ""
}



In [6]:
def normalize_bbox(bbox, width, height):
     return [
         int(1000 * (bbox[0] / width)),
         int(1000 * (bbox[1] / height)),
         int(1000 * (bbox[2] / width)),
         int(1000 * (bbox[3] / height)),
     ]
    
    
def parse_xml(xml_path):
    root = ET.parse(xml_path).getroot()
    objs = root.findall('object')
    boxes, obj_names = [], []
    for obj in objs:
        obj_name = obj.find('name').text
        box = obj.find('bndbox')
        xmin = int(box.find('xmin').text)
        ymin = int(box.find('ymin').text)
        xmax = int(box.find('xmax').text)
        ymax = int(box.find('ymax').text)
        boxes.append([xmin, ymin, xmax, ymax])
        obj_names.append(obj_name)
    return boxes, obj_names


def widen_box(box, percent_x, percent_y):
        xmin, ymin, xmax, ymax = box
        w = xmax - xmin
        h = ymax - ymin
        xmin -= w * percent_x
        ymin -= h * percent_y
        xmax += w * percent_x
        ymax += h * percent_y
        return (int(xmin), int(ymin), int(xmax), int(ymax))

    
def draw_json_on_img(img, json_data):
    labels = list(set(shape['label'] for shape in json_data['shapes']))
    color = {}
    for i in range(len(labels)):
        color[labels[i]] = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
        
    img = img.copy()
    draw = ImageDraw.Draw(img)
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_size = 0.5
    for i, shape in enumerate(json_data['shapes']):
        polys = shape['points']
        polys = [(int(pt[0]), int(pt[1])) for pt in polys]
        label = shape['label']
        draw.polygon(polys, outline=color[label], width=2)
        # Draw the text on the image
        img = np.array(img)
        cv2.putText(img, shape['label'], (polys[0][0], polys[0][1]-5), font, font_size, color[label], thickness=1)
        img = Image.fromarray(img)
        draw = ImageDraw.Draw(img)
    return img
    
    
def mask_image(img, boxes, json_data, widen_range_x, widen_range_y):
    # widen block
    if isinstance(widen_range_x, list) and isinstance(widen_range_y, list):
        boxes = [widen_box(box, np.random.uniform(widen_range_x[0], widen_range_x[1]), np.random.uniform(widen_range_y[0], widen_range_y[1])) for box in boxes]
    else:
        boxes = [widen_box(box, widen_range_x, widen_range_y) for box in boxes]
        
    
    ls_polys2keep = []
    ls_area2keep = []
    iou_threshold = 0.
    for box_idx, box in enumerate(boxes):
        xmin, ymin, xmax, ymax = box
        box_pts = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
        p_box = Polygon(box_pts)
        for shape_idx, shape in enumerate(json_data['shapes']):
            if shape_idx in ls_polys2keep:
                continue
            pts = shape['points']
            p_shape = Polygon(pts)
            intersect_area = p_box.intersection(p_shape).area
            if intersect_area / p_shape.area > iou_threshold:
                ls_polys2keep.append(shape_idx)
                pts = [coord for pt in pts for coord in pt]
                poly_xmin = min(pts[::2])
                poly_ymin = min(pts[1::2])
                poly_xmax = max(pts[::2])
                poly_ymax = max(pts[1::2])
                ls_area2keep.append((poly_xmin, poly_ymin, poly_xmax, poly_ymax))

    # mask white all area of image that is not in block
    mask = np.zeros(img.shape[:2], dtype=np.uint8)
    for box in boxes:
        xmin, ymin, xmax, ymax = box
        xmin = max(0, xmin)
        ymin = max(0, ymin)
        xmax = min(img.shape[1], xmax)
        ymax = min(img.shape[0], ymax)
        mask[ymin:ymax, xmin:xmax] = 255

    for area2keep in ls_area2keep:
        xmin, ymin, xmax, ymax = area2keep
        xmin = int(max(0, xmin))
        ymin = int(max(0, ymin))
        xmax = int(min(img.shape[1], xmax))
        ymax = int(min(img.shape[0], ymax))
        mask[ymin:ymax, xmin:xmax] = 255

    # mask white
    img[mask == 0] = 255

    # delete all poly that is not in block
    ls_idx2del = [idx for idx, shape in enumerate(json_data['shapes']) if idx not in ls_polys2keep]
    for idx in sorted(ls_idx2del, reverse=True):
        del json_data['shapes'][idx]

    return img, json_data
        

def gen_annotation_for_img(img_fp, xml_fp, json_fp, masked=False, widen_range_x=[0.1, 0.2], widen_range_y=[0.1, 0.25]):
    img = Image.open(img_fp).convert("RGB")
    boxes, obj_names = parse_xml(xml_fp)
    json_data = json.load(open(json_fp))
    
    if masked:
        img, json_data = mask_image(np.array(img), boxes=boxes, json_data=json_data, widen_range_x=widen_range_x, widen_range_y=widen_range_y)
        img = Image.fromarray(img)
    # pdb.set_trace()
        
    words, boxes, labels = [], [], []
    img_h, img_w = json_data['imageHeight'], json_data['imageWidth']
    for i, shape in enumerate(json_data['shapes']):
      # words.append(unidecode.unidecode(shape['text'].lower()))
      words.append(shape['text'].lower())
      labels.append(shape['label'])
      pts = [coord for pt in shape['points'] for coord in pt]
      xmin = min(pts[0::2])
      xmax = max(pts[0::2])
      ymin = min(pts[1::2])
      ymax = max(pts[1::2])

      xmin = max(xmin, 0)
      ymin = max(ymin, 0)
      xmax = min(img_w, xmax)
      ymax = min(img_h, ymax)

      boxes.append(normalize_bbox((xmin, ymin, xmax, ymax), img_w, img_h))

    return img, words, boxes, labels


class CORDDataset(Dataset):
    def __init__(self, file_paths, processor=None, max_length=512, masked=False, widen_range_x=[0.1, 0.2], widen_range_y=[0.1, 0.25]):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.ls_img_fp, self.ls_xml_fp, self.ls_json_fp = file_paths
        assert len(self.ls_img_fp) == len(self.ls_json_fp) == len(self.ls_xml_fp)
        self.processor = processor
        self.masked = masked
        self.widen_range_x = widen_range_x
        self.widen_range_y = widen_range_y
        

    def __len__(self):
        return len(self.ls_img_fp)

    def __getitem__(self, index):
        # first, take an image
        img_fp = self.ls_img_fp[index]
        xml_fp = self.ls_xml_fp[index]
        json_fp = self.ls_json_fp[index]
        
        img, words, boxes, text_labels = gen_annotation_for_img(img_fp, xml_fp, json_fp, masked=self.masked, widen_range_x=self.widen_range_x, widen_range_y=self.widen_range_y)
        idx_labels = [label2id[label] for label in text_labels]

        encoded_inputs = processor(img, words, boxes=boxes, word_labels=idx_labels, truncation=True, stride =128, 
                            padding="max_length", max_length=512, return_overflowing_tokens=True, return_offsets_mapping=True, return_tensors="pt")  
        
        # print(encoded_inputs.keys())
        overflow_to_sample_mapping = encoded_inputs.pop('overflow_to_sample_mapping')
        offset_mapping = encoded_inputs.pop('offset_mapping')
        encoded_inputs.pop('image')
        # print('overflow_to_sample_mapping: ', overflow_to_sample_mapping)
        # print('offset_mapping: ', offset_mapping)

        # remove batch dimension
        idx = np.random.randint(0, len(encoded_inputs['bbox']))
        for k, v in encoded_inputs.items():
            encoded_inputs[k] = v[idx]
      
        return encoded_inputs

In [9]:
def get_file_paths(data_dir):
    ls_img_fp, ls_xml_fp, ls_json_fp = [], [], []
    for img_fp in Path(data_dir).rglob('*.jpg'):
        json_fp = img_fp.with_suffix('.json')
        xml_fp = img_fp.with_suffix('.xml')
        
        ls_img_fp.append(str(img_fp))
        ls_xml_fp.append(str(xml_fp))
        ls_json_fp.append(str(json_fp))
    
    return ls_img_fp, ls_xml_fp, ls_json_fp


train_file_paths = get_file_paths(train_dir)
val_file_paths = get_file_paths(val_dir)

widen_range_x = [0.1, 0.2]
widen_range_y = [0.1, 0.25]
train_dataset = CORDDataset(file_paths=train_file_paths, processor=processor, masked=True, 
                            widen_range_x=widen_range_x, widen_range_y=widen_range_y)
val_dataset = CORDDataset(file_paths=val_file_paths, processor=processor, masked=True, 
                          widen_range_x=0.15, widen_range_y=0.2)

print(len(train_dataset))
print(len(val_dataset))

303
33


In [10]:
encoding = val_dataset[9]
for k,v in encoding.items():
  print(k, v.shape)

input_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])


In [12]:
ls_token = [processor.tokenizer.decode(input_id) for input_id in encoding['input_ids']]
ls_label = [id2label[int(label_id)] if label_id != -100 else 'SPECIAL' for label_id in encoding['labels'] ]
ls_bb = list(encoding['bbox'])
for item in zip(ls_token, ls_label, ls_bb):
  print(item)
  # break

('<s>', 'SPECIAL', tensor([0, 0, 0, 0]))
('mentioned', 'text', tensor([616, 353, 680, 362]))
('.', 'text', tensor([616, 353, 680, 362]))
('here', 'text', tensor([555, 353, 612, 362]))
('unde', 'text', tensor([555, 353, 612, 362]))
('conditions', 'text', tensor([496, 353, 551, 362]))
("'", 'text', tensor([332, 353, 369, 362]))
('article', 'text', tensor([332, 353, 369, 362]))
('and', 'text', tensor([468, 350, 493, 364]))
('terms', 'text', tensor([434, 350, 468, 364]))
('the', 'text', tensor([413, 352, 432, 363]))
('with', 'text', tensor([386, 350, 413, 364]))
('r', 'text', tensor([372, 352, 386, 364]))
("'", 'text', tensor([372, 352, 386, 364]))
('at', 'text', tensor([315, 352, 328, 364]))
('items', 'text', tensor([281, 350, 313, 364]))
('of', 'text', tensor([263, 352, 277, 363]))
('sales', 'text', tensor([229, 350, 262, 364]))
('the', 'text', tensor([207, 350, 230, 364]))
('agreed', 'text', tensor([149, 350, 189, 363]))
('on', 'text', tensor([191, 350, 207, 363]))
('have', 'text', tens

In [13]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

for item in train_dataloader:
  for k, v in item.items():
    print(k, v.shape)
  break

input_ids torch.Size([4, 512])
attention_mask torch.Size([4, 512])
bbox torch.Size([4, 512, 4])
labels torch.Size([4, 512])


# Model

In [14]:
from transformers import LiltForTokenClassification

model = LiltForTokenClassification.from_pretrained('ckpt/masked/fake_data/lilt/checkpoint-34000')

In [15]:
model

LiltForTokenClassification(
  (lilt): LiltModel(
    (embeddings): LiltTextEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layout_embeddings): LiltLayoutEmbeddings(
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
      (box_position_embeddings): Embedding(514, 192, padding_idx=1)
      (box_linear_embeddings): Linear(in_features=768, out_features=192, bias=True)
      (LayerNorm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LiltEncoder(
      (layer): ModuleList(
        (0): LiltLayer(

In [16]:
import torch.nn as nn

model.classifier = nn.Linear(in_features=768, out_features=len(label_list), bias=True)
model.num_labels = len(label_list)
model.config.label2id = label2id
model.config.id2label = id2label

# Hugging Face Trainer

In [17]:
label_list

['bank_name',
 'fax',
 'represented_position',
 'bank_address',
 'swift_code',
 'marker_represented_name',
 'marker_bank_address',
 'marker_represented_position',
 'marker_tax',
 'account_number',
 'marker_company_name',
 'marker_bank_name',
 'marker_fax',
 'marker_company_address',
 'text',
 'tax',
 'represented_name',
 'marker_phone',
 'company_address',
 'phone',
 'company_name',
 'marker_swift_code',
 'marker_account_number']

In [18]:
import evaluate

metric = evaluate.load("seqeval")

import numpy as np
from seqeval.metrics import classification_report

return_entity_level_metrics = False

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

In [19]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="ckpt/masked/real_data/lilt_pretrain_fake_data",
                                  num_train_epochs=100,
                                  learning_rate=2e-5,
                                  weight_decay=1e-2,
                                  evaluation_strategy="steps",
                                  save_strategy='steps',
                                  eval_steps=250,
                                  save_steps=250,
                                  save_total_limit=15,
                                  load_best_model_at_end=True,
                                  metric_for_best_model="f1",
                                  warmup_ratio = 0.1,
                                  do_eval=True)

In [20]:
from transformers.data.data_collator import default_data_collator

class CustomTrainer(Trainer):
  def get_train_dataloader(self):
    return train_dataloader

  def get_eval_dataloader(self, eval_dataset = None):
    return val_dataloader

# Initialize our Trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics,
)

In [21]:
trainer.train()

  return lib.intersection(a, b, **kwargs)


Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
250,No log,0.577315,0.727691,0.681897,0.70405,0.903361
500,1.231500,0.117179,0.910122,0.899138,0.904597,0.977377
750,1.231500,0.047563,0.955691,0.948276,0.951969,0.989015
1000,0.092200,0.067816,0.960242,0.957759,0.958999,0.990715
1250,0.092200,0.044698,0.970715,0.971552,0.971133,0.992415
1500,0.032100,0.025956,0.973253,0.972414,0.972833,0.992285
1750,0.032100,0.031847,0.981802,0.976724,0.979257,0.992677
2000,0.016000,0.049548,0.981881,0.981034,0.981458,0.9932
2250,0.016000,0.026038,0.986171,0.983621,0.984894,0.995554
2500,0.010300,0.044616,0.981865,0.980172,0.981018,0.995031


  _warn_prf(average, modifier, msg_start, len(result))
  return lib.intersection(a, b, **kwargs)
  _warn_prf(average, modifier, msg_start, len(result))
  return lib.intersection(a, b, **kwargs)
  _warn_prf(average, modifier, msg_start, len(result))
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwa

TrainOutput(global_step=7600, training_loss=0.09240561253832359, metrics={'train_runtime': 7981.7555, 'train_samples_per_second': 3.796, 'train_steps_per_second': 0.952, 'total_flos': 8429477794099200.0, 'train_loss': 0.09240561253832359, 'epoch': 100.0})

In [23]:
# delete model
import shutil

model_dir = 'ckpt/masked/real_data/lilt_pretrain_fake_data/'
for model_fn in os.listdir(model_dir):
    if model_fn == 'runs':
        continue
    model_fp = os.path.join(model_dir, model_fn)
    if model_fn not in ['checkpoint-7500']:
        shutil.rmtree(model_fp)
        print(f'removed {model_fn}')

removed checkpoint-7000
removed checkpoint-5250
removed checkpoint-6250
removed checkpoint-5500
removed checkpoint-6000
removed checkpoint-2250
removed checkpoint-4500
removed checkpoint-4750
removed checkpoint-7250
removed checkpoint-4250
removed checkpoint-5000
removed checkpoint-6750
removed checkpoint-5750
removed checkpoint-6500
