In [1]:
from pathlib import Path
import json
import torch
from visual_page_classifier.lightning_module import VisualPageClassifierPLModule, VisualPageClassifier
from multipage_transformer.lightning_module import MultipageTransformerPLModule, MultipageTransformer
from page_comparsion_encoder.lightning_module import PageComparisonEncoderPLModule, PageComparisonEncoder


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
VISUAL_CLASSIFIER_MODEL_PATH = "/data/training/master_thesis/evaluation_logs/visual_page_classifier/version_1/checkpoints/best-checkpoint.ckpt"
DUAL_CLASSIFIER_MODEL_PATH = "/data/training/master_thesis/evaluation_logs/page_comparision_encoder/microsoft/swinv2-base-patch4-window8-256/version_0/checkpoints/best-checkpoint.ckpt"
TRANSFORMER_MODEL_PATH = "/data/training/master_thesis/evaluation_logs/multipage_transformer/version_1/checkpoints/best-checkpoint.ckpt"

# Load Model
visual_classifier: VisualPageClassifier = VisualPageClassifierPLModule.load_from_checkpoint(VISUAL_CLASSIFIER_MODEL_PATH, map_location="cpu").eval().classifier
dual_classifier: PageComparisonEncoder = PageComparisonEncoderPLModule.load_from_checkpoint(DUAL_CLASSIFIER_MODEL_PATH, map_location="cpu").eval().model
transformer: MultipageTransformer = MultipageTransformerPLModule.load_from_checkpoint(TRANSFORMER_MODEL_PATH, map_location="cpu").eval().model


Some weights of MBartForCausalLM were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['lm_head.weight', 'decoder.layer_norm.weight', 'decoder.layer_norm.bias', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of MBartForCausalLM were not initialized from the model checkpoint at facebook/bart-base and are newly initialized because the shapes did not match:
- decoder.embed_positions.weight: found shape torch.Size([1026, 768]) in the checkpoint and torch.Size([770, 1024]) in the model instantiated
- decoder.layernorm_embedding.bias: found shape torch.Size([768]) in the checkpoint and torch.Size([1024]) in the model instantiated
- decoder.layernorm_embedding.weight: found shape torch.Size([768]) in the checkpoint and torch.Size([1024]) in the model instantiated
- decoder.layers.0.encoder_attn.k_proj.bias: found shape torch.Size([768]) in the 

In [3]:
SAMPLE_PATH = "/data/training/master_thesis/datasets/evaluation/ma_sample"

In [4]:
CLASS_PATH = "/data/training/master_thesis/datasets/bzuf_classes.json"
classes = [c for c in json.load(open(CLASS_PATH))]
id2class = {idx: str(label) for idx, label in enumerate(classes)}


In [5]:
# Load images
from PIL import Image
import os

folder_path = Path(SAMPLE_PATH)  # Replace with the path to your folder

# Get a list of all files in the folder
file_list = [file for file in os.listdir(folder_path) if file.endswith(".png")]

# Sort the files based on page number
file_list.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))

# Initialize a list to store the loaded images
images = []

# Load each image using PIL
for file_name in file_list:
    file_path = os.path.join(folder_path, file_name)
    image = Image.open(file_path)
    images.append(image)

ground_truth = json.load(open(folder_path / "ground_truth.json"))

In [6]:
def visual_classifier_inference(images: list[Image.Image]):
    # prepare input
    pixel_values = torch.cat([visual_classifier.encoder.page_encoder.prepare_input(img).unsqueeze(0) for img in images])
    # Inference
    with torch.no_grad():
        output = visual_classifier.predict(pixel_values)

    doc_id_pred = [p.item() for p in output["doc_id"]]
    class_pred = [id2class[int(p.item())] for p in output["doc_class"].argmax(1)]
    page_nr_pred = [p.item() for p in output["page_nr"].argmax(1)]

    return doc_id_pred, class_pred, page_nr_pred

