In [None]:
import os
from os.path import join
import pandas as pd
import numpy as np
from ast import literal_eval
from dla_pipeline_support_functions import load_mask_registry

from transformers import (
    LayoutLMv3FeatureExtractor,
    LayoutLMv3Tokenizer,
    LayoutLMv3Processor,
    LayoutLMv3ForTokenClassification,
)
import json
import torch

from datasets import Dataset

from PIL.JpegImagePlugin import JpegImageFile

pd.set_option("display.max_rows", 999)
pd.set_option("display.max_columns", 999)
pd.set_option("display.width", 999)

In [None]:
if not os.path.exists("/data"):
    os.symlink("/user/w210/clean_data_is_all_you_need/app/data", "/data")

DATA_DIRECTORY = "/data"

S1_INPUT_PDFS_DIR = join(DATA_DIRECTORY, "s1_input_pdfs")
S2_DLA_INPUTS_DIR = join(DATA_DIRECTORY, "s2_dla_inputs")
S3_OUTPUTS_DIR = join(DATA_DIRECTORY, "s3_outputs")
S4_JSON_TEXT_OUTPUTS_DIR = join(DATA_DIRECTORY, "s4_json_text_output")
PAGE_MASK_DIR = join(S3_OUTPUTS_DIR, "page_masks")

PRETRAINED_MODEL_DIR = "/user/w210/large_file_repo/models_pretrained"
MODEL_TAG = "layoutlmv3-finetuned-DocLayNet_large_sci_23_12_02-15_50_34/checkpoint-5946"
MODEL_WEIGHTS = join(PRETRAINED_MODEL_DIR, MODEL_TAG)
MODEL_PROCESSOR = join(PRETRAINED_MODEL_DIR, "microsoft-layoutlmv3-base-processor")
MODEL_CATEGORIES_JSON = join(DATA_DIRECTORY, "dla_categories_doclaynet.json")

GLOBAL_BATCH_SIZE = 1
MAX_LENGTH = 512

for pth in [
    DATA_DIRECTORY,
    S1_INPUT_PDFS_DIR,
    S2_DLA_INPUTS_DIR,
    S3_OUTPUTS_DIR,
    S4_JSON_TEXT_OUTPUTS_DIR,
    PAGE_MASK_DIR,
    MODEL_WEIGHTS,
    MODEL_PROCESSOR,
    MODEL_CATEGORIES_JSON,
]:
    assert os.path.exists(pth), f"PATH NOT FOUND: {pth}"

In [None]:
# Normalize box diamentions to range 0 to 1000
def normalized_box(box, image_width, image_height):
    return [
        int(1000 * (box[0] / image_width)),
        int(1000 * (box[1] / image_height)),
        int(1000 * (box[2] / image_width)),
        int(1000 * (box[3] / image_height)),
    ]

In [None]:
mask_registry = load_mask_registry(PAGE_MASK_DIR, validate_csvs=False)
mask_registry["umask_id"] = np.arange(len(mask_registry))

page_image_registry = pd.read_csv(join(S3_OUTPUTS_DIR, "page_images_list.csv"))

doc_text_registry = pd.read_csv(
    join(S4_JSON_TEXT_OUTPUTS_DIR, "text_extract.csv")
)

doc_text_registry["json_path"] = doc_text_registry.apply(
    lambda var: var["pdf_file"].replace(".pdf", ".json"),
    axis=1,
)

print(mask_registry.columns)
print("")
print(page_image_registry.columns)
print("")
print(doc_text_registry.columns)

In [None]:
def generate_doclaylet_dataset(
    page_image_registry: pd.DataFrame,
    doc_text_registry: pd.DataFrame,
    mask_registry: pd.DataFrame,
    max_text_length: int
):
    dataset_dict = {
        "document_id": [],
        "page_no": [],
        "images": [],
        "words": [],
        "bboxes": [],
        "umask_id": [],
        "dummy_label": []
    }
    for i, row in doc_text_registry.iterrows():

        # DOCUMENT SPECIFIC VALUES ########################################
        ###################################################################

        doc_id = row["pdf_file"]

        doc_json_path = join(row["output_directory"], row["json_path"])
        with open(doc_json_path, "r") as json_file:
            doc_json = json.load(json_file)

        doc_json_df = pd.DataFrame(doc_json["paper_text"])

        # Ensure the box is read as numbers
        doc_json_df["section_im_bbox"] = doc_json_df["section_im_bbox"].apply(
            lambda var: literal_eval(str(var))
        )        
        doc_json_df.sort_values(by=["section_page", "section_id"], inplace=True)


        # PAGE SPECIFIC VALUES ############################################
        ###################################################################
        doc_image_df = page_image_registry.query(f"document=='{doc_id}'")
        
        for ii, im_row in doc_image_df.iterrows():
            page_no = im_row["page_no"]
            page_img_path = join(S2_DLA_INPUTS_DIR, im_row["file_name"])
            page_img = JpegImageFile(page_img_path)

            # Dataset Doc Info
            dataset_dict["document_id"].append(doc_id)
            dataset_dict["page_no"].append(page_no)

            # Dataset Images
            dataset_dict["images"].append(page_img)

            # MASK SPECIFIC VALUES ########################################
            ###############################################################
            
            doc_page_json_df = doc_json_df.query(f"section_page=={page_no}")
            doc_page_json_df["mask_id"] = doc_page_json_df["section_id"]  

            # Dataset Words
            #   NOTE: We need to ensure that per page there are less tokens 
            #   generated than the maximum number of tokens allowed. The code 
            #   below, ensures that each section gets tokens.

            page_text = doc_page_json_df["section_text"].to_list()

            no_masks = len(page_text)
            words_per_mask = int((max_text_length / no_masks) * 0.60)

            short_page_text = []
            for section_text in page_text:
                short_txt = section_text.split(" ")[0:words_per_mask]
                short_txt = " ".join(short_txt)
                short_page_text.append(short_txt)

            dataset_dict["words"].append(short_page_text)

            # Dataset bboxes
            bboxes = doc_page_json_df["section_im_bbox"].to_list()
            dataset_dict["bboxes"].append(bboxes)    
                    

            # Unique Mask ID:
            page_mask_df = mask_registry.query(f"document=='{doc_id}' & page_no=={page_no}")
            page_mask_df = pd.merge(doc_page_json_df, page_mask_df, on="mask_id", how="inner")
            umask_ids = page_mask_df['umask_id'].to_list()

            assert len(short_page_text) == len(bboxes) == len(umask_ids)

            dataset_dict['umask_id'].append(umask_ids)
            dataset_dict['dummy_label'].append(np.zeros(len(umask_ids)))



    return Dataset.from_dict(dataset_dict)


