In [1]:
from datasets import load_dataset, Dataset
from transformers import (
    LayoutLMv3Tokenizer,
    LayoutLMv3ForSequenceClassification,
    TrainingArguments,
    Trainer
)
import torch
import random
import numpy as np
from sklearn.metrics import classification_report, accuracy_score
import re
from PIL import Image
import os
import json
import logging

from collections import Counter
import torch.nn as nn

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
Image.MAX_IMAGE_PIXELS = None

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)




In [2]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\soham\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [3]:
ds = load_dataset("ds4sd/DocLayNet-v1.1", streaming=True)
print(ds)

Resolving data files:   0%|          | 0/29 [00:00<?, ?it/s]

IterableDatasetDict({
    train: IterableDataset({
        features: ['image', 'bboxes', 'category_id', 'segmentation', 'area', 'pdf_cells', 'metadata'],
        num_shards: 29
    })
    test: IterableDataset({
        features: ['image', 'bboxes', 'category_id', 'segmentation', 'area', 'pdf_cells', 'metadata'],
        num_shards: 2
    })
    val: IterableDataset({
        features: ['image', 'bboxes', 'category_id', 'segmentation', 'area', 'pdf_cells', 'metadata'],
        num_shards: 3
    })
})


In [4]:
# Define label map and categories of interest
label_map = {
    1: "Caption",
    2: "Footnote",
    3: "Formula",
    4: "List-item",
    5: "Page-footer",
    6: "Page-header",
    7: "Picture",
    8: "Section-header",
    9: "Table",
    10: "Text",
    11: "Title"
}
categories_of_interest = [1, 8, 10, 11]

def preprocess_sample(sample):
    texts = []
    bboxes = []
    category_ids = []
    
    # Extract text cells from pdf_cells
    for cell_list in sample['pdf_cells']:
        if isinstance(cell_list, dict):
            cell_list = [cell_list]
        for cell in cell_list:
            cell_bbox = cell['bbox']
            cell_text = cell['text'].strip() if isinstance(cell['text'], str) else ""
            if not cell_text:
                continue

            # Match text cell to annotated objects
            matched_category_id = -1
            for obj_bbox, obj_cat_id in zip(sample['bboxes'], sample['category_id']):
                if is_overlapping(cell_bbox, obj_bbox):
                    matched_category_id = obj_cat_id
                    break

            texts.append(cell_text)
            bboxes.append(cell_bbox)
            category_ids.append(matched_category_id)

    return {
        "texts": texts,
        "bboxes": bboxes,
        "category_ids": category_ids
    }

def is_overlapping(bbox1, bbox2):
    x1, y1, w1, h1 = bbox1
    x2, y2, w2, h2 = bbox2
    return (x1 < x2 + w2 and x1 + w1 > x2 and y1 < y2 + h2 and y1 + h1 > y2)

In [5]:
# train_dataset = ds['train']
# ds1 = train_dataset.take(1)
# print(ds1)
# for index, sample in enumerate(ds1):
#     out = sample
#     output = preprocess_sample(sample)
#     print(output)
# a = preprocess_sample(out)
# for x in a:
#     print(a[x])

In [6]:
# Initialize tokenizer
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")

def normalize_bbox(bbox, width=1000, height=1000):
    x, y, w, h = bbox
    return [
        int(1000 * x / width),
        int(1000 * y / height),
        int(1000 * (x + w) / width),
        int(1000 * (y + h) / height)
    ]

