# Layoutlm V3 for Token Classification

Huggingface Model Reference Link: 
1. https://huggingface.co/docs/transformers/model_doc/layoutlmv3
2. https://huggingface.co/microsoft/layoutlmv3-base

Model Paper: https://arxiv.org/pdf/2204.08387

### We'll start by importing required libraries

In [None]:
import os
import psutil
import json
from datetime import date, datetime
import time
from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, LayoutLMv3Config
from transformers import get_linear_schedule_with_warmup

In [None]:
# Current date and time
dat = str(date.today())
dat_tim = datetime.today().strftime("%Y-%m-%d_%H:%M:%S")

# Check if GPU is available
gpu_available = torch.cuda.is_available()
gpu_count = None
gpu_variant = None

if gpu_available:
    gpu_count = torch.cuda.device_count()
    gpu_variant = torch.cuda.get_device_name(0)
    one_gpu_size = torch.cuda.get_device_properties(0).total_memory / 1e9
    one_gpu_size = round(one_gpu_size, 2)

# Get total RAM size and convert it to GB
total_ram = psutil.virtual_memory().total
total_ram_gb = round(total_ram / (1024 ** 3), 2)

# Hardware specs dictionary
hardware_specs = {
    "vm_ram_memory_gb": total_ram_gb,
    "gpu_available": gpu_available,
    "total_gpu_count": gpu_count,
    "gpu_variant": gpu_variant,
    "one_gpu_size_gb": one_gpu_size if gpu_available else None
}
print(hardware_specs)

In [None]:
# Set the fine-tune count
finetune_count = 9

# Define paths for model and data
processor_path = '/local/folder/path/where/model/is/stored'
model_name = '/local/folder/path/where/model/is/stored'
model_op_dir_path = f"/output/directory/path/to/store/weights/file"
output_dir = model_op_dir_path
os.makedirs(model_op_dir_path, exist_ok=True)

# Define paths for data
data_folder_path = "/path/where/image/data/is/stored"
current_working_directory_path = "/current/working/directory/path"
other_supporting_data_path = "/path/of/json/file/"

image_folder = "image_seperated_pdfs"
json_folder = "page_seperated_json"
data_date_range = ""

In [None]:
# Define class for storing metadata about training
class TrainingMetadata:
    def __init__(self, date, model_op_dir_path):
        # Initialize variables related to training metadata
        self.model_op_dir_path = model_op_dir_path
        self.finetune_count = -1
        self.model_name = ""
        self.date = date
        self.hardware_info = {}
        self.dataset_load_start_time = None
        self.dataset_load_end_time = None
        self.dataset_load_time_required = None
        self.batch_size = -1
        self.validation_dataset_size = -1
        self.dataset_rows = -1
        self.total_dataset_samples = -1
        self.failed_rows_samples = -1
        self.total_train_samples = -1
        self.total_val_samples = -1
        self.total_fields = -1
        self.fields_list = []
        self.train_start_time = None
        self.train_end_time = None
        self.train_time_required = None
        self.num_epochs = -1
        self.patience = -1
        self.default_learning_rate = None
        self.optimizer = "None"
        self.scheduler = "None"
        self.dropout_rate = 0.3
        self.hidden_dropout_prob = 0.3
        self.attention_probs_dropout_prob = 0.3
        self.dropout = 0.3
        self.hidden_size = 512
        self.visualizer_image_path = ""
        self.tarining_flow = []
        self.metrics_history = {}

    def get_training_metadata(self):
        return {key: getattr(self, key) for key in vars(self)}

    def save_training_metadata(self):
        # Save metadata to JSON file
        os.makedirs(self.model_op_dir_path, exist_ok=True)
        finetune_metadata_file_path = os.path.join(self.model_op_dir_path, "finetune_metadata.json")
        finetune_metadata = self.get_training_metadata()
        try:
            with open(finetune_metadata_file_path, "w") as json_file:
                json.dump(finetune_metadata, json_file, indent=3)
        except Exception as e:
            print("Exception in storing metadata json file.", e)