dataset = generate_doclaylet_dataset(page_image_registry, doc_text_registry, mask_registry, max_text_length=MAX_LENGTH)

In [None]:
with open(MODEL_CATEGORIES_JSON, "r") as json_file:
    categories_dict = json.load(json_file)

categories_dict

id2label = {int(k): v for k, v in categories_dict.items()}
label2id = {v: int(k) for k, v in categories_dict.items()}

print(id2label)
print(label2id)

# Load Model and Processor

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = LayoutLMv3Processor.from_pretrained(MODEL_PROCESSOR, apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained(
    MODEL_WEIGHTS, id2label=id2label, label2id=label2id
)

# Prepare Dataset for inference


In [None]:
def prepare_examples(examples):
    images = examples["images"]
    words = examples["words"]
    bboxes = examples["bboxes"]
    word_labels = examples["dummy_label"]

    unique_boxes_list = []
    for img, boxs in zip(images, bboxes):
        width, height = img.size
        original_bboxes_list = [normalized_box(box, width, height) for box in boxs]

        unique_boxes_list.append(original_bboxes_list)

    # https://github.com/huggingface/transformers/issues/19190
    encoding = processor(
        images,
        words,
        boxes=unique_boxes_list,
        word_labels=word_labels,
        max_length=512,
        padding="max_length",
        stride=128,
        truncation=True,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
    )
    offset_mapping = encoding.pop("offset_mapping")

    overflow_to_sample_mapping = encoding.pop("overflow_to_sample_mapping")

    return encoding


from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D

# we need to define custom features for `set_format` (used later on) to work properly
features = Features(
    {
        "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224)),
        "input_ids": Sequence(feature=Value(dtype="int64")),
        "attention_mask": Sequence(Value(dtype="int64")),
        "bbox": Array2D(dtype="int64", shape=(512, 4)),
        "labels": Sequence(feature=Value(dtype="int64")),
    }
)

eval_dataset = dataset.map(
    prepare_examples,
    batched=True,
    remove_columns=dataset.features.keys(),
    features=features,
    batch_size=GLOBAL_BATCH_SIZE,
)

In [None]:
eval_dataset

In [None]:
eval_dataset.set_format("torch")

In [None]:
torch.cuda.empty_cache()

In [None]:
from transformers.data.data_collator import default_data_collator
from transformers import TrainingArguments, Trainer

trainer = Trainer(
    model=model,
    tokenizer=processor,
    data_collator=default_data_collator,
)

In [None]:
report_dataset = eval_dataset

predictions_full = trainer.predict(report_dataset)

predictions = predictions_full.predictions.argmax(axis=2)
labels = report_dataset["labels"]

filtered_predictions = [
    [p for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
filtered_labels = [
    [l for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]


assert len(filtered_predictions) == len(filtered_labels)

In [None]:
dataset

In [None]:
filtered_predictions

In [None]:
def update_mask_registry_with_new_labels(mask_registry: pd.DataFrame, dataset:Dataset, filtered_predictions:list, id2label: dict):

    unknown_id = -1

    id2label[unknown_id] = "Unknown"

    mask_registry["new_category"] = np.full(len(mask_registry), unknown_id)

   
    umask_id_list = []
    predictions_list = []

    for um_list,p_list in zip(dataset['umask_id'], filtered_predictions):
        assert len(um_list) == len(p_list)

        for umask_id, pred in zip(um_list,p_list):
            umask_id_list.append(umask_id)
            predictions_list.append(pred)

            matching_index = mask_registry.index[mask_registry["umask_id"] == umask_id].tolist()[0]


            mask_registry.at[matching_index, "new_category"] = pred

            
            




    


update_mask_registry_with_new_labels(mask_registry, dataset, filtered_predictions, id2label)

mask_registry

In [None]:
type(filtered_predictions)