def tokenize_and_align_labels(examples):
    tokenized_inputs = {
        "input_ids": [],
        "attention_mask": [],
        "bbox": [],
        "labels": [],
        "category_ids": []
    }
    
    for text, bbox, category_id in zip(
        examples['texts'], examples['bboxes'], examples['category_ids']
    ):
        if not isinstance(text, str):
            logger.warning(f"Skipping invalid text: {text} (type: {type(text)})")
            continue
        
        if not text.strip() or text.strip() in [",", ";"]:
            logger.warning(f"Skipping empty or invalid text: {text}")
            continue
        
        try:
            words = nltk.word_tokenize(text)
            norm_bbox = normalize_bbox(bbox)
            word_bboxes = [norm_bbox for _ in words]
            encoding = tokenizer(
                words,
                boxes=word_bboxes,
                truncation=True,
                max_length=512,
                padding="max_length",
                return_tensors="pt",
                is_split_into_words=True
            )
        except Exception as e:
            logger.error(f"Tokenization failed for text: {text}, error: {str(e)}")
            continue
        
        token_bboxes = [norm_bbox for _ in range(len(encoding['input_ids'][0]))]
        
        # Assign label: 0 (Caption), 1 (Section-header), 2 (Text), 3 (Title), 4 (Other)
        if category_id == 1:
            label = 0
        elif category_id == 8:
            label = 1
        elif category_id == 10:
            label = 2
        elif category_id == 11:
            label = 3
        else:
            label = 4
        
        tokenized_inputs['input_ids'].append(encoding['input_ids'][0])
        tokenized_inputs['attention_mask'].append(encoding['attention_mask'][0])
        tokenized_inputs['bbox'].append(token_bboxes)
        tokenized_inputs['labels'].append(label)
        tokenized_inputs['category_ids'].append(category_id)
    
    return tokenized_inputs

def process_dataset(dataset, num_samples=100):
    processed_data = {
        "texts": [],
        "bboxes": [],
        "category_ids": []
    }
    for i, sample in enumerate(dataset.take(num_samples)):
        try:
            result = preprocess_sample(sample)
            processed_data['texts'].extend(result['texts'])
            processed_data['bboxes'].extend(result['bboxes'])
            processed_data['category_ids'].extend(result['category_ids'])
        except Exception as e:
            logger.error(f"Error processing sample {i}: {str(e)}")
            continue
    return tokenize_and_align_labels(processed_data)

def create_dataset(tokenized_data):
    return Dataset.from_dict({
        'input_ids': tokenized_data['input_ids'],
        'attention_mask': tokenized_data['attention_mask'],
        'bbox': tokenized_data['bbox'],
        'labels': tokenized_data['labels'],
        'category_ids': tokenized_data['category_ids']
    })

In [7]:
p = {
        "texts": [],
        "bboxes": [],
        "category_ids": []
    }

for i, sample in enumerate(ds['train'].take(2)):
    try:
        result = preprocess_sample(sample)
        print(result)
    except Exception as e:
        logger.error(f"Error processing sample {i}: {str(e)}")
        continue