In [None]:
# Instantiate the metadata object
training_metadata = TrainingMetadata(date=dat, model_op_dir_path=model_op_dir_path)
training_metadata.finetune_count = finetune_count
training_metadata.model_name = "LayoutLM_V3_ForTokenClassification"
training_metadata.hardware_info = hardware_specs
training_metadata.batch_size = 16
training_metadata.validation_dataset_size = 0.25
training_metadata.num_epochs = 22
training_metadata.patience = 6
training_metadata.default_learning_rate = 5e-5
training_metadata.optimizer = "AdamW"
training_metadata.scheduler = "ReduceLROnPlateau"
training_metadata.hidden_dropout_prob = 0.3
training_metadata.attention_probs_dropout_prob = 0.3
training_metadata.dropout = 0.3

# Save training metadata
training_metadata.save_training_metadata()

### Class to plot traning metrics graph

In [None]:
# Define a class for visualizing training metrics
class TrainingVisualizer:
    def __init__(self, checkpoint_dir='model_checkpoints', training_metadata=None):
        # Initialize tracking lists for various metrics
        self.train_losses = []
        self.val_losses = []
        self.learning_rates = []
        self.train_metrics = []
        self.val_metrics = []
        self.gradient_norms = []
        self.weight_distributions = []

        # Checkpoint directory
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.training_metadata = training_metadata

    def plot_all_metrics(self):
        """
        Create a comprehensive visualization of training metrics.
        """
        # Create a figure with multiple subplots
        plt.figure(figsize=(20, 15))
        plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.3, hspace=0.3)
        
        # Plot Loss Curves
        plt.subplot(2, 2, 1)
        plt.plot(self.train_losses, label='Training Loss')
        plt.plot(self.val_losses, label='Validation Loss')
        plt.title('Model Loss over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        # Plot Learning Rate Tracking
        plt.subplot(2, 2, 2)
        plt.plot(self.learning_rates)
        plt.title('Learning Rate Schedule')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')

        # Plot Loss and Performance Heatmap
        plt.subplot(2, 2, 3)
        metrics_df = pd.DataFrame({
            'Epoch': range(1, len(self.train_losses) + 1),
            'Train Loss': self.train_losses,
            'Val Loss': self.val_losses,
        })
        metrics_corr = metrics_df.corr()
        sns.heatmap(metrics_corr, annot=True, cmap='coolwarm', center=0)
        plt.title('Correlation Between Training Metrics')

        # Plot Gradient Norms
        plt.subplot(2, 2, 4)
        plt.plot(self.gradient_norms)
        plt.title('Gradient Norms')
        plt.xlabel('Training Step')
        plt.ylabel('Gradient Norm')

        # Save the plot
        img_name = f"model_training_stats_{self.training_metadata.finetune_count}_{dat_tim}.png"
        plt.tight_layout()
        plt.savefig(os.path.join(self.checkpoint_dir, img_name))
        plt.close()

        return os.path.join(self.checkpoint_dir, img_name)

    def log_metrics(self, model, epoch, train_loss, val_loss, learning_rate, avg_train_acc, avg_val_acc, additional_metric=None):
        """
        Log training metrics for visualization
        
        Args:
            model: PyTorch model
            epoch: Current epoch number
            train_loss: Training loss for the epoch
            val_loss: Validation loss for the epoch
            learning_rate: Current learning rate
            additional_metric: Optional performance metric
        """
        # Log losses
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        
        # Log learning rate
        self.learning_rates.append(learning_rate)
        
        # Optional additional metric
        if additional_metric is not None:
            self.train_metrics.append(additional_metric)
        
        # Track gradient norms
        total_grad_norm = 0
        for param in model.parameters():
            if param.grad is not None:
                param_grad_norm = param.grad.detach().data.norm(2).item()
                total_grad_norm += param_grad_norm
        self.gradient_norms.append(total_grad_norm)
        
        # Track weight distributions
        weights = []
        for param in model.parameters():
            weights.extend(param.data.cpu().numpy().flatten())
        self.weight_distributions.extend(weights)
        
    def plot_all_metrics_in_one_image(self, metrics_history):
        """
        Plots all important graphs in a single image and saves it as a PNG file.

        Args:
            metrics_history (dict): Dictionary containing training and validation metrics.
            output_dir (str): Directory to save the plot.
        """
        # # Create output directory if it doesn't exist
        # os.makedirs(output_dir, exist_ok=True)

        # Create a figure with subplots
        fig, axes = plt.subplots(3, 2, figsize=(20, 18))
        fig.suptitle('Training and Validation Metrics', fontsize=16)

        # Plot Training and Validation Loss
        axes[0, 0].plot(metrics_history['train_loss'], label='Train Loss')
        axes[0, 0].plot(metrics_history['val_loss'], label='Validation Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        # Plot Training and Validation Accuracy
        axes[0, 1].plot(metrics_history['train_accuracy'], label='Train Accuracy')
        axes[0, 1].plot(metrics_history['val_accuracy'], label='Validation Accuracy')
        axes[0, 1].set_title('Training and Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        # Plot Training Precision, Recall, and F1
        axes[1, 0].plot(metrics_history['train_precision'], label='Train Precision')
        axes[1, 0].plot(metrics_history['train_recall'], label='Train Recall')
        axes[1, 0].plot(metrics_history['train_f1'], label='Train F1')
        axes[1, 0].set_title('Training Precision, Recall, and F1')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True)

        # Plot Validation Precision, Recall, and F1
        axes[1, 1].plot(metrics_history['val_precision'], label='Validation Precision')
        axes[1, 1].plot(metrics_history['val_recall'], label='Validation Recall')
        axes[1, 1].plot(metrics_history['val_f1'], label='Validation F1')
        axes[1, 1].set_title('Validation Precision, Recall, and F1')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Score')
        axes[1, 1].legend()
        axes[1, 1].grid(True)

        # Plot Gradients (Average Gradient Norm per Epoch)
        if 'gradients' in metrics_history:
            avg_grad_norms = []
            for epoch_grads in metrics_history['gradients']:
                avg_grad_norm = sum(grad[1] for grad in epoch_grads) / len(epoch_grads)
                avg_grad_norms.append(avg_grad_norm)

            axes[2, 0].plot(avg_grad_norms, label='Average Gradient Norm')
            axes[2, 0].set_title('Average Gradient Norm per Epoch')
            axes[2, 0].set_xlabel('Epoch')
            axes[2, 0].set_ylabel('Gradient Norm')
            axes[2, 0].legend()
            axes[2, 0].grid(True)

        # Plot Activations (Average Activation Magnitude per Epoch)
        if 'activations' in metrics_history:
            avg_activation_magnitudes = [np.mean(np.abs(act)) for act in metrics_history['activations']]

            axes[2, 1].plot(avg_activation_magnitudes, label='Average Activation Magnitude')
            axes[2, 1].set_title('Average Activation Magnitude per Epoch')
            axes[2, 1].set_xlabel('Epoch')
            axes[2, 1].set_ylabel('Activation Magnitude')
            axes[2, 1].legend()
            axes[2, 1].grid(True)
        
        img_name = f"all_metrics_plot_{self.training_metadata.finetune_count}_{dat_tim}.png"
        img_path = os.path.join(self.checkpoint_dir, img_name)
        
        # Adjust layout and save the figure
        plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust layout to prevent overlap
        plt.savefig(img_path)
        plt.close()

        print(f"All metrics plot saved to {img_path}")

### Create Custom Dataset

In [None]:
# Define dataset class for Named Entity Recognition (NER)
class NERDataset(Dataset):
    def __init__(self, data, processor, labels_map):
        self.data = data
        self.processor = processor
        self.labels_map = labels_map

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        file_date = item["date"]
        img_name = item['page_file_name']
        img_base_path = os.path.join(data_folder_path, f"{file_date}/{image_folder}/{img_name}.png")
        
        # Load image
        image = Image.open(img_base_path).convert("RGB")

        # Helper function to encode word labels
        def encode_word_labels(word_labels):
            int_labels = [0] * len(word_labels)
            for i in range(len(word_labels)):
                int_labels[i] = self.labels_map["fields_to_label"][str(word_labels[i])]
            return int_labels
        
        max_len = 512
        # Encode word string labels to integer labels
        int_labels = encode_word_labels(item['words_labels'])

        # Prepare inputs using the processor
        inputs = self.processor(
            image, 
            item['words'], 
            boxes=item['page_words_bboxes_normalized'], 
            word_labels=int_labels,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=max_len
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'bbox': inputs['bbox'].squeeze(),
            'labels': inputs['labels'].squeeze(),
            'pixel_values': inputs['pixel_values'].squeeze(),
            'int_labels': int_labels
        }

### Cusstom Model Training Loop

In [None]:
def custom_train_loop(
    model, 
    train_loader, 
    val_loader, 
    processor,
    device, 
    epochs, 
    learning_rate, 
    patience, 
    class_weights=None,
    visualizer=None,
    training_metadata=None
):
    """
    Custom training loop for training a model with train and validation data.
    Includes metrics tracking, learning rate scheduling, and model checkpointing.
    """
    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=patience, verbose=True)
    
    # Initialize loss function (Cross Entropy)
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    
    # Tracking variables for best model
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None
    
    # Metrics history
    metrics_history = {
        'train_loss': [],
        'val_loss': [],
        'train_accuracy': [],
        'val_accuracy': [],
        'train_precision': [],
        'train_recall': [],
        'train_f1': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': [],
        'gradients': [],
        'class_weights': class_weights.tolist() if class_weights is not None else None,
        'activations': []
    }
    
    def compute_metrics(predictions, labels):
        """
        Compute precision, recall, F1 score, and accuracy from predictions and true labels.
        """
        active_predictions = predictions.flatten()[labels.flatten() != -100]
        active_labels = labels.flatten()[labels.flatten() != -100]
        precision, recall, f1, _ = precision_recall_fscore_support(active_labels.cpu(), active_predictions.cpu(), average='weighted', zero_division=0)
        accuracy = accuracy_score(active_labels.cpu(), active_predictions.cpu())
        return precision, recall, f1, accuracy
    
    # Training loop
    for epoch in range(epochs):
        model.train()  # Set model to training mode
        total_train_loss, total_train_acc, train_batches = 0, 0, 0
        
        print(f"\nTraining epoch: {epoch+1}")
        
        # Training phase
        for batch in train_loader:
            input_ids, attention_mask, bbox, labels, pixel_values = [batch[key].to(device) for key in ['input_ids', 'attention_mask', 'bbox', 'labels', 'pixel_values']]

            optimizer.zero_grad()  # Zero gradients

            # Forward pass
            outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, labels=labels, pixel_values=pixel_values)
            logits = outputs.logits

            # Compute loss
            loss = criterion(logits.view(-1, logits.shape[-1]), labels.view(-1))
            total_train_loss += loss.item()

            # Compute accuracy
            predictions = logits.argmax(dim=-1)
            precision, recall, f1, accuracy = compute_metrics(predictions, labels)
            total_train_acc += accuracy
            train_batches += 1

            # Backward pass and optimization
            loss.backward()
            gradients = [(name, param.grad.norm().item()) for name, param in model.named_parameters() if param.grad is not None]
            metrics_history['gradients'].append(gradients)
            optimizer.step()

        # Validation phase
        model.eval()  # Set model to evaluation mode
        total_val_loss, total_val_acc, val_batches = 0, 0, 0
        all_preds, all_labels = [], []

        with torch.no_grad():
            for batch in val_loader:
                input_ids, attention_mask, bbox, labels, pixel_values = [batch[key].to(device) for key in ['input_ids', 'attention_mask', 'bbox', 'labels', 'pixel_values']]
                
                outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, labels=labels, pixel_values=pixel_values)
                total_val_loss += outputs.loss.item()

                # Compute validation accuracy
                predictions = outputs.logits.argmax(dim=-1)
                all_preds.append(predictions)
                all_labels.append(labels)

                precision, recall, f1, accuracy = compute_metrics(predictions, labels)
                total_val_acc += accuracy
                val_batches += 1

        # Average losses and accuracies
        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        avg_train_acc = total_train_acc / train_batches
        avg_val_acc = total_val_acc / val_batches

        # Compute precision, recall, and F1 for validation
        all_preds = torch.cat(all_preds).flatten().cpu()
        all_labels = torch.cat(all_labels).flatten().cpu()
        val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)

        # Log metrics
        metrics_history['train_loss'].append(avg_train_loss)
        metrics_history['val_loss'].append(avg_val_loss)
        metrics_history['train_accuracy'].append(avg_train_acc)
        metrics_history['val_accuracy'].append(avg_val_acc)
        metrics_history['train_precision'].append(precision)
        metrics_history['train_recall'].append(recall)
        metrics_history['train_f1'].append(f1)
        metrics_history['val_precision'].append(val_precision)
        metrics_history['val_recall'].append(val_recall)
        metrics_history['val_f1'].append(val_f1)

        # Log activations (model output logits)
        metrics_history['activations'].append(logits.detach().cpu().numpy())

        # Current learning rate
        current_lr = optimizer.param_groups[0]['lr']

        # Log metrics for visualization
        visualizer.log_metrics(model, epoch, avg_train_loss, avg_val_loss, current_lr, avg_train_acc, avg_val_acc)

        # Print epoch summary
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
        print(f"Train Accuracy: {avg_train_acc:.4f}, Validation Accuracy: {avg_val_acc:.4f}")
        print(f"Train Precision: {precision:.4f}, Train Recall: {recall:.4f}, Train F1: {f1:.4f}")
        print(f"Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}, Validation F1: {val_f1:.4f}")

        # Learning rate scheduling
        scheduler.step(avg_val_loss)

        # Model checkpointing and early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
            # Save best model checkpoint
            if epoch > 8 and (epoch % 2 == 0 or epoch % 3 == 0 or epoch % 5 == 0):
                os.makedirs(output_dir, exist_ok=True)
                model.save_pretrained(os.path.join(output_dir, f'checkpoint-epoch-{epoch+1}'))
                print(f"Saved new best model at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
        else:
            epochs_no_improve += 1

        # Early stopping (if no improvement for patience epochs)
        if epochs_no_improve >= patience * 2:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

        # Store metadata for training
        training_metadata.training_flow.append({
            "epoch_count": epoch+1,
            "avg_train_loss": avg_train_loss,
            "avg_val_loss": avg_val_loss,
            "best_val_loss": round(best_val_loss, 4),
            "current_lr": current_lr,
            "train_accuracy": round(avg_train_acc * 100, 2),
            "val_accuracy": round(avg_val_acc * 100, 2)
        })

    print(f"\nTraining complete at epoch {epoch+1}")
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(os.path.join(output_dir, f'checkpoint-epoch-{epoch+1}'))
    print(f"Saved best model at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
    
    # Save model state
    try:
        torch.save(model.state_dict(), os.path.join(output_dir, f'best_model_epoch_{epoch}.pth'))
    except Exception as e:
        print("Exception in saving model:", e)

    return best_model_state


In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is set to:", device)


# Training and data preparation code...
if torch.cuda.is_available():
    device = torch.device("cuda")
    
else:
    device = torch.device("cpu")

print("Device is set to: ",device)

In [None]:
label_map = None
field_and_labels_mapping_file_path = os.path.join(other_supporting_data_path, f"field_and_labels_mappings_{data_date_range}.json")
with open (field_and_labels_mapping_file_path, 'r') as f:
    label_map = json.load(f)
    
    
training_metadata.total_fields = int(len(label_map["unique_fields"]))
training_metadata.fields_list = label_map["unique_fields"]
print(training_metadata.total_fields, training_metadata.fields_list)
training_metadata.save_training_metadata()

training_metadata.dataset_load_start_time = datetime.today().strftime("%Y-%m-%d_%H:%M:%S")
start_time =  datetime.today()

data_list = None
single_page_json_file_data_path = os.path.join(other_supporting_data_path, f"single_page_files_list_train_3.json")

#Write dictionary to JSON file
with open (single_page_json_file_data_path, 'r') as f:
    data_list = json.load(f)
    

In [None]:
# data_list_val = None
# single_page_json_file_data_path = os.path.join(other_supporting_data_path, f"single_page_files_list_val_2.json")

# #Write dictionary to JSON file
# with open (single_page_json_file_data_path, 'r') as f:
#     data_list_val = json.load(f)

training_metadata.dataset_rows = int(len(data_list))

print("training data sample rows: ", training_metadata.dataset_rows)
print("row keys: ", data_list[0].keys())

# print("training data sample rows: ", training_metadata.dataset_rows)
# print("val len: ", len(data_list_val))

# Initialize processor
processor = LayoutLMv3Processor.from_pretrained(processor_path, apply_ocr=False)

In [None]:
shuffle = True
batch_size = training_metadata.batch_size

dataset = NERDataset(data_list, processor, label_map)
# dataset_val = NERDataset(data_list_val, processor, label_map)

labels_lst = []
datasetlst = []
failed_files = []

data_samples_len = len(data_list)
# data_samples_len = 50
for i in range(data_samples_len):
    try:
        item = dataset.__getitem__(i)
        
        if list(item.keys()) == list(['input_ids', 'attention_mask', 'bbox', 'labels', 'pixel_values', 'int_labels']):
            labels_lst.append(item["int_labels"])
            del item["int_labels"]
            datasetlst.append(item)
        
    except Exception as e:
        data_list[i]['enception'] = e
        failed_files.append(data_list[i])

training_metadata.total_dataset_samples = int(len(datasetlst))
training_metadata.failed_rows_samples = int(len(failed_files))

print("Total dataset samples for training: ",training_metadata.total_dataset_samples)
print("Total dataset failed: ",training_metadata.failed_rows_samples)

In [None]:
# val_size = 0.25
print("train_test_split")
train_subset, val_subset = train_test_split(
    datasetlst, 
    test_size=training_metadata.validation_dataset_size, 
    random_state=42
)

# train_subset = datasetlst
# val_subset = datasetlst_val

print("train loader")
train_dataloader = DataLoader(
    train_subset, 
    batch_size=training_metadata.batch_size, 
    shuffle=shuffle
)

print("val loader")
val_dataloader = DataLoader(
    val_subset, 
    batch_size=training_metadata.batch_size, 
    shuffle=shuffle
)

training_metadata.dataset_load_end_time = datetime.today().strftime("%Y-%m-%d_%H:%M:%S")

end_time = datetime.today()

# Calculate the time difference
time_difference = end_time - start_time

time_difference_minutes = time_difference.total_seconds() / 60

training_metadata.dataset_load_time_required = round(time_difference_minutes,3)

training_metadata.total_train_samples = int(len(train_subset))
training_metadata.total_val_samples = int(len(val_subset))

training_metadata.save_training_metadata()


In [None]:
# Compute class weights
labels = np.concatenate(labels_lst)  # Flatten all labels
class_weights = compute_class_weight('balanced', classes=np.unique(labels), y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)
class_weights = class_weights.to(device)

print("class weights: ", class_weights)

config = LayoutLMv3Config.from_pretrained(model_name)

config.num_labels = training_metadata.total_fields



In [None]:
# Modify dropout rates
config.hidden_dropout_prob = training_metadata.hidden_dropout_prob      # Dropout for hidden states
config.attention_probs_dropout_prob = training_metadata.attention_probs_dropout_prob  # Dropout for attention probabilities
config.dropout = training_metadata.dropout_rate     



In [None]:
model = LayoutLMv3ForTokenClassification.from_pretrained(
    model_name, 
    config = config
    
).to(device)

print("Model Device is set to:", model.device)

# Initialize visualizer
visualizer = TrainingVisualizer(checkpoint_dir = model_op_dir_path, training_metadata = training_metadata)

training_metadata.train_start_time = datetime.today().strftime("%Y-%m-%d_%H:%M:%S")

start_time = datetime.today()

In [None]:
# Final model training
best_model_state = custom_train_loop(
    model=model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    processor=processor,
    device=device,
    epochs=training_metadata.num_epochs,
    learning_rate=training_metadata.default_learning_rate,
    patience=training_metadata.patience,
    class_weights=class_weights,
    visualizer=visualizer,
    training_metadata=training_metadata
)

# Save final metadata
training_metadata.train_end_time = datetime.today().strftime("%Y-%m-%d_%H:%M:%S")
end_time = datetime.today()
training_metadata.train_time_required = round((end_time - start_time).total_seconds() / 60, 3)
training_metadata.visualizer_image_path = visualizer.plot_all_metrics()
visualizer.plot_all_metrics_in_one_image(training_metadata.metrics_history)
training_metadata.save_training_metadata()


In [None]:
##### inference
import os
import torch
import json
import re
from PIL import Image
from torch.utils.data import Dataset
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor

# Define paths
processor_model_name = '/home/jovyan/work/Sagar/layoutlmv3_base'
model_name = "/home/jovyan/work/Sagar/layoutlmv3_kvp_playground/layoutlmv3_kvp_fine_tuned/finetune_6_2025-01-16/checkpoint-epoch-20"
data_folder_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data"
other_supporting_data_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data/other_supporting_data"
image_folder = "image_seperated_pdfs"
data_date_range = "5_to_11_august"

# Define dataset class
class NERDataset(Dataset):
    def __init__(self, data, processor, labels_map):
        self.data = data
        self.processor = processor
        self.labels_map = labels_map
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = os.path.join(data_folder_path, f"{item['date']}/{image_folder}/{item['page_file_name']}.png")
        image = Image.open(img_path).convert("RGB")
        
        # Encode labels
        def encode_labels(word_labels):
            return [self.labels_map["fields_to_label"].get(str(lbl), 0) for lbl in word_labels]
        
        words = [re.sub(r'[^A-Za-z0-9 ]+', '', w).lower() for w in item['words']]
        int_labels = encode_labels(item['words_labels'])
        
        # Process inputs
        inputs = self.processor(
            image, words, boxes=item['page_words_bboxes_normalized'],
            return_tensors="pt", padding="max_length", truncation=True, max_length=512
        )
        
        return {
            "encodings": {k: v.squeeze() for k, v in inputs.items()},
            "image": img_path,
            "words": item['words'],
            "words_bboxes": item['page_words_bboxes_normalized'],
            "int_labels": int_labels,
            "field_create": item["field_create"]
        }

# Load data
with open(os.path.join(other_supporting_data_path, "single_page_files_list_test_3.json"), 'r') as f:
    data_list = json.load(f)

with open(os.path.join(other_supporting_data_path, f"field_and_labels_mappings_{data_date_range}.json"), 'r') as f:
    label_map = json.load(f)

# Initialize processor and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = LayoutLMv3Processor.from_pretrained(processor_model_name, apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name).to(device)

# Load model checkpoint
checkpoint_path = os.path.join(model_name, "best_model_epoch_19.pth")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint, strict=False)

# Create dataset
dataset = NERDataset(data_list, processor, label_map)

# Inference function
def inference(item, model, processor, device=device):
    model.eval()
    inputs = {k: v.to(device) for k, v in item["encodings"].items()}
    with torch.no_grad():
        outputs = model(**inputs)
    
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    extracted_fields = {}
    for i, word in enumerate(item["words"]):
        label_id = predictions[i]
        if label_id != 0:
            field = label_map["label_to_field"].get(str(label_id), "Unknown")
            extracted_fields[field] = extracted_fields.get(field, "") + " " + word
    
    return {
        "image": item["image"],
        "predictions": extracted_fields,
        "actual": item["field_create"]
    }

# Run inference on first item
y = inference(dataset[0], model, processor)
print("Actual:", y["actual"])
print("Predictions:", y["predictions"])