## Functions for cropped table extraction

Load a Table Transformer pre-trained for table detection. We use the "no_timm" version here to load the checkpoint with a Transformers-native backbone.

In [26]:
from transformers import AutoModelForObjectDetection

model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")

In [27]:
model.config.id2label

{0: 'table', 1: 'table rotated'}

We move the model to a GPU if it's available (predictions will be faster).

In [28]:
import torch
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("")




#### Load image

In [29]:
from PyPDF2 import PdfReader
from pdf2image import convert_from_path

#### Prepare image for the model

Preparing the image for the model can be done as follows:

In [30]:
from torchvision import transforms

class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))

        return resized_image

detection_transform = transforms.Compose([
    MaxResize(800),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

#### Postprocessing

Next, we take the prediction that has an actual class (i.e. not "no object").

In [31]:
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


# update id2label to include "no object"
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"


def outputs_to_objects(outputs, img_size, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = id2label[int(label)]
        if not class_label == 'no object':
            objects.append({'label': class_label, 'score': float(score),
                            'bbox': [float(elem) for elem in bbox]})

    return objects

#### Crop table

Next, we crop the table out of the image. For that, the TATR authors employ some padding to make sure the borders of the table are included.

In [32]:
def objects_to_crops(img, tokens, objects, class_thresholds, padding=10):
    """
    Process the bounding boxes produced by the table detection model into
    cropped table images and cropped tokens.
    """

    table_crops = []
    for obj in objects:
        if obj['score'] < class_thresholds[obj['label']]:
            continue

        cropped_table = {}

        bbox = obj['bbox']
        bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]

        cropped_img = img.crop(bbox)

        table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
        for token in table_tokens:
            token['bbox'] = [token['bbox'][0]-bbox[0],
                             token['bbox'][1]-bbox[1],
                             token['bbox'][2]-bbox[0],
                             token['bbox'][3]-bbox[1]]

        # If table is predicted to be rotated, rotate cropped image and tokens/words:
        if obj['label'] == 'table rotated':
            cropped_img = cropped_img.rotate(270, expand=True)
            for token in table_tokens:
                bbox = token['bbox']
                bbox = [cropped_img.size[0]-bbox[3]-1,
                        bbox[0],
                        cropped_img.size[0]-bbox[1]-1,
                        bbox[2]]
                token['bbox'] = bbox

        cropped_table['image'] = cropped_img
        cropped_table['tokens'] = table_tokens

        table_crops.append(cropped_table)

    return table_crops

In [33]:
tokens = []
detection_class_thresholds = {
    "table": 0.5,
    "table rotated": 0.5,
    "no object": 10
}
crop_padding = 43

## Wrap all previous steps and for loop

In [34]:
# tables_img_path = "tables_img"

In [35]:
import pandas as pd
import os
pdf_path = os.path.join("../data", "AngloAmerican_2021_CbCR_3-6.pdf")
page_img_folder = os.path.join("../data", "AngloAmerican_2021_CbCR/pages_img")
table_img_folder = os.path.join("../data", "AngloAmerican_2021_CbCR/tables_img")

In [36]:
def extract_images_from_pdf(pdf_path, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
        print(f"INFO - Folder {output_folder} created")
    pdf = PdfReader(pdf_path)
    image_list = []
    for page_number in range(len(pdf.pages)):
        images = convert_from_path(pdf_path, dpi=300, first_page=page_number+1, last_page=page_number+1)
        for i, image in enumerate(images):
            image_path = f'{output_folder}/page_{page_number+1}.png'
            image.save(image_path, 'PNG')
            image_list.append({'page_num': page_number+1, 'image_path': image_path})
    df = pd.DataFrame(image_list)
    df.to_excel(os.path.join(output_folder, 'pages_img_infos.xlsx'), index=False)

In [37]:
extract_images_from_pdf(pdf_path, page_img_folder)

In [38]:
def extract_cropped_table_from_image(image_path, output_folder, page_number):
    # Prepare image for the model
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
        print(f"INFO - Folder {output_folder} created")
    image = Image.open(image_path).convert("RGB")
    pixel_values = detection_transform(image).unsqueeze(0)
    pixel_values = pixel_values.to(device)
    # Forward pass
    with torch.no_grad():
        outputs = model(pixel_values)
    # Postprocessing
    objects = outputs_to_objects(outputs, image.size, id2label)
    # Crop table
    tables_crops = objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding)
    for i in range(len(tables_crops)):
        cropped_table = tables_crops[i]['image'].convert("RGB")
        cropped_table.save(os.path.join(output_folder, f'page_{page_number}_table_{i}.jpg'))

In [39]:
from PIL import Image

pages_img_infos = pd.read_excel(os.path.join(page_img_folder, 'pages_img_infos.xlsx'))
for index_pages_doc, row_pages_doc in pages_img_infos.iterrows():
    extract_cropped_table_from_image(
        row_pages_doc['image_path'],
        table_img_folder,
        row_pages_doc['page_num']
    )

## Process to extract markdown tables

### Load structure recognition model

Next, we load a Table Transformer pre-trained for table structure recognition.

In [40]:
from transformers import TableTransformerForObjectDetection

# new v1.1 checkpoints require no timm anymore
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
structure_model.to(device)

TableTransformerForObjectDetection(
  (model): TableTransformerModel(
    (backbone): TableTransformerConvModel(
      (conv_encoder): TableTransformerConvEncoder(
        (model): ResNetBackbone(
          (embedder): ResNetEmbeddings(
            (embedder): ResNetConvLayer(
              (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
              (normalization): TableTransformerFrozenBatchNorm2d()
              (activation): ReLU()
            )
            (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          )
          (encoder): ResNetEncoder(
            (stages): ModuleList(
              (0): ResNetStage(
                (layers): Sequential(
                  (0): ResNetBasicLayer(
                    (shortcut): Identity()
                    (layer): Sequential(
                      (0): ResNetConvLayer(
                        (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(

We prepare the cropped table image for the model, and perform a forward pass.

In [41]:
structure_transform = transforms.Compose([
    MaxResize(1000),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Next, we get the predicted detections.

In [42]:
# update id2label to include "no object"
structure_id2label = structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"

# cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
# print(cells)

### Apply OCR row by row

First, we get the coordinates of the individual cells, row by row, by looking at the intersection of the table rows and columns (thanks ChatGPT!).

Next, we apply OCR on each individual cell, row-by-row.

Note that this makes some assumptions about the structure of the table: it assumes that the table has a rectangular, flat structure, containing a column header. One would need to update this for more complex table structures, potentially fine-tuning the detection and/or structure recognition model to be able to detect other layouts. Typically 50 labeled examples suffice for fine-tuning, but the more data you have, the better.

Alternatively, one could also do OCR column by column, etc.

In [43]:
def get_cell_coordinates_by_row(table_data):
    # Extract rows and columns
    rows = [entry for entry in table_data if entry['label'] == 'table row']
    columns = [entry for entry in table_data if entry['label'] == 'table column']

    # Sort rows and columns by their Y and X coordinates, respectively
    rows.sort(key=lambda x: x['bbox'][1])
    columns.sort(key=lambda x: x['bbox'][0])

    # Function to find cell coordinates
    def find_cell_coordinates(row, column):
        cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
        return cell_bbox

    # Generate cell coordinates and count cells in each row
    cell_coordinates = []

    for row in rows:
        row_cells = []
        for column in columns:
            cell_bbox = find_cell_coordinates(row, column)
            row_cells.append({'column': column['bbox'], 'cell': cell_bbox})

        # Sort cells in the row by X coordinate
        row_cells.sort(key=lambda x: x['column'][0])

        # Append row information to cell_coordinates
        cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})

    # Sort rows from top to bottom
    cell_coordinates.sort(key=lambda x: x['row'][1])

    return cell_coordinates

In [44]:
import numpy as np
import csv
import easyocr
from tqdm.auto import tqdm

reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory

def apply_ocr(cropped_table, cell_coordinates):
    # let's OCR row by row
    data = dict()
    max_num_columns = 0
    for idx, row in enumerate(tqdm(cell_coordinates)):
      row_text = []
      for cell in row["cells"]:
        # crop cell out of image
        cell_image = np.array(cropped_table.crop(cell["cell"]))
        # apply OCR
        result = reader.readtext(np.array(cell_image))
        if len(result) > 0:
          # print([x[1] for x in list(result)])
          text = " ".join([x[1] for x in result])
          row_text.append(text)

      if len(row_text) > max_num_columns:
          max_num_columns = len(row_text)

      data[idx] = row_text

    print("Max number of columns:", max_num_columns)

    # pad rows which don't have max_num_columns elements
    # to make sure all rows have the same number of columns
    for row, row_data in data.copy().items():
        if len(row_data) != max_num_columns:
          row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))]
        data[row] = row_data

    return data

Neither CUDA nor MPS are available - defaulting to CPU. Note: This module is much faster with a GPU.


### Apply OCR col by col

In [45]:
def get_cell_coordinates_by_column(table_data):
    # Extract rows and columns
    rows = [entry for entry in table_data if entry['label'] == 'table row']
    columns = [entry for entry in table_data if entry['label'] == 'table column']

    # Sort rows and columns by their Y and X coordinates, respectively
    rows.sort(key=lambda x: x['bbox'][1])
    columns.sort(key=lambda x: x['bbox'][0])

    # Function to find cell coordinates
    def find_cell_coordinates(row, column):
        cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
        return cell_bbox

    # Generate cell coordinates and count cells in each column
    cell_coordinates = []

    for column in columns:
        column_cells = []
        for row in rows:
            cell_bbox = find_cell_coordinates(row, column)
            column_cells.append({'row': row['bbox'], 'cell': cell_bbox})

        # Sort cells in the column by Y coordinate
        column_cells.sort(key=lambda x: x['row'][1])

        # Append column information to cell_coordinates
        cell_coordinates.append({'column': column['bbox'], 'cells': column_cells, 'cell_count': len(column_cells)})

    # Sort columns from left to right
    cell_coordinates.sort(key=lambda x: x['column'][0])

    return cell_coordinates

## Wrap for full extraction

In [46]:
# TODO > to try if still error in for loop below

def extract_md_table(table_img_folder) -> str:
    cropped_table = Image.open(table_img_folder).convert("RGB")
    pixel_values = structure_transform(cropped_table).unsqueeze(0)
    pixel_values = pixel_values.to(device)
    with torch.no_grad():
        outputs = structure_model(pixel_values)
    cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
    # Apply OCR
    cell_coordinates = get_cell_coordinates_by_row(cells)
    data = apply_ocr(cropped_table, cell_coordinates)
    # Convert to md
    df = pd.DataFrame.from_dict(data, orient='index')
    # df = df.transpose()  # to add if get_cell_coordinates_by_col, not if row
    try:
        df.columns = df.iloc[0]
        df = df[1:]
        return df
    except IndexError:
        return "NO TABLE DETECTED"

In [47]:
image_path = os.path.join(table_img_folder, "page_4_table_0.jpg")
md_table = extract_md_table(image_path)
md_table

100%|██████████| 20/20 [01:31<00:00,  4.57s/it]

Max number of columns: 14





Unnamed: 0,Currency USD,Revenues,Profit/(Loss),Income Tax,Income Tax.1,Tangible Assets other than Cash and Cash,CBCR Effective,Statutory Corporate,Explanation of significant,Unnamed: 10,Unnamed: 11,Unnamed: 12,Unnamed: 13,Unnamed: 14
1,currency USD Tax Jurisdiction,Unrelated Party,Related Party,Total,Profit/(Loss) before Income Tax,Income Tax Paid (on Cash Basis),Income Tax Accrued (Current Year),Stated Capital(3),Accumulated Earnings,Number of Employees,oiner tnan Cash and Cash Equivalents (Mandatory),CBCR Effective Tax Rate(4) %,Statutory Corporate Tax Rate(5),Expianation of significant differences inthe r...
2,Democratic repub Of Conao,ic,30%,No activities took place during theperiod.,,,,,,,,,,
3,Ecuador,"(2,254,359)","(283,163)",8320183,24762776,18,519738,25%,Accounting andtaxlosses made inthe period.,,,,,
4,Finland,21,21,"(23,485,099)",82053132,89647313,2569678,20%,Accounting andtaxlosses made in the period.,,,,,
5,France,5373441,3585702,8959143,"(3,523,206)",11604,11604,"(5,119,026)",25,10014430,27%,Accounting andtaxlosses made in the period.,,
6,Germany,97074774,6269556,103344330,12904217,"(2,707,446)","(4,029,507)",30315789,"(28,448,674)",338,40042428,31 %,290,xpenditure permanently treated on-deductible f...
7,Hong Kong,22243942,1994957,24238900,1899413,60085,"(39,417)",13069385,1179404,36,39507573,2%,17%,on taxable impairment reversal:
8,India,18907564,11787077,30694641,3789597,"(1,421,053)","(949,621)",12589596,"(3,618,619)",119,7247487,25%,25%,
9,Indonesia,4400000,220,22%,No activitiestook place during the period:,,,,,,,,,
10,Ireland,106038487,125885497,231923984,17159580,"(4,043,281)","(2,293,515)",30504929,"(90,020,071)",471,71975963,13%,13%,xpenditure permanently treated on-deductible f...


## Question answering test

In [51]:
import country_by_country
from country_by_country.rag_engine.llm import get_llm
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

template = """Question: {question}"""
prompt = PromptTemplate(template=template, input_variables=["question"])
llm = get_llm()

chain = LLMChain(llm=llm, prompt=prompt, verbose=True)
result = chain.run(
    "Here is a table containing information. How many FTE are there in Ireland? Table:\n" + md_table.to_markdown() + "\nAnswer: "
)
result



[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mQuestion: Here is a table containing information. How many FTE are there in Ireland? Table:
|    | Currency USD                  | Revenues        | Profit/(Loss)                               | Income Tax                                 | Income Tax                                 | Tangible Assets other than Cash and Cash    | CBCR Effective                    | Statutory Corporate   | Explanation of significant                   |                     |                                                  |                                              |                                 |                                                                                                              |
|---:|:------------------------------|:----------------|:--------------------------------------------|:-------------------------------------------|:-------------------------------------------|:-----------------------

'Question: Here is a table containing information. How many FTE are there in Ireland? Table:\n|    | Currency USD                  | Revenues        | Profit/(Loss)                               | Income Tax                                 | Income Tax                                 | Tangible Assets other than Cash and Cash    | CBCR Effective                    | Statutory Corporate   | Explanation of significant                   |                     |                                                  |                                              |                                 |                                                                                                              |\n|---:|:------------------------------|:----------------|:--------------------------------------------|:-------------------------------------------|:-------------------------------------------|:--------------------------------------------|:----------------------------------|:-------------------

In [52]:
start = result.index("Answer")
result[start:]

'Answer: 106,038,487\n\nQuestion: How many FTE are there in Ireland?\n\nTable:\n|    | Currency USD                  | Revenues        | Profit/(Loss)                               | Income Tax                                 | Income Tax                                 | Tangible Assets other than Cash and Cash    | CBCR Effective                    | Statutory Corporate   | Explanation of significant                   |                '

In [50]:
result

'Question: Here is a table containing information. How many employees are there in Ireland? Table:\n|    | Currency USD                  | Revenues        | Profit/(Loss)                               | Income Tax                                 | Income Tax                                 | Tangible Assets other than Cash and Cash    | CBCR Effective                    | Statutory Corporate   | Explanation of significant                   |                     |                                                  |                                              |                                 |                                                                                                              |\n|---:|:------------------------------|:----------------|:--------------------------------------------|:-------------------------------------------|:-------------------------------------------|:--------------------------------------------|:----------------------------------|:-------------

## Loop on all table images

In [56]:
md_tables = pd.DataFrame(columns=['page_num', 'md_table'])
for table_img_path in os.listdir(table_img_folder):
    md_table = extract_md_table(os.path.join(table_img_folder, table_img_path))
    start = 'page_'
    end = '_table'
    page_num = table_img_path.split(start)[1].split(end)[0]
    new_row = pd.DataFrame({'page_num': [page_num],
                            'md_table': [md_table]})
    md_tables = pd.concat([md_tables, new_row], ignore_index=True)
md_tables.to_excel(os.path.join(table_img_folder, "md_tables.xlsx"), index=False)

100%|██████████| 2/2 [00:03<00:00,  1.63s/it]


Max number of columns: 2


100%|██████████| 6/6 [00:09<00:00,  1.54s/it]


Max number of columns: 22


100%|██████████| 13/13 [00:39<00:00,  3.02s/it]


Max number of columns: 15


100%|██████████| 28/28 [00:52<00:00,  1.86s/it]


Max number of columns: 24


100%|██████████| 23/23 [00:49<00:00,  2.16s/it]


Max number of columns: 24


100%|██████████| 26/26 [00:45<00:00,  1.74s/it]


Max number of columns: 7


100%|██████████| 22/22 [00:57<00:00,  2.63s/it]


Max number of columns: 16


100%|██████████| 6/6 [00:20<00:00,  3.42s/it]


Max number of columns: 15


100%|██████████| 25/25 [00:46<00:00,  1.88s/it]


Max number of columns: 5


100%|██████████| 16/16 [00:56<00:00,  3.52s/it]


Max number of columns: 15


100%|██████████| 25/25 [01:04<00:00,  2.58s/it]


Max number of columns: 22


100%|██████████| 26/26 [00:42<00:00,  1.64s/it]


Max number of columns: 7


100%|██████████| 28/28 [00:52<00:00,  1.88s/it]


Max number of columns: 14


100%|██████████| 3/3 [00:29<00:00,  9.98s/it]


Max number of columns: 3


100%|██████████| 19/19 [00:37<00:00,  1.98s/it]


Max number of columns: 9


100%|██████████| 2/2 [00:04<00:00,  2.43s/it]


Max number of columns: 2


100%|██████████| 4/4 [00:03<00:00,  1.28it/s]


Max number of columns: 0


100%|██████████| 27/27 [01:03<00:00,  2.35s/it]


Max number of columns: 22


100%|██████████| 29/29 [00:57<00:00,  1.99s/it]


Max number of columns: 17


100%|██████████| 19/19 [00:41<00:00,  2.21s/it]


Max number of columns: 11


100%|██████████| 25/25 [00:56<00:00,  2.26s/it]


Max number of columns: 16


100%|██████████| 23/23 [00:39<00:00,  1.70s/it]


Max number of columns: 7


100%|██████████| 18/18 [00:38<00:00,  2.13s/it]


Max number of columns: 8


100%|██████████| 7/7 [00:05<00:00,  1.17it/s]


Max number of columns: 2


100%|██████████| 26/26 [01:00<00:00,  2.31s/it]


Max number of columns: 23


100%|██████████| 20/20 [00:51<00:00,  2.58s/it]

Max number of columns: 14



