In [None]:
import cv2
import math
import glob
import inspect
import json
import mlflow
import torch
import os
import time
import sys

import lightning as L
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorch_lightning.loggers import MLFlowLogger
from lightning.pytorch.callbacks import ModelCheckpoint

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root, train=True):
        self.root = root

        self.labels = dict()
        self.labels["Background"] = 0
        self.labels["Text"] = 1
        self.labels["List"] = 2
        self.labels["Picture"] = 3
        self.labels["Header"] = 4

        self.labels["Table"] = 5
        self.labels["Caption"] = 6
        self.labels["Footnote"] = 7
        self.labels["Formula"] = 8

        self.elements = []
        filenames = glob.glob(os.path.join(root, "labels/*.json"))
        for filename in filenames:
            with open(filename, "r") as f:
                data = json.load(f)
            for element in data:
                self.elements.append(element)

        n_train = int(len(self) * 0.8)

        if train:
            self.elements = self.elements[:n_train]
        else:
            self.elements = self.elements[n_train:]

    def __len__(self):
        return len(self.elements)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        element = self.elements[idx]

        filename = element["data"]["image"]

        fileparts = filename.split("/")
        filename = fileparts[-1].split("-")[0]
        img = cv2.imread(os.path.join(self.root, "images", filename))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img / 255
        img = np.asarray(img, dtype=np.float32)
        h, w, c = img.shape
        label_img = np.zeros([h, w])
        
        annotations = element["annotations"][0]["result"]
        for annotation in annotations:
            original_width = annotation["original_width"]
            original_height = annotation["original_height"]
            
            percentage_x = annotation["value"]["x"]
            percentage_y = annotation["value"]["y"]
            xl = int((percentage_x * original_width) / 100)
            yl = int((percentage_y * original_height) / 100)
            
            percentage_width = annotation["value"]["width"]
            percentage_height = annotation["value"]["height"]
            width = int((percentage_width * original_width) / 100)
            height = int((percentage_height * original_height) / 100)

            xu = xl + width
            yu = yl + height
            
            label = annotation["value"]["rectanglelabels"][0]
            label_idx = self.labels[label]
            label_img[yl:yu, xl:xu] = label_idx

        N = 16
        p = 256

        nH = h // p
        nW = w // p
        H = nH * p
        W = nW * p

        img = cv2.resize(img, (W, H))
        label_img = cv2.resize(label_img, (W, H))
            
        patches = np.zeros([N, p, p, 3], dtype=np.float32)
        label_patches = np.zeros([N, p, p, 1], dtype=np.int64)
        pos_embeddings = np.zeros([N, p, p, 1], dtype=np.float32)
        
        coords = np.meshgrid(np.linspace(0, 1, W), 
                             np.linspace(0, 1, H))
        coords = np.asarray(coords, dtype=np.float32)
        embedding = coords[0] * coords[1]

        coords = []
        for i in range(N):
            rh = np.random.randint(0, nH - 1)
            rw = np.random.randint(0, nW - 1)
            xl = rw * p
            yl = rh * p
            xu = (rw + 1) * p
            yu = (rh + 1) * p

            coords.append((xl, xu, yl, yu))
            
            patches[i] = img[yl:yu, xl:xu]
            label_patches[i] = label_img[yl:yu, xl:xu, None]
            pos_embeddings[i] = embedding[yl:yu, xl:xu, None]

        coords = np.asarray(coords)

        img = img.transpose(2, 0, 1)
        label_img = label_img[None]
        patches = patches.transpose(0, 3, 1, 2)
        label_patches = label_patches.transpose(0, 3, 1, 2)
        pos_embeddings = pos_embeddings.transpose(0, 3, 1, 2)
        
        sample = dict()
        sample["img"] = img
        sample["label_img"] = label_img
        sample["coords"] = coords
        sample["patches"] = patches
        sample["label_patches"] = label_patches
        sample["pos_embeddings"] = pos_embeddings

        return sample

In [None]:
ROOT_PATH = "../data"
BATCH_SIZE = 16
NUM_WORKERS = 0
DEVICE = "cuda"
LR = 1e-3
N_CLASSES = 9
N_EPOCHS = 10
MODEL_IDENTIFIER = "multi-seg-documents"

In [None]:
checkpoint = None

In [None]:
train_dataset = CustomDataset(ROOT_PATH, train=True)
test_dataset = CustomDataset(ROOT_PATH, train=False)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
print("Train:", len(train_dataset))
print("Test:", len(test_dataset))

In [None]:
element = next(iter(train_loader))

img = element["img"]
label_img = element["label_img"]
coords = element["coords"]
patches = element["patches"]
label_patches = element["label_patches"]
pos_embeddings = element["pos_embeddings"]

print("img", img.shape)
print("label_img", label_img.shape)
print("coords", coords.shape)
print("patches", patches.shape)
print("label_patches", label_patches.shape)
print("pos_embeddings", pos_embeddings.shape)

