In [2]:
import json
from pathlib import Path

import pytorch_lightning as pl
import torch
from multipage_classifier.datasets.mosaic_dataset import MosaicDataModule
from multipage_classifier.encoder.swin_encoder import SwinEncoderConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from swin_encoder.lightning_module import SwinEncoderPLModule

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

EMBEDDER_MODEL = "/data/training/master_thesis/lightning_logs/swin_encoder/version_2/checkpoints/best-checkpoint.ckpt"

DATASET_PATH = "/data/training/master_thesis/datasets/2023-05-23"
CLASS_PATH = "/data/training/master_thesis/datasets/bzuf_classes.json"


MAX_PAGES = 8
BATCH_SIZE = 1


In [4]:
# Load Embedder
encoder_module = SwinEncoderPLModule.load_from_checkpoint(EMBEDDER_MODEL)
encoder_module = encoder_module.eval()


Some weights of SwinModel were not initialized from the model checkpoint at /data/training/master_thesis/models/donut-encoder and are newly initialized: ['layernorm.bias', 'layernorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Define data module
classes = [c for c in json.load(open(CLASS_PATH))]
data_module = MosaicDataModule(Path(DATASET_PATH), classes, encoder_module.encoder.prepare_input, batch_size=BATCH_SIZE, max_pages=MAX_PAGES)

data_module.prepare_data()
data_module.setup()

In [6]:
ds = iter(data_module.test_dataloader())
n = 4
for _ in range(n):
    sample = next(ds)
    

In [7]:
sample["pixel_values"].size()

torch.Size([8, 3, 704, 512])

In [8]:
pred, gt, loss = encoder_module.step(sample)

In [9]:
gt["order"].view(-1, len(sample["pixel_values"]))

tensor([[3, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 0, 0, 0, 0, 0],
        [0, 0, 3, 1, 1, 1, 1, 1],
        [0, 0, 2, 3, 1, 1, 1, 1],
        [0, 0, 2, 2, 3, 1, 1, 1],
        [0, 0, 2, 2, 2, 3, 1, 1],
        [0, 0, 2, 2, 2, 2, 3, 1],
        [0, 0, 2, 2, 2, 2, 2, 3]])

In [10]:
bs = len(pred["class"])
order = torch.exp(pred["order"])
order = order.view(-1, bs, 4)

order.argmax(2)

tensor([[3, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 0, 0, 0, 0, 0],
        [0, 0, 3, 1, 1, 1, 1, 1],
        [0, 0, 2, 3, 2, 2, 2, 1],
        [0, 0, 2, 1, 3, 1, 1, 1],
        [0, 0, 2, 1, 2, 3, 1, 3],
        [0, 0, 2, 2, 2, 2, 3, 2],
        [0, 0, 2, 2, 2, 3, 1, 3]])

In [11]:
from sklearn.cluster import DBSCAN

cluster_prediction = DBSCAN(
    min_samples=1, metric="precomputed"
)

doc_id_probs = order[:, :, 0] # just take 0 dim "None" and invert

# make symetric
doc_id_probs = (doc_id_probs + doc_id_probs.transpose(1, 0)) / (2)

doc_ids = cluster_prediction.fit_predict(doc_id_probs.cpu().data.numpy())
doc_ids = torch.tensor(doc_ids)

doc_ids

tensor([0, 1, 2, 2, 2, 2, 2, 2])

In [12]:
order.size()

torch.Size([8, 8, 4])

In [22]:
from collections import defaultdict, deque


for doc_id in torch.unique(doc_ids):
    print()
    print("sub")
    indices = torch.nonzero(doc_ids == doc_id).reshape(-1)
    sub_order_mat =  order[indices][:, indices]   

    page_nr_probs = sub_order_mat[:, :, 1:3] 
    # Make "page order symetric": 0 maps to 1 and the other way arround
    bs = len(page_nr_probs)
    for i in range(bs):
        for j in range(i, bs):
            avg = (page_nr_probs[i, j] + page_nr_probs[j, i].flip(0)) / 2
            page_nr_probs[i, j] = avg
            page_nr_probs[j, i] = avg.flip(0)

    page_nr_preds = page_nr_probs.argmax(2)
    print(page_nr_preds)

    # Create a graph representation using a dictionary
    graph = defaultdict(list)
    for i in range(bs):
        for j in range(i, bs):
            if i != j:
                if page_nr_preds[i, j] == 0:  # i is predecessor
                    graph[i].append(j)
                elif page_nr_preds[i, j] == 1:  # i is successor
                    graph[j].append(i)
    print(graph)
    # Count the numbner of incomming edges
    in_degree = [0] * bs
    for node, neighbors in graph.items():
        for neighbor in neighbors:
            in_degree[neighbor] += 1
    
    # List of "start nodes" which have no incoming edges
    queue = deque([node for node, degree in enumerate(in_degree) if degree == 0])
    
    topological_order: list[int] = []
    while queue:
        node = queue.popleft()
        topological_order.append(node)

        for neighbor in graph[node]:
            in_degree[neighbor] -= 1
            if in_degree[neighbor] == 0:
                queue.append(neighbor)
    print(topological_order)
    


sub
tensor([[0]])
defaultdict(<class 'list'>, {})
[0]

sub
tensor([[0]])
defaultdict(<class 'list'>, {})
[0]

sub
tensor([[0, 0, 0, 0, 0, 0],
        [1, 0, 1, 1, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 0, 0]])
defaultdict(<class 'list'>, {0: [1, 2, 3, 4, 5, 1, 2, 3, 4, 5], 2: [1, 1, 3, 4, 5, 3, 4, 5], 3: [1, 1, 4, 5, 4, 5], 1: [4, 5, 4, 5], 5: [4, 4]})
[0, 2, 3, 1, 5, 4]


In [None]:
order_matrix = sub_mat

In [None]:
# extract "pred" and "succ" prediction. NOTE this changes the indices -> ["Pred", "Succ"]
page_nr_probs = order_matrix[:, :, 1:3] 

# Make "page order symetric": 0 maps to 1 and the other way arround
bs = len(page_nr_probs)
for i in range(bs):
    for j in range(i, bs):
        avg = (page_nr_probs[i, j] + page_nr_probs[j, i].flip(0)) / 2
        page_nr_probs[i, j] = avg
        page_nr_probs[j, i] = avg.flip(0)

page_nr_preds = page_nr_probs.argmax(2)

# Create a graph representation using a dictionary
graph = defaultdict(list)
for i in range(bs):
    for j in range(bs):
        if i != j:
            if page_nr_preds[i, j] == 0:  # i is predecessor
                graph[i].append(j)
            elif page_nr_preds[i, j] == 1:  # i is successor
                graph[j].append(i)

# Perform topological sorting using Kahn's algorithm

# Count the numbner of incomming edges
in_degree = [0] * bs
for node, neighbors in graph.items():
    for neighbor in neighbors:
        in_degree[neighbor] += 1

# List of "start nodes" which have no incoming edges
queue = deque([node for node, degree in enumerate(in_degree) if degree == 0])

topological_order = []
while queue:
    node = queue.popleft()
    topological_order.append(node)

    for neighbor in graph[node]:
        in_degree[neighbor] -= 1
        if in_degree[neighbor] == 0:
            queue.append(neighbor)

torch.tensor(topological_order)

tensor([0, 2, 3, 1, 5, 4])

In [None]:
page_nr_probs = order[:, :, 1:3]

In [None]:
bs = len(sample["pixel_values"])
for i in range(bs):
    for j in range(i, bs):
        avg = (page_nr_probs[i, j] + page_nr_probs[j, i].flip(0)) / 2
        page_nr_probs[i, j] = avg
        page_nr_probs[j, i] = avg.flip(0)


In [None]:
prediction_matrix = page_nr_probs[2:, 2:].argmax(2)
prediction_matrix

tensor([[0, 0, 0, 0, 0, 0],
        [1, 0, 1, 1, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 0, 0]])

In [None]:
import numpy as np
from collections import defaultdict, deque

n = len(prediction_matrix)

graph = defaultdict(list)
for i in range(n):
    for j in range(i, n): # "symetric" in the page order way: 0 -> 1 and 1 -> 0
        if i != j:
            if prediction_matrix[i, j] == 1:  # Predecessor
                graph[i].append(j)
            elif prediction_matrix[i, j] == 0:  # Successor
                graph[j].append(i)


# Perform topological sorting using Kahn's algorithm


# Count the numbner of incomming edges
in_degree = [0] * n
for node, neighbors in graph.items():
    for neighbor in neighbors:
        in_degree[neighbor] += 1


# List of "start nodes" which have no incoming edges
queue = deque([node for node, degree in enumerate(in_degree) if degree == 0])
topological_order = []

while S:
    node = queue.popleft()
    topological_order.append(node)

    for neighbor in graph[node]:
        in_degree[neighbor] -= 1
        if in_degree[neighbor] == 0:
            queue.append(neighbor)


NameError: name 'S' is not defined

In [None]:
dict(sorted(graph.items()))

{1: [0, 2, 3], 2: [0], 3: [0, 2], 4: [0, 1, 2, 3, 5], 5: [0, 1, 2, 3]}