In [1]:
from pathlib import Path

import pytorch_lightning as pl
from lightning_module import MultipageClassifierPLModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from multipage_classifier.datasets.ucsf_dataset import UCSFDataModule
from multipage_classifier.page_classifier import (MultipageClassifier,
                                                  MultipageClassifierConfig)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

NAME = "multipage_classifier"
MODEL_PATH = "../lightning_logs/multipage_classifier/version_1/checkpoints/best-checkpoint.ckpt"
DATASET_PATH = "../../dataset/ucsf-idl-resized-without_emails/"
DS_FILE = "IDL-less_20_pages.json"

N_EPOCHS = 100
MAX_PAGES = 64
NUM_WORKERS = 8

MAX_LENGTH = 768
IMAGE_SIZE  = [512, 704] 

In [3]:
# Define Model
config = MultipageClassifierConfig(
    input_size=IMAGE_SIZE,
    max_pages=MAX_PAGES
)
model = MultipageClassifierPLModule.load_from_checkpoint(MODEL_PATH)

data_module = UCSFDataModule(
    Path(DATASET_PATH),
    DS_FILE,
    # TODO maybe call this in PLModule
    prepare_function=model.classifier.prepare_input,
    split=[0.8, 0.2],
    max_pages=MAX_PAGES,
    num_workers=NUM_WORKERS
)

data_module.prepare_data()
data_module.setup()


In [4]:
model_input = next(iter(data_module.test_dataloader()))

In [5]:
pred, prob =model.classifier.predict(model_input["pixel_values"])

In [6]:
for i in range(len(model_input["doc_id"])):
    print(f"GT: {model_input['doc_id'].tolist()[i]}, ",
          f"Pred: {pred.tolist()[i]}, ")

GT: 0,  Pred: 0, 
GT: 0,  Pred: 0, 
GT: 1,  Pred: 0, 
GT: 1,  Pred: 0, 
GT: 1,  Pred: 0, 
GT: 1,  Pred: 0, 
GT: 1,  Pred: 0, 
GT: 1,  Pred: 0, 
GT: 2,  Pred: 1, 
GT: 2,  Pred: 2, 
GT: 2,  Pred: 2, 
GT: 2,  Pred: 2, 
GT: 2,  Pred: 2, 
GT: 2,  Pred: 2, 
GT: 3,  Pred: 2, 
GT: 3,  Pred: 2, 
GT: 3,  Pred: 2, 
GT: 3,  Pred: 3, 
GT: 3,  Pred: 4, 
GT: 4,  Pred: 5, 
GT: 5,  Pred: 6, 
GT: 5,  Pred: 7, 
GT: 5,  Pred: 8, 
GT: 5,  Pred: 9, 
GT: 5,  Pred: 10, 
GT: 5,  Pred: 10, 
GT: 5,  Pred: 10, 
GT: 5,  Pred: 10, 
GT: 5,  Pred: 10, 
GT: 5,  Pred: 10, 
GT: 5,  Pred: 10, 
GT: 6,  Pred: 11, 
GT: 6,  Pred: 12, 
GT: 6,  Pred: 12, 
GT: 6,  Pred: 13, 
GT: 7,  Pred: 14, 
GT: 7,  Pred: 14, 
GT: 7,  Pred: 15, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 7,  Pred: 16, 
GT: 8,  Pred: 16, 
GT: 8,  Pred: 17, 
GT: 9,  Pred: 18, 
GT: 9,  Pred: 18, 
GT: 9,  Pred: 19, 
GT: 9,  Pred: 20, 
GT: 9,  Pred: -1,