def dual_classifier_inference(images: list[Image.Image]):
    # prepare input
    pixel_values = torch.cat([dual_classifier.encoder.prepare_input(img).unsqueeze(0) for img in images])

    # Inference
    with torch.no_grad():
        output = dual_classifier.forward(pixel_values)
        output = dual_classifier.postprocess(output)

    doc_id_pred = [p.item() for p in output["doc_id"]]
    class_pred = [id2class[int(p.item())] for p in output["doc_class"].argmax(1)]
    page_nr_pred = [p.item() for p in output["page_nr"]]

    return doc_id_pred, class_pred, page_nr_pred

def transformer_inference(images: list[Image.Image]):
    # prepare input
    pixel_values = torch.cat([transformer.encoder.page_encoder.prepare_input(img).unsqueeze(0) for img in images])
    
    # Inference
    with torch.no_grad():  
        output = transformer.predict(pixel_values)

    doc_id_pred = []
    class_pred = []
    page_nr_pred = []
    def extract_pred(obj):
        if isinstance(obj, dict):
            doc_id_pred.append(obj.get("doc_id", "N/A"))
            class_pred.append(obj.get("doc_class", "N/A"))
            page_nr_pred.append(obj.get("page_nr", "N/A"))
            return 
        doc_id_pred.append(-1)
        class_pred.append(-1)
        page_nr_pred.append(-1)

    if isinstance(output, list):
        for elem in output:
            extract_pred(elem)
    elif isinstance(output, dict):
        extract_pred(output)

    doc_id_pred = doc_id_pred + ["N/A"] * (len(pixel_values) -len(doc_id_pred)) 
    class_pred = class_pred + ["N/A"] * (len(pixel_values) -len(class_pred)) 
    page_nr_pred = page_nr_pred + ["N/A"] * (len(pixel_values) -len(page_nr_pred)) 
    
    return doc_id_pred, class_pred, page_nr_pred

v_id, v_class, v_nr = visual_classifier_inference(images)
d_id, d_class, d_nr = dual_classifier_inference(images)
t_id, t_class, t_nr = transformer_inference(images)




In [7]:

for idx, page_gt in enumerate(ground_truth):
    print(f'Page {idx +1} &  &  & \\\\')
    print(f'Ground Truth & {page_gt["doc_id"]} & {page_gt["doc_class"]} & {page_gt["page_nr"]}\\\\'.replace("_", "\\_").replace(".", "-"))
    print(f'PageComparisionEncoder & {d_id[idx]} & {d_class[idx]} & {d_nr[idx]} \\\\'.replace("_", "\\_").replace(".", "-"))
    print(f'VisualPageClassifier & {v_id[idx]} & {v_class[idx]} & {v_nr[idx]} \\\\'.replace("_", "\\_").replace(".", "-"))
    print(f'MultipageTransformer & {t_id[idx]} & {t_class[idx]} & {t_nr[idx]} \\\\ \\hline'.replace("_", "\\_").replace(".", "-"))
    print()



Page 1 &  &  & \\
Ground Truth & 0 & antrag-formlos & 0\\
PageComparisionEncoder & 0 & antrag-formlos & 0 \\
VisualPageClassifier & 0 & anschreiben & 0 \\
MultipageTransformer & N/A & N/A & N/A \\ \hline

Page 2 &  &  & \\
Ground Truth & 1 & antrag-formblattantrag-hilfe\_zur\_pflege & 0\\
PageComparisionEncoder & 1 & antrag-formblattantrag-heilpaedagogische\_tagesstaette & 1 \\
VisualPageClassifier & 1 & antrag-formblattantrag-ambulant\_betreutes\_wohnen & 0 \\
MultipageTransformer & N/A & N/A & N/A \\ \hline

Page 3 &  &  & \\
Ground Truth & 1 & antrag-formblattantrag-hilfe\_zur\_pflege & 1\\
PageComparisionEncoder & 1 & antrag-formblattantrag-heilpaedagogische\_tagesstaette & 0 \\
VisualPageClassifier & 1 & antrag-formblattantrag-hilfe\_zur\_pflege & 1 \\
MultipageTransformer & N/A & N/A & N/A \\ \hline

Page 4 &  &  & \\
Ground Truth & 1 & antrag-formblattantrag-hilfe\_zur\_pflege & 2\\
PageComparisionEncoder & 1 & antrag-formblattantrag-heilpaedagogische\_tagesstaette & 1 \\
Visual