# TableTransformer - Dataset Extractor

### Imports

In [1]:
import os
import time
import pathlib

import numpy as np
import cv2
import torch
from torchvision import transforms
from PIL import Image, ImageDraw, ImageTk
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch

from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection

import tkinter as tk
from tkinter import ttk
from IPython.display import clear_output

In [2]:
# CUDA Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


### Table Detection and Processing

Function: box_cxcywh_to_xyxy(x) :
Converts bounding boxes from center coordinates and size (cx, cy, w, h) format to corner coordinates (x1, y1, x2, y2) format

Function: rescale_bboxes(out_bbox, size) :
Rescales bounding boxes to the size of the image.

Function: outputs_to_objects(outputs, img_size, id2label) :
Converts model outputs to a list of detected objects with labels, scores, and bounding boxes.

Function: fig2img(fig) :
Converts a Matplotlib figure to a PIL Image.

Function: visualize_detected_tables(img, det_tables, out_path=None) :
Visualizes detected tables on an image, highlighting them with different colors based on their type.

Function: objects_to_crops(img, tokens, objects, class_thresholds, padding=10) :
Processes detected bounding boxes into cropped table images and their respective tokens.

Function: get_cell_coordinates_by_row(table_data) :
Extracts cell coordinates from table data, organizing them by rows and columns.

In [3]:
class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize(
            (int(round(scale * width)), int(round(scale * height)))
        )

        return resized_image


# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


