1. Import Libraries and Define Label Mapping

In this section, we import essential libraries and set up a mapping of entity labels. The label_map dictionary assigns an integer label to each field type, which allows the model to identify and classify entities in the data.

In [1]:
import os
import pandas as pd
import torch

In [2]:
# Define the label mapping for the entity labels in your dataset
label_map = {
    'employerName': 0,
    'employerAddressStreet_name': 1,
    'employerAddressCity': 2,
    'employerAddressState': 3,
    'employerAddressZip': 4,
    'einEmployerIdentificationNumber': 5,
    'employeeName': 6,
    'ssnOfEmployee': 7,
    'box1WagesTipsAndOtherCompensations': 8,
    'box2FederalIncomeTaxWithheld': 9,
    'box3SocialSecurityWages': 10,
    'box4SocialSecurityTaxWithheld': 11,
    'box16StateWagesTips': 12,
    'box17StateIncomeTax': 13,
    'taxYear': 14,
    'OTHER': 15  # Label for non-entity tokens
}


2. Define Paths and Create Output Directory

Here, paths for the test dataset and output directory are specified. os.makedirs ensures that the output directory exists before saving the predictions

In [3]:
#  Paths to test folder and output folder
test_folder_path = 'dataset/val/boxes_transcripts'
output_folder_path = 'output file'
os.makedirs(output_folder_path, exist_ok=True)

3. Initialize Tokenizer

The tokenizer from the transformers library is initialized using a pretrained BERT model (bert-base-uncased). This tokenizer will convert text data into tokens that the BERT model can process.

In [4]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

  from .autonotebook import tqdm as notebook_tqdm


4. Define the BERT + Coordinates Model

We define a custom PyTorch model class, BERTWithCoords, which incorporates both BERT embeddings and bounding box coordinates. The model’s output layer combines text embeddings from BERT with bounding box features to predict entity labels.

   > __init__ initializes the BERT model and a fully connected layer (fc).

   
   >  forward uses BERT to get text embeddings, concatenates them with bounding box features, and applies the classification layer to produce logits.

In [5]:
# Define the BERT + Coordinates model
class BERTWithCoords(nn.Module):
    def __init__(self, num_labels):
        super(BERTWithCoords, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.fc = nn.Linear(self.bert.config.hidden_size + 4, num_labels)  # +4 for bbox coordinates

    def forward(self, input_ids, attention_mask, bbox):
        # Get BERT embeddings
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # Concatenate BERT embeddings with bounding box features
        combined_features = torch.cat((bert_output.pooler_output, bbox), dim=1)
        
        # Classification layer
        logits = self.fc(combined_features)
        return logits

5. Function to Prepare Test Data

The prepare_test_data function tokenizes the transcript text and converts bounding box coordinates into a PyTorch tensor. This function handles both the text and spatial data needed for prediction.

In [6]:
# Function to tokenize and prepare bounding boxes
def prepare_test_data(data, tokenizer):
    tokens = tokenizer(data['transcript'].tolist(), return_tensors="pt", padding=True, truncation=True)
    bbox = torch.tensor(data[['x_top_left', 'y_top_left', 'x_bottom_right', 'y_bottom_right']].values)
    return tokens, bbox


6. Load Model and Set to Evaluation Mode

The BERTWithCoords model is loaded from a saved state dictionary, and the model is set to evaluation mode. This ensures that layers like dropout are not applied, making predictions stable.

In [7]:
# Load model and set to evaluation mode
model = BERTWithCoords(num_labels=len(label_map))
model.load_state_dict(torch.load("bert_with_coords_model.pth"))
model.eval()

  model.load_state_dict(torch.load("bert_with_coords_model.pth"))


BERTWithCoords(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

7. Reverse Label Map

This dictionary maps the numeric labels back to their original entity names, allowing us to decode the model’s predictions into readable labels.

In [8]:
# Reverse label map for decoding predictions
reverse_label_map = {v: k for k, v in label_map.items()}


. Prediction Function

The predict_for_file function processes a single file for prediction:

   > Load Data: Reads the TSV file and ensures the ‘transcript’ column is filled, converting any missing values to “unknown.”

   > Prepare Inputs: Tokenizes the transcripts and prepares the bounding box coordinates.

   > Model Prediction: Passes the inputs through the model to get predictions, then uses torch.max to get the predicted class for each token.
   
   >Add Predictions: Creates a new column in the data to store predicted labels, mapping them back to entity names.

In [9]:
def predict_for_file(model, file_path):
    # Load the test data and ensure correct column names
    data = pd.read_csv(file_path, sep=',', names=[
        'start_index', 'end_index', 'x_top_left', 'y_top_left', 
        'x_bottom_right', 'y_bottom_right', 'transcript'
    ])
    
    # Fill missing values in 'transcript' and convert to string type
    data['transcript'] = data['transcript'].fillna("unknown").astype(str)

    # Prepare inputs
    tokens, bbox = prepare_test_data(data, tokenizer)
    
    with torch.no_grad():
        input_ids = tokens['input_ids']
        attention_mask = tokens['attention_mask']
        
        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox)
        
        # Get predictions and map to labels
        _, predictions = torch.max(outputs, dim=1)
        predicted_labels = [reverse_label_map[pred.item()] for pred in predictions]
        
    # Add predictions to the data frame
    data['predicted_field'] = predicted_labels
    return data


9. Run Predictions on All Test Files

This final section iterates through all .tsv files in the test folder, makes predictions using predict_for_file, and saves each output file with predicted labels in the specified output directory.

In [10]:
for file_name in os.listdir(test_folder_path):
    if file_name.endswith('.tsv'):
        file_path = os.path.join(test_folder_path, file_name)
        output_path = os.path.join(output_folder_path, file_name)
        
        # Run prediction for the file
        predicted_data = predict_for_file(model, file_path)
        
        # Save the result to the output folder
        # predicted_data.to_csv(output_path, sep='\t', index=False)
        predicted_data.to_csv(output_path, sep=',', index=False)
        print(f"Predictions saved for {file_name} to {output_path}")

print("All predictions completed.")

Predictions saved for 033ae477-99aa-4047-953d-4a951fc5a498_document-3_page-1.tsv to output file\033ae477-99aa-4047-953d-4a951fc5a498_document-3_page-1.tsv
Predictions saved for 033ae477-99aa-4047-953d-4a951fc5a498_document-4_page-1.tsv to output file\033ae477-99aa-4047-953d-4a951fc5a498_document-4_page-1.tsv
Predictions saved for 03ca8d34-d060-49ee-b6cb-125e82305045_document-4_page-1.tsv to output file\03ca8d34-d060-49ee-b6cb-125e82305045_document-4_page-1.tsv
Predictions saved for 03ca8d34-d060-49ee-b6cb-125e82305045_document-5_page-1.tsv to output file\03ca8d34-d060-49ee-b6cb-125e82305045_document-5_page-1.tsv
Predictions saved for 053f994b-599a-4d25-a72f-4ee9f88b4136_document-4_page-1.tsv to output file\053f994b-599a-4d25-a72f-4ee9f88b4136_document-4_page-1.tsv
Predictions saved for 05cf86f4-299a-4b93-8ba3-b59c88499280_document-1_page-1.tsv to output file\05cf86f4-299a-4b93-8ba3-b59c88499280_document-1_page-1.tsv
Predictions saved for 05cf86f4-299a-4b93-8ba3-b59c88499280_document-4_