In [1]:
from pathlib import Path
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

import pytorch_lightning as pl
from multipage_classifier.multipage_transformer import MultipageTransformerConfig
from training.transformer.lightning_module import MultipageTransformerPLModule, MultipagePLDataModule


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

MODEL = "/data/training/master_thesis/lightning_logs/multipage_transformer/version_0/checkpoints/best-checkpoint.ckpt"

DATASET_PATH = "/data/training/master_thesis/datasets/2023-05-23"
CLASS_PATH = "/data/training/master_thesis/datasets/bzuf_classes.json"

MAX_PAGES = 8
NUM_WORKERS = 1

MAX_LENGTH = 768
IMAGE_SIZE  = (420, 360)
TASK_PROMPT = "<s_classification>"


In [3]:
# Load Model
model = MultipageTransformerPLModule.load_from_checkpoint(MODEL)
#model = torch.load("/data/training/master_thesis/lightning_logs/last_model.ckpt") # TODO FIXXX

In [4]:
data_module = MultipagePLDataModule(Path(DATASET_PATH), model.model, task_prompt=TASK_PROMPT, num_workers=NUM_WORKERS)

data_module.prepare_data()
data_module.setup()

In [5]:
ds = iter(data_module.test_dataloader())
n = 10
for _ in range(n):
    i = next(ds)

In [6]:
i["ground_truth"]

[{'doc_id': tensor([0]), 'doc_class': ['anschreiben'], 'page_nr': tensor([0])},
 {'doc_id': tensor([0]), 'doc_class': ['anschreiben'], 'page_nr': tensor([1])},
 {'doc_id': tensor([1]),
  'doc_class': ['antrag.formblattantrag.hilfe_zur_pflege'],
  'page_nr': tensor([0])},
 {'doc_id': tensor([1]),
  'doc_class': ['antrag.formblattantrag.hilfe_zur_pflege'],
  'page_nr': tensor([1])},
 {'doc_id': tensor([1]),
  'doc_class': ['antrag.formblattantrag.hilfe_zur_pflege'],
  'page_nr': tensor([2])},
 {'doc_id': tensor([1]),
  'doc_class': ['antrag.formblattantrag.hilfe_zur_pflege'],
  'page_nr': tensor([3])},
 {'doc_id': tensor([1]),
  'doc_class': ['antrag.formblattantrag.hilfe_zur_pflege'],
  'page_nr': tensor([4])},
 {'doc_id': tensor([1]),
  'doc_class': ['antrag.formblattantrag.hilfe_zur_pflege'],
  'page_nr': tensor([5])},
 {'doc_id': tensor([2]),
  'doc_class': ['vermoegen.vermoegenserklaerung'],
  'page_nr': tensor([0])},
 {'doc_id': tensor([2]),
  'doc_class': ['vermoegen.vermoegenserk

In [14]:
from torch.nn.utils.rnn import pad_sequence

decoder_prompts = pad_sequence(
    [
        input_id[: end_idx + 1]
        for input_id, end_idx in zip( i["decoder_input_ids"], i["prompt_end_index"])
    ],
    batch_first=True,
)
o = model.model.inference(i["pixel_values"], decoder_prompts, return_json=False)

torch.Size([1, 16, 3, 420, 360])
torch.Size([1, 16, 768])
torch.Size([1, 1])


In [23]:
torch.Tensor([0]).item()

0.0

In [25]:
for pred, gt in zip(model.model.token2json(o["predictions"][0]), i["ground_truth"]):
    print(pred["doc_id"], gt["doc_id"].item(), "-", pred["page_nr"], gt["page_nr"].item(), "-", pred["doc_class"], gt["doc_class"][0])

0 0 - 0 0 - antrag.formlos anschreiben
0 0 - 1 1 - antrag.formlos anschreiben
1 1 - 0 0 - antrag.formblattantrag.einzelintegration antrag.formblattantrag.hilfe_zur_pflege
1 1 - 1 1 - antrag.formblattantrag.einzelintegration antrag.formblattantrag.hilfe_zur_pflege
1 1 - 2 2 - antrag.formblattantrag.einzelintegration antrag.formblattantrag.hilfe_zur_pflege
1 1 - 3 3 - antrag.formblattantrag.einzelintegration antrag.formblattantrag.hilfe_zur_pflege
1 1 - 4 4 - antrag.formblattantrag.einzelintegration antrag.formblattantrag.hilfe_zur_pflege
2 1 - 0 5 - antrag.formblattantrag.einzelintegration antrag.formblattantrag.hilfe_zur_pflege
2 2 - 1 0 - antrag.formblattantrag.einzelintegration vermoegen.vermoegenserklaerung
3 2 - 0 1 - antrag.formblattantrag.einzelintegration vermoegen.vermoegenserklaerung
3 3 - 1 0 - antrag.einzelintegration vermoegen.vermoegenserklaerung
3 3 - 2 1 - antrag.formblattantrag.einzelintegration vermoegen.vermoegenserklaerung
4 4 - 3 0 - antrag.einzelintegration vollmac

In [9]:
import torch
t = torch.tensor([[0,1,2,3],[4,5,6,7]])

In [10]:
t.view(-1)

tensor([0, 1, 2, 3, 4, 5, 6, 7])

In [11]:
import torch

t = [torch.zeros((1,3,5))]
torch.cat(t).size()

torch.Size([1, 3, 5])