# Object detection
def outputs_to_objects(outputs, img_size, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = id2label[int(label)]
        if not class_label == "no object":
            objects.append(
                {
                    "label": class_label,
                    "score": float(score),
                    "bbox": [float(elem) for elem in bbox],
                }
            )

    return objects


def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io

    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


def visualize_detected_tables(img, det_tables, out_path=None):
    plt.imshow(img, interpolation="lanczos")
    fig = plt.gcf()
    fig.set_size_inches(20, 20)
    ax = plt.gca()

    for det_table in det_tables:
        bbox = det_table["bbox"]

        # Extend the bottom edge of the bounding box
        extend_height = (bbox[3] - bbox[1]) * 0.05
        bbox[3] += extend_height

        # Extend the top edge of the bounding box
        bbox[1] -= extend_height

        if det_table["label"] == "table":
            facecolor = (1, 0, 0.45)
            edgecolor = (1, 0, 0.45)
            alpha = 0.3
            linewidth = 2
            hatch = "//////"
        elif det_table["label"] == "table rotated":
            facecolor = (0.95, 0.6, 0.1)
            edgecolor = (0.95, 0.6, 0.1)
            alpha = 0.3
            linewidth = 2
            hatch = "//////"
        else:
            continue

        rect = patches.Rectangle(
            bbox[:2],
            bbox[2] - bbox[0],
            bbox[3] - bbox[1],
            linewidth=linewidth,
            edgecolor="none",
            facecolor=facecolor,
            alpha=0.1,
        )
        ax.add_patch(rect)
        rect = patches.Rectangle(
            bbox[:2],
            bbox[2] - bbox[0],
            bbox[3] - bbox[1],
            linewidth=linewidth,
            edgecolor=edgecolor,
            facecolor="none",
            linestyle="-",
            alpha=alpha,
        )
        ax.add_patch(rect)
        rect = patches.Rectangle(
            bbox[:2],
            bbox[2] - bbox[0],
            bbox[3] - bbox[1],
            linewidth=0,
            edgecolor=edgecolor,
            facecolor="none",
            linestyle="-",
            hatch=hatch,
            alpha=0.2,
        )
        ax.add_patch(rect)

    plt.xticks([], [])
    plt.yticks([], [])

    legend_elements = [
        Patch(
            facecolor=(1, 0, 0.45),
            edgecolor=(1, 0, 0.45),
            label="Table",
            hatch="//////",
            alpha=0.3,
        ),
        Patch(
            facecolor=(0.95, 0.6, 0.1),
            edgecolor=(0.95, 0.6, 0.1),
            label="Table (rotated)",
            hatch="//////",
            alpha=0.3,
        ),
    ]
    plt.legend(
        handles=legend_elements,
        bbox_to_anchor=(0.5, -0.02),
        loc="upper center",
        borderaxespad=0,
        fontsize=10,
        ncol=2,
    )
    plt.gcf().set_size_inches(10, 10)
    plt.axis("off")

    if out_path is not None:
        plt.savefig(out_path, bbox_inches="tight", dpi=150)

    return fig


def objects_to_crops(img, tokens, objects, class_thresholds, padding=10):
    """
    Process the bounding boxes produced by the table detection model into
    cropped table images and cropped tokens.
    """

    table_crops = []
    for obj in objects:
        if obj["score"] < class_thresholds[obj["label"]]:
            continue

        cropped_table = {}

        bbox = obj["bbox"]
        bbox = [
            bbox[0] - padding,
            bbox[1] - padding,
            bbox[2] + padding,
            bbox[3] + padding,
        ]

        cropped_img = img.crop(bbox)

        table_tokens = [token for token in tokens if iob(token["bbox"], bbox) >= 0.5]
        for token in table_tokens:
            token["bbox"] = [
                token["bbox"][0] - bbox[0],
                token["bbox"][1] - bbox[1],
                token["bbox"][2] - bbox[0],
                token["bbox"][3] - bbox[1],
            ]

        # If table is predicted to be rotated, rotate cropped image and tokens/words:
        if obj["label"] == "table rotated":
            cropped_img = cropped_img.rotate(270, expand=True)
            for token in table_tokens:
                bbox = token["bbox"]
                bbox = [
                    cropped_img.size[0] - bbox[3] - 1,
                    bbox[0],
                    cropped_img.size[0] - bbox[1] - 1,
                    bbox[2],
                ]
                token["bbox"] = bbox

        cropped_table["image"] = cropped_img
        cropped_table["tokens"] = table_tokens

        table_crops.append(cropped_table)

    return table_crops


def get_cell_coordinates_by_row(table_data):
    # Extract rows and columns
    rows = [entry for entry in table_data if entry["label"] == "table row"]
    columns = [entry for entry in table_data if entry["label"] == "table column"]

    # Sort rows and columns by their Y and X coordinates, respectively
    rows.sort(key=lambda x: x["bbox"][1])
    columns.sort(key=lambda x: x["bbox"][0])

    # Function to find cell coordinates
    def find_cell_coordinates(row, column):
        cell_bbox = [
            column["bbox"][0],
            row["bbox"][1],
            column["bbox"][2],
            row["bbox"][3],
        ]
        return cell_bbox

    # Generate cell coordinates and count cells in each row
    cell_coordinates = []

    for row in rows:
        row_cells = []
        for column in columns:
            cell_bbox = find_cell_coordinates(row, column)
            row_cells.append({"column": column["bbox"], "cell": cell_bbox})

        # Sort cells in the row by X coordinate
        row_cells.sort(key=lambda x: x["column"][0])

        # Append row information to cell_coordinates
        cell_coordinates.append(
            {"row": row["bbox"], "cells": row_cells, "cell_count": len(row_cells)}
        )

    # Sort rows from top to bottom
    cell_coordinates.sort(key=lambda x: x["row"][1])

    return cell_coordinates

### Table Detection Model
Loads a pre-trained object detection model specifically designed for table detection.

### Table Structure Recognition Model
Loads a pre-trained model for recognizing table structures, and applies necessary transformations to the input images.

In [4]:
model = AutoModelForObjectDetection.from_pretrained(
    "microsoft/table-transformer-detection", revision="no_timm"
)

print("Model config: ", model.config.id2label)

detection_transform = transforms.Compose(
    [
        MaxResize(800),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# update id2label to include "no object"
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"


#  Structure Model
structure_model = TableTransformerForObjectDetection.from_pretrained(
    "microsoft/table-structure-recognition-v1.1-all"
)
structure_model.to("cpu")

structure_transform = transforms.Compose(
    [
        MaxResize(1000),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# update id2label to include "no object"
structure_id2label = structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"

Model config:  {0: 'table', 1: 'table rotated'}


### GUI for Cell Classification
#### Function: create_classification_gui(cell_crops, file_name_without_extension) :
Creates a GUI for classifying cell images into three categories: "true", "false", and "none". The classified images are converted to grayscale and saved in corresponding directories.

In [5]:
def create_classification_gui(cell_crops, file_name_without_extension):
    dataset_path = pathlib.Path("dataset/train")
    sub_paths = ["true", "false", "none"]

    for sub_path in sub_paths:
        path = dataset_path / sub_path
        path.mkdir(parents=True, exist_ok=True)

    def save_classified_image(image_array, classification, cell_number):
        new_file_name = f"{file_name_without_extension}-cell-{cell_number}.png"
        image_path = dataset_path / classification / f"{classification}-{new_file_name}"
        
        # Convert to grayscale
        grayscale_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
        
        plt.imsave(str(image_path), grayscale_image, cmap='gray')
        print(f"Saved grayscale image to {image_path}")

    root = tk.Tk()
    root.title("Cell Classification")

    frame = ttk.Frame(root, padding="10")
    frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))

    image_label = ttk.Label(frame)
    image_label.grid(row=0, column=0, columnspan=3, pady=10)

    cell_info_label = ttk.Label(frame, text="")
    cell_info_label.grid(row=1, column=0, columnspan=3, pady=5)

    current_cell = [0]  # Start from 0 as header is already removed

    def resize_image(image_array, target_size=(300, 200)):
        img = Image.fromarray(image_array)
        img = img.resize(target_size, Image.LANCZOS)
        return np.array(img)

    def update_image():
        if current_cell[0] < len(cell_crops):
            resized_img = resize_image(cell_crops[current_cell[0]])
            img = ImageTk.PhotoImage(Image.fromarray(resized_img))
            image_label.config(image=img)
            image_label.image = img  # Keep a reference
            cell_info_label.config(text=f"Cell {current_cell[0] + 1} of {len(cell_crops)}")
        else:
            root.destroy()

    def on_classification(classification):
        save_classified_image(cell_crops[current_cell[0]], classification, current_cell[0] + 1)
        current_cell[0] += 1
        if current_cell[0] < len(cell_crops):
            update_image()
        else:
            root.destroy()  # Close GUI immediately after classifying the last image

    ttk.Button(frame, text="TRUE", command=lambda: on_classification("true")).grid(row=2, column=0, padx=5, pady=10)
    ttk.Button(frame, text="FALSE", command=lambda: on_classification("false")).grid(row=2, column=1, padx=5, pady=10)
    ttk.Button(frame, text="NONE", command=lambda: on_classification("none")).grid(row=2, column=2, padx=5, pady=10)

    update_image()
    root.mainloop()

### Extractor
#### Function: perform_extraction(image_path)
Extracts tables and cells from an image and launches a GUI for cell classification.
<br/>
<br/>
<img src="./public/img/CellClassification-GUI.png" height="200" width="200">

In [8]:
def perform_extraction(image_path):
    file_name = image_path.split("/")[-1]
    file_name_without_extension = file_name.split(".")[0]
    print("Extracting data for:", file_name)

    image = Image.open(image_path).convert("RGB")

    width, height = image.size
    # display(image.resize((int(0.6 * width), (int(0.6 * height)))))

    pixel_values = detection_transform(image).unsqueeze(0)
    pixel_values = pixel_values.to("cpu")
    # print(pixel_values.shape)

    with torch.no_grad():
        outputs = model(pixel_values)
    # print("Output shape: ", outputs.logits.shape)

    objects = outputs_to_objects(outputs, image.size, id2label)
    # print("Objects: ", objects)

    # fig = visualize_detected_tables(image, objects)
    # visualized_image = fig2img(fig)
    
    # # Prevent image from displaying
    # plt.close(fig)

    tokens = []
    detection_class_thresholds = {"table": 0.5, "table rotated": 0.5, "no object": 10}
    crop_padding = 10

    tables_crops = objects_to_crops(
        image, tokens, objects, detection_class_thresholds, padding=0
    )
    if len(tables_crops) == 0:
        print("No tables detected")
        return
    cropped_table = tables_crops[0]["image"].convert("RGB")
    # cropped_table

    pixel_values = structure_transform(cropped_table).unsqueeze(0)
    pixel_values = pixel_values.to("cpu")
    # print(pixel_values.shape)

    # forward pass
    with torch.no_grad():
        outputs = structure_model(pixel_values)

    # cropped_table_visualized = cropped_table.copy()
    # draw = ImageDraw.Draw(cropped_table_visualized)

    cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
    # print("Cells: ", cells)

    # for cell in cells:
    #     draw.rectangle(cell["bbox"], outline="red", width=6)

    # cropped_table_visualized

    cell_coordinates = get_cell_coordinates_by_row(cells)

    # Plotting the cropped cell regions
    original_img_np = np.array(cropped_table)
    
    
    # Extract cell crops
    cell_crops = []
    for i, row in enumerate(cell_coordinates):
        if i == 0 and len(cell_coordinates) > 10:  # Skip header if 10 rows or more
            continue
        last_cell = row["cells"][-1]
        cell_x, cell_y, cell_w, cell_h = [int(x) for x in last_cell["cell"]]
        cell_crop = original_img_np[cell_y:cell_h, cell_x:cell_w]
        cell_crops.append(cell_crop)


    # Launch the classification GUI
    print("Number of cell crops:", len(cell_crops))
    # create_classification_gui(cell_crops, file_name_without_extension)



# # Viewing here itself for 1 image
# perform_extraction("data/Sample_Data/20240328_145219.jpg")

Extracting data for: 20240328_145219.jpg
Number of cell crops: 10


### Runner Code

In [7]:
import time
from IPython.display import clear_output

folder_path = pathlib.Path("data/Sample_Data")

for image_path in folder_path.glob("*.jpg"):
    image_path = str(image_path).replace("\\", "/")
    print(f"Reading: {image_path}")
    clear_output(wait=True)
    perform_extraction(str(image_path))
    time.sleep(1)