In [None]:
@torch.no_grad
def test_model_config():
    batch = next(iter(train_loader))
    
    img = batch["img"]
    label_img = batch["label_img"]
    coords = batch["coords"]
    patches = batch["patches"].to(DEVICE)
    label_patches = batch["label_patches"].to(DEVICE)
    pos_embeddings = batch["pos_embeddings"].to(DEVICE)

    BS, N, D, H, W = patches.shape

    print(f"input")
    print(f"  {patches.shape}")
    print(f"  {pos_embeddings.shape}")
        
    pred = model(patches, pos_embeddings)

    print(f"output")
    print(f"  {pred.shape}")
    print()
    
    pred = pred.permute(0, 1, 3, 4, 2)
    pred = pred.reshape(-1, N_CLASSES)

    label_patches = label_patches.permute(0, 1, 3, 4, 2)
    label_patches = label_patches.reshape(-1)
    
    loss = F.cross_entropy(pred, label_patches)
    print(loss)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, f_in, f_mid, f_out):
        super().__init__()

        self.cnn = nn.Sequential(nn.ConvTranspose2d(f_in, f_mid, 4, 2),
                                 nn.ReLU(),
                                 nn.Conv2d(f_mid, f_mid, 3, 1, padding=1),
                                 nn.ReLU(),
                                 nn.Conv2d(f_mid, f_out, 3, 1, padding=0),
                                 nn.ReLU())

    def forward(self, x):
        x = self.cnn(x)
        return x
class Decoder(nn.Module):
    def __init__(self, f_in, n_layers):
        super().__init__()

        self.n_layers = n_layers
        self.blocks = nn.ModuleList([DecoderBlock(f_in, f_in * 2, f_in) for _ in range(n_layers)])

    def forward(self, xs):
        x = xs[0]
        x = self.blocks[0](x)
        for i in range(1, self.n_layers):
            x = self.blocks[i](x + xs[i])
        return x
class EncoderBlock(nn.Module):
    def __init__(self, f_in, f_mid, f_out):
        super().__init__()

        self.cnn = nn.Sequential(nn.Conv2d(f_in, f_mid, 3, 1, padding=1),
                                 nn.ReLU(),
                                 nn.Conv2d(f_mid, f_out, 3, 1, padding=1),
                                 nn.AvgPool2d((2, 2)),
                                 nn.ReLU())

    def forward(self, x):
        x = self.cnn(x)
        return x
class Encoder(nn.Module):
    def __init__(self, f_in, n_layers):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(f_in, f_in * 2, f_in) for _ in range(n_layers)])

    def forward(self, x):
        xs = []
        for block in self.blocks:
            x = block(x)
            xs.append(x)
        return xs
class Model(L.LightningModule):
    def __init__(self, f=64, n_patches=16, n_classes=9, n_layers=4):
        super().__init__()

        self.automatic_optimization = False

        self.encoder = Encoder(f, n_layers)
        self.decoder = Decoder(f, n_layers)
        self.out = nn.Conv2d(f, n_patches * n_classes, 3, 1, padding=1)

    def forward(self, x, pos_emb):

        x = torch.cat([x, pos_emb], dim=2)
        BS, N, D, H, W = x.shape
        x = x.view(BS, N * D, H, W)
        xs = self.encoder(x)
        xs = [x, *xs][::-1]
        x = self.decoder(xs)
        x = self.out(x)
        x = x.view(BS, N, -1, H, W)

        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LR)
        return optimizer

    def training_step(self, batch, batch_idx):

        patches = batch["patches"]
        label_patches = batch["label_patches"]
        pos_embeddings = batch["pos_embeddings"]

        optimizer = self.configure_optimizers()

        pred = self(patches, pos_embeddings)

        pred = pred.permute(0, 1, 3, 4, 2)
        pred = pred.reshape(-1, N_CLASSES)
    
        label_patches = label_patches.permute(0, 1, 3, 4, 2)
        label_patches = label_patches.reshape(-1)
        
        loss = F.cross_entropy(pred, label_patches)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss

    def on_train_epoch_end(self):
        print()

    def test_step(self, batch, batch_idx):

        patches = batch["patches"]
        label_patches = batch["label_patches"]
        pos_embeddings = batch["pos_embeddings"]

        optimizer = self.configure_optimizers()

        pred = self(patches, pos_embeddings)

        pred = pred.permute(0, 1, 3, 4, 2)
        pred = pred.reshape(-1, N_CLASSES)
    
        label_patches = label_patches.permute(0, 1, 3, 4, 2)
        label_patches = label_patches.reshape(-1)
        
        loss = F.cross_entropy(pred, label_patches)
        
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss

    def val_step(self, batch, batch_idx):

        patches = batch["patches"]
        label_patches = batch["label_patches"]
        pos_embeddings = batch["pos_embeddings"]

        optimizer = self.configure_optimizers()

        pred = self(patches, pos_embeddings)

        pred = pred.permute(0, 1, 3, 4, 2)
        pred = pred.reshape(-1, N_CLASSES)
    
        label_patches = label_patches.permute(0, 1, 3, 4, 2)
        label_patches = label_patches.reshape(-1)
        
        loss = F.cross_entropy(pred, label_patches)
        
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss

In [None]:
model = Model().to(DEVICE)

mlflow_logger = MLFlowLogger(
    experiment_name="default",
    tracking_uri="http://localhost:8080"
)

trainer = L.Trainer(devices=1, 
                    accelerator="gpu",
                    max_epochs=N_EPOCHS,
                    enable_checkpointing=True,
                    benchmark=True,
                    logger=mlflow_logger)

trainer.fit(model=model, 
            train_dataloaders=train_loader)