{'texts': ['NOTES TO THE FINANCIAL STATEMENTS', 'Finance receivables that originated outside the U.S. were $52.7 billion and $47.5 billion at December 31, 2004 and 2003,', 'respectively. Other finance receivables consisted primarily of real estate, commercial and other collateralized loans and', 'accrued interest.', 'Included in net finance and other receivables at December 31, 2004 and 2003 were $16.9 billion and $14.3 billion,', 'respectively, of receivables that have been sold for legal purposes to consolidated securitization SPEs and are available only', 'for repayment of debt issued by those entities, and to pay other securitization investors and other participants; they are not', 'available to pay our other obligations or the claims of our other creditors.', 'Future maturities, exclusive of the effects of SFAS No. 133,', 'Accounting for Derivative Instruments and Hedging Activities', ', of', 'total finance receivables including minimum lease rentals are as follows (in billions): 

In [8]:
p['texts'].extend(result['texts'])
p['bboxes'].extend(result['bboxes'])
p['category_ids'].extend(result['category_ids'])
p

{'texts': ['attention to detail and ability to recognize',
  'what makes games compelling.',
  'While the success of the',
  'Grand Theft',
  'Auto',
  'franchise is extremely rewarding,',
  'creating a blockbuster of this magnitude',
  'also affords Take-Two an invaluable base',
  'of knowledge and expertise. During fiscal',
  '2003, Take-Two took significant steps to',
  'share and leverage internal resources and',
  'experiences to create a more integrated',
  'and seamless publishing operation. Our',
  'Rockstar Games, Gathering and Global',
  'Star Software publishing labels have been',
  'streamlined to tap Rockstar’s knowledge,',
  'Rockstar’s unique market position is com-',
  'plemented by Gathering’s focus on pub-',
  'lishing premium and mid-priced products',
  'on PC, console and handheld platforms.',
  'We firmly believe demand for our premi-',
  'um priced games such as',
  'Grand Theft',
  'Auto, Max Payne, Midnight Club,',
  'Manhunt,',
  'and',
  'Mafia',
  'will conti

In [9]:
model = LayoutLMv3ForSequenceClassification.from_pretrained(
    "microsoft/layoutlmv3-base",
    num_labels=5  # For Caption, Section-header, Text, Title, Other
)

Some weights of LayoutLMv3ForSequenceClassification were not initialized from the model checkpoint at microsoft/layoutlmv3-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
t1 = process_dataset(ds['train'], num_samples=200)
v1 = process_dataset(ds['val'], num_samples=50)



In [11]:
print(v1.keys(), t1.keys(), sep='\n')

dict_keys(['input_ids', 'attention_mask', 'bbox', 'labels', 'category_ids'])
dict_keys(['input_ids', 'attention_mask', 'bbox', 'labels', 'category_ids'])


In [12]:
train_dataset = create_dataset(t1)
val_dataset = create_dataset(v1)

In [13]:
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

2.7.1+cu128
True
NVIDIA GeForce RTX 3050 6GB Laptop GPU


In [14]:
len(train_dataset)

20177

In [15]:
training_args = TrainingArguments(
    output_dir="./layoutlmv3-headings",
    num_train_epochs=5,
    per_device_train_batch_size=4,  # Further reduced for 6GB GPU
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,  # Effective batch size = 2 * 4 = 8
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
    logging_steps=10,
    fp16=torch.cuda.is_available(),
    max_grad_norm=1.0,
    save_total_limit=2,
    remove_unused_columns=False,
)

In [16]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    class_labels = [0, 1, 2, 3, 4]  # All possible class indices
    return {
        "accuracy": accuracy_score(labels, predictions),
        "classification_report": classification_report(
            labels,
            predictions,
            labels=class_labels,  # <-- Add this line
            target_names=["Caption", "Section-header", "Text", "Title", "Other"],
            output_dict=True
        )
    }

# Custom Trainer with class weights
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        # Remove keys not expected by the model
        if "category_ids" in inputs:
            inputs.pop("category_ids")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# Compute class weights
label_counts = Counter(train_dataset['labels'])
total = sum(label_counts.values())
print("Label counts:", dict(label_counts))

# Avoid division by zero by adding a small constant
class_weights = torch.tensor(
    [total / (len(label_counts) * (label_counts.get(i, 0) + 1e-6)) for i in range(5)],
    dtype=torch.float
).to('cuda' if torch.cuda.is_available() else 'cpu')
print("Class weights:", class_weights)

# Initialize Trainer
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

Label counts: {4: 11407, 2: 8098, 1: 546, 0: 83, 3: 43}
Class weights: tensor([48.6193,  7.3908,  0.4983, 93.8465,  0.3538], device='cuda:0')




Epoch,Training Loss,Validation Loss,Accuracy,Classification Report
0,0.4925,0.394606,0.876201,"{'Caption': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 93.0}, 'Section-header': {'precision': 0.5357142857142857, 'recall': 0.17543859649122806, 'f1-score': 0.2643171806167401, 'support': 171.0}, 'Text': {'precision': 0.790187217559716, 'recall': 0.8583450210378681, 'f1-score': 0.8228571428571428, 'support': 1426.0}, 'Title': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 3.0}, 'Other': {'precision': 0.916206987128973, 'recall': 0.9378865286367303, 'f1-score': 0.9269200106298167, 'support': 3719.0}, 'accuracy': 0.8762010347376201, 'macro avg': {'precision': 0.44842169808059495, 'recall': 0.3943340292331653, 'f1-score': 0.40281886682073986, 'support': 5412.0}, 'weighted avg': {'precision': 0.854727993390604, 'recall': 0.8762010347376201, 'f1-score': 0.8621226982875159, 'support': 5412.0}}"
1,0.334,0.452233,0.872875,"{'Caption': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 93.0}, 'Section-header': {'precision': 0.5319148936170213, 'recall': 0.29239766081871343, 'f1-score': 0.37735849056603776, 'support': 171.0}, 'Text': {'precision': 0.7625075346594334, 'recall': 0.8870967741935484, 'f1-score': 0.820097244732577, 'support': 1426.0}, 'Title': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 3.0}, 'Other': {'precision': 0.9316753211259907, 'recall': 0.9166442592094649, 'f1-score': 0.9240986717267552, 'support': 3719.0}, 'accuracy': 0.8728750923872876, 'macro avg': {'precision': 0.44521954988048906, 'recall': 0.41922773884434533, 'f1-score': 0.4243108814050739, 'support': 5412.0}, 'weighted avg': {'precision': 0.8579441445861828, 'recall': 0.8728750923872876, 'f1-score': 0.8630284429096913, 'support': 5412.0}}"
2,0.2258,0.497221,0.883962,"{'Caption': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 93.0}, 'Section-header': {'precision': 0.5952380952380952, 'recall': 0.29239766081871343, 'f1-score': 0.39215686274509803, 'support': 171.0}, 'Text': {'precision': 0.816408876933423, 'recall': 0.8513323983169705, 'f1-score': 0.8335049776862341, 'support': 1426.0}, 'Title': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 3.0}, 'Other': {'precision': 0.9164280135381411, 'recall': 0.9464909922022049, 'f1-score': 0.9312169312169312, 'support': 3719.0}, 'accuracy': 0.8839615668883961, 'macro avg': {'precision': 0.46561499714193183, 'recall': 0.4180442102675778, 'f1-score': 0.4313757543296527, 'support': 5412.0}, 'weighted avg': {'precision': 0.8636697256358318, 'recall': 0.8839615668883961, 'f1-score': 0.871920304675859, 'support': 5412.0}}"
3,0.1662,0.597988,0.878049,"{'Caption': {'precision': 0.2857142857142857, 'recall': 0.021505376344086023, 'f1-score': 0.04, 'support': 93.0}, 'Section-header': {'precision': 0.5625, 'recall': 0.3157894736842105, 'f1-score': 0.4044943820224719, 'support': 171.0}, 'Text': {'precision': 0.8213538032100488, 'recall': 0.8253856942496494, 'f1-score': 0.8233648128716334, 'support': 1426.0}, 'Title': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 3.0}, 'Other': {'precision': 0.9081290322580645, 'recall': 0.9462221027157838, 'f1-score': 0.926784303397419, 'support': 3719.0}, 'accuracy': 0.8780487804878049, 'macro avg': {'precision': 0.5155394242364798, 'recall': 0.4217805293987459, 'f1-score': 0.43892869965830494, 'support': 5412.0}, 'weighted avg': {'precision': 0.8631451077081855, 'recall': 0.8780487804878049, 'f1-score': 0.8672796723606417, 'support': 5412.0}}"
4,0.0803,0.660009,0.878234,"{'Caption': {'precision': 0.3076923076923077, 'recall': 0.043010752688172046, 'f1-score': 0.07547169811320754, 'support': 93.0}, 'Section-header': {'precision': 0.5714285714285714, 'recall': 0.21052631578947367, 'f1-score': 0.3076923076923077, 'support': 171.0}, 'Text': {'precision': 0.790920716112532, 'recall': 0.8674614305750351, 'f1-score': 0.8274247491638796, 'support': 1426.0}, 'Title': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 3.0}, 'Other': {'precision': 0.9217714134181915, 'recall': 0.9346598547996773, 'f1-score': 0.9281708945260347, 'support': 3719.0}, 'accuracy': 0.8782335550628233, 'macro avg': {'precision': 0.5183626017303206, 'recall': 0.41113167077047164, 'f1-score': 0.42775192989908595, 'support': 5412.0}, 'weighted avg': {'precision': 0.8651608458995557, 'recall': 0.8782335550628233, 'f1-score': 0.8668531968939261, 'support': 5412.0}}"


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Trainer is attempting to log a value of "{'Caption': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 93.0}, 'Section-header': {'precision': 0.5357142857142857, 'recall': 0.17543859649122806, 'f1-score': 0.2643171806167401, 'support': 171.0}, 'Text': {'precision': 0.790187217559716, 'recall': 0.8583450210378681, 'f1-score': 0.8228571428571428, 'support': 1426.0}, 'Title': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 3.0}, 'Other': {'precision': 0.916206987128973, 'recall': 0.9378865286367303, 'f1-score': 0.9269200106298167, 'support': 3719.0}, 'accuracy': 0.8762010347376201, 'macro avg': {'precision': 0.44842169808059495, 'recall': 0.3943340292331653, 'f1-score': 0.40281886682073986, 'support': 5412.0}, 'weighted avg': {'precision': 0.854727

TrainOutput(global_step=3150, training_loss=0.3009083567630677, metrics={'train_runtime': 12150.3394, 'train_samples_per_second': 8.303, 'train_steps_per_second': 0.259, 'total_flos': 2.675807028436992e+16, 'train_loss': 0.3009083567630677, 'epoch': 4.999008919722497})

In [None]:
def test_on_document(trainer, dataset, num_samples=1):
    # Process one test document
    test_processed_data = process_dataset(dataset['test'], num_samples=num_samples)
    
    # Check if processed data is empty
    if not test_processed_data['labels']:
        logger.error("No valid texts in test_processed_data, cannot proceed")
        return [], {'accuracy': 0.0, 'classification_report': {}}
    
    test_dataset = create_dataset(test_processed_data)
    
    # Verify dataset structure
    logger.info("Verifying test_dataset structure")
    print("Sample test_dataset entry:", test_dataset[0])
    
    # Debug texts
    logger.info("Inspecting test_processed_data['labels']")
    print("First few texts:", test_processed_data['labels'][:5])
    
    # Get predictions
    predictions = trainer.predict(test_dataset)
    logits = predictions.predictions
    predicted_labels = np.argmax(logits, axis=-1)
    
    # Ground truth labels
    ground_truth_labels = test_dataset['labels']
    ground_truth_category_ids = test_dataset['category_ids']
    
    # Map label indices to names
    label_map = {0: "Caption", 1: "Section-header", 2: "Text", 3: "Title", 4: "Other"}
    predicted_labels_named = [label_map[label] for label in predicted_labels]
    ground_truth_labels_named = [label_map[label] for label in ground_truth_labels]
    
    # Extract texts and handle non-string cases
    texts = test_processed_data['labels']
    formatted_texts = []
    for text in texts:
        if not isinstance(text, str):
            logger.warning(f"Non-string text found: {text}, converting to string")
            text = str(text)
        formatted_texts.append(text)
    
    # Print results
    print("\nTest Document Results:")
    print(f"{'Text':<60} {'Predicted':<15} {'Ground Truth':<15} {'Category ID':<10}")
    print("-" * 100)
    for text, pred, gt, cat_id in zip(formatted_texts, predicted_labels_named, ground_truth_labels_named, ground_truth_category_ids):
        print(f"{text[:57]:<60} {pred:<15} {gt:<15} {cat_id:<10}")
    
    # Compute metrics
    metrics = compute_metrics((logits, ground_truth_labels))
    print("\nMetrics on Test Document:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print("Classification Report:")
    for label, scores in metrics['classification_report'].items():
        if isinstance(scores, dict):
            print(f"{label}:")
            print(f"  Precision: {scores['precision']:.4f}")
            print(f"  Recall: {scores['recall']:.4f}")
            print(f"  F1-Score: {scores['f1-score']:.4f}")
            print(f"  Support: {scores['support']}")
    
    # Save results to JSON
    results = []
    for text, pred, gt, cat_id in zip(formatted_texts, predicted_labels_named, ground_truth_labels_named, ground_truth_category_ids):
        results.append({
            "text": text,
            "predicted_label": pred,
            "ground_truth_label": gt,
            "category_id": int(cat_id)
        })
    
    output_file = "./test_predictions.json"
    with open(output_file, "w") as f:
        json.dump(results, f, indent=4)
    print(f"\nResults saved to {output_file}")
    
    return results, metrics

# Test on one document from the test split
results, metrics = test_on_document(trainer, ds, num_samples=1)
print(results, metrics)

NameError: name 'trainer' is not defined