In [34]:
from pathlib import Path
import torch
import pytorch_lightning as pl
from lightning_module import SwinEncoderPLModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from multipage_classifier.datasets.ucsf_dataset import UCSFDataModule
from multipage_classifier.encoder.swin_encoder import SwinEncoderConfig
from lightning_module import SwinEncoderPLModule
from multipage_classifier.preprocessor import ImageProcessor


In [35]:

EMBEDDER_MODEL = "../lightning_logs/multipage_swin_encoder/version_3/checkpoints/best-checkpoint.ckpt"
DATASET_PATH="../../dataset/ucsf-idl-resized-without_emails/"
DS_FILE = "IDL-less_20_pages.json"


MAX_PAGES = 8
NUM_WORKERS = 8

IMAGE_SIZE  = (512, 704) 

In [36]:
# Define preprocessor
image_processor = ImageProcessor(img_size=IMAGE_SIZE)

data_module = UCSFDataModule(
    Path(DATASET_PATH),
    DS_FILE,
    prepare_function=image_processor.prepare_input,
    split=[0.8, 0.2],
    max_pages=MAX_PAGES,
    num_workers=NUM_WORKERS
)
data_module.prepare_data()
data_module.setup()

In [37]:
# Load Embedder
embedder = SwinEncoderPLModule.load_from_checkpoint(EMBEDDER_MODEL)
embedder.eval()
pass

In [38]:
pixel_input = next(iter(data_module.test_dataloader()))

In [39]:
pred, gt, loss = embedder.step(pixel_input)

In [40]:
gt.view(-1, len(pixel_input["pixel_values"]))

tensor([[3, 1, 1, 1],
        [2, 3, 1, 1],
        [2, 2, 3, 1],
        [2, 2, 2, 3]])

In [41]:
pred.argmax(1).view(-1, len(pixel_input["pixel_values"]))

tensor([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]])