In [23]:
import json
from pathlib import Path

import pytorch_lightning as pl
import torch
from multipage_classifier.datasets.mosaic_dataset import MosaicDataModule
from multipage_classifier.encoder.swin_encoder import SwinEncoderConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from swin_encoder.lightning_module import SwinEncoderPLModule

In [24]:

EMBEDDER_MODEL = "/data/training/master_thesis/lightning_logs/swin_encoder/version_5/checkpoints/best-checkpoint.ckpt"

DATASET_PATH = "/data/training/master_thesis/datasets/2023-05-23"
CLASS_PATH = "/data/training/master_thesis/datasets/bzuf_classes.json"


MAX_PAGES = 8
BATCH_SIZE = 1

IMAGE_SIZE = (704, 512) # height, width 

In [25]:
# Load Embedder
encoder_module = SwinEncoderPLModule.load_from_checkpoint(EMBEDDER_MODEL)
encoder_module = encoder_module.eval()


Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [26]:
# Define data module
classes = [c for c in json.load(open(CLASS_PATH))]
data_module = MosaicDataModule(Path(DATASET_PATH), classes, encoder_module.encoder.prepare_input, batch_size=BATCH_SIZE, max_pages=MAX_PAGES)

data_module.prepare_data()
data_module.setup()

In [27]:
pixel_input = next(iter(data_module.test_dataloader()))

In [28]:
pred, gt, loss = encoder_module.step(pixel_input)

In [29]:
gt["order"].view(-1, len(pixel_input["pixel_values"]))

tensor([[3, 0],
        [0, 3]])

In [30]:
pred["order"].argmax(1).view(-1, len(pixel_input["pixel_values"]))

tensor([[3, 0],
        [0, 3]])