In [1]:
import json
from pathlib import Path

import pytorch_lightning as pl
import torch
from multipage_classifier.datasets.mosaic_dataset import MosaicDataModule
from multipage_classifier.encoder.swin_encoder import SwinEncoderConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from swin_encoder.lightning_module import SwinEncoderPLModule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

EMBEDDER_MODEL = "/data/training/master_thesis/lightning_logs/swin_encoder/version_2/checkpoints/best-checkpoint.ckpt"

DATASET_PATH = "/data/training/master_thesis/datasets/2023-05-23"
CLASS_PATH = "/data/training/master_thesis/datasets/bzuf_classes.json"


MAX_PAGES = 8
BATCH_SIZE = 1


In [3]:
# Load Embedder
encoder_module = SwinEncoderPLModule.load_from_checkpoint(EMBEDDER_MODEL)
encoder_module = encoder_module.eval()


Some weights of SwinModel were not initialized from the model checkpoint at /data/training/master_thesis/models/donut-encoder and are newly initialized: ['layernorm.bias', 'layernorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# Define data module
classes = [c for c in json.load(open(CLASS_PATH))]
data_module = MosaicDataModule(Path(DATASET_PATH), classes, encoder_module.encoder.prepare_input, batch_size=BATCH_SIZE, max_pages=MAX_PAGES)

data_module.prepare_data()
data_module.setup()

In [5]:
ds = iter(data_module.test_dataloader())
n = 4
for _ in range(n):
    sample = next(ds)
    

In [6]:
sample["pixel_values"].size()

torch.Size([8, 3, 704, 512])

In [7]:
pred, gt, loss = encoder_module.step(sample)

In [8]:
gt["order"].view(-1, len(sample["pixel_values"]))

tensor([[3, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 0, 0, 0, 0, 0],
        [0, 0, 3, 1, 1, 1, 1, 1],
        [0, 0, 2, 3, 1, 1, 1, 1],
        [0, 0, 2, 2, 3, 1, 1, 1],
        [0, 0, 2, 2, 2, 3, 1, 1],
        [0, 0, 2, 2, 2, 2, 3, 1],
        [0, 0, 2, 2, 2, 2, 2, 3]])

In [29]:
bs = len(pred["class"])
order = torch.exp(pred["order"])
order = order.view(-1, bs, 4)

order.argmax(2)

tensor([[3, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 0, 0, 0, 0, 0],
        [0, 0, 3, 1, 1, 1, 1, 1],
        [0, 0, 2, 3, 2, 2, 2, 1],
        [0, 0, 2, 1, 3, 1, 1, 1],
        [0, 0, 2, 1, 2, 3, 1, 3],
        [0, 0, 2, 2, 2, 2, 3, 2],
        [0, 0, 2, 2, 2, 3, 1, 3]])

In [38]:
from sklearn.cluster import DBSCAN

cluster_prediction = DBSCAN(
    min_samples=1, metric="precomputed"
)

doc_id_probs = order[:, :, 0] # just take 0 dim "None" and invert

# make symetric
doc_id_probs = (doc_id_probs + doc_id_probs.transpose(1, 0)) / (2)

doc_ids = cluster_prediction.fit_predict(doc_id_probs.cpu().data.numpy())
doc_ids = torch.tensor(doc_ids)

doc_ids

tensor([0, 1, 2, 2, 2, 2, 2, 2])

In [57]:
order

tensor([[[3.8803e-04, 1.6957e-03, 1.6336e-03, 9.9628e-01],
         [9.9861e-01, 1.8358e-05, 1.3709e-03, 1.0216e-21],
         [9.3510e-01, 7.5570e-04, 6.4141e-02, 1.9910e-08],
         [7.6831e-01, 3.9955e-02, 1.9174e-01, 7.2950e-11],
         [9.7364e-01, 1.8778e-03, 2.4481e-02, 1.7639e-08],
         [9.7512e-01, 3.9372e-03, 2.0940e-02, 5.4863e-09],
         [9.4619e-01, 1.9510e-02, 3.4298e-02, 5.0965e-09],
         [9.5388e-01, 1.8119e-02, 2.7999e-02, 4.7300e-08]],

        [[9.9912e-01, 8.7407e-04, 1.0278e-05, 1.2901e-21],
         [9.8162e-03, 2.6957e-03, 2.6751e-03, 9.8481e-01],
         [9.9861e-01, 1.0378e-03, 3.5038e-04, 1.1848e-17],
         [9.8898e-01, 1.0231e-02, 7.8926e-04, 4.8051e-14],
         [9.9835e-01, 1.5134e-03, 1.3214e-04, 3.1528e-20],
         [9.9848e-01, 1.4530e-03, 6.4232e-05, 2.3590e-22],
         [9.9849e-01, 1.4770e-03, 3.6317e-05, 4.3878e-23],
         [9.9263e-01, 7.2585e-03, 1.0918e-04, 2.9285e-21]],

        [[9.4910e-01, 5.0281e-02, 6.1583e-04, 2.7272

In [62]:
for doc_id in torch.unique(doc_ids):
    indices = torch.nonzero(doc_ids == doc_id).reshape(-1)
    sub_order_mat =  order[indices][:, indices]   
    print(sub_mat.argmax(2))

tensor([0])
tensor([[3]])
tensor([1])
tensor([[3]])
tensor([2, 3, 4, 5, 6, 7])
tensor([[3, 1, 1, 1, 1, 1],
        [2, 3, 2, 2, 2, 1],
        [2, 1, 3, 1, 1, 1],
        [2, 1, 2, 3, 1, 3],
        [2, 2, 2, 2, 3, 2],
        [2, 2, 2, 3, 1, 3]])


In [67]:
order_matrix = sub_mat

In [68]:
# extract "pred" and "succ" prediction. NOTE this changes the indices -> ["Pred", "Succ"]
page_nr_probs = order_matrix[:, :, 1:3] 

# Make "page order symetric": 0 maps to 1 and the other way arround
bs = len(page_nr_probs)
for i in range(bs):
    for j in range(i, bs):
        avg = (page_nr_probs[i, j] + page_nr_probs[j, i].flip(0)) / 2
        page_nr_probs[i, j] = avg
        page_nr_probs[j, i] = avg.flip(0)

page_nr_preds = page_nr_probs.argmax(2)

# Create a graph representation using a dictionary
graph = defaultdict(list)
for i in range(bs):
    for j in range(bs):
        if i != j:
            if page_nr_preds[i, j] == 0:  # i is predecessor
                graph[i].append(j)
            elif page_nr_preds[i, j] == 1:  # i is successor
                graph[j].append(i)

# Perform topological sorting using Kahn's algorithm

# Count the numbner of incomming edges
in_degree = [0] * bs
for node, neighbors in graph.items():
    for neighbor in neighbors:
        in_degree[neighbor] += 1

# List of "start nodes" which have no incoming edges
queue = deque([node for node, degree in enumerate(in_degree) if degree == 0])

topological_order = []
while queue:
    node = queue.popleft()
    topological_order.append(node)

    for neighbor in graph[node]:
        in_degree[neighbor] -= 1
        if in_degree[neighbor] == 0:
            queue.append(neighbor)

torch.tensor(topological_order)

tensor([0, 2, 3, 1, 5, 4])

In [65]:
page_nr_probs = order[:, :, 1:3]

In [16]:
bs = len(sample["pixel_values"])
for i in range(bs):
    for j in range(i, bs):
        avg = (page_nr_probs[i, j] + page_nr_probs[j, i].flip(0)) / 2
        page_nr_probs[i, j] = avg
        page_nr_probs[j, i] = avg.flip(0)


In [17]:
prediction_matrix = page_nr_probs[2:, 2:].argmax(2)
prediction_matrix

tensor([[0, 0, 0, 0, 0, 0],
        [1, 0, 1, 1, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 0, 0]])