In [1]:
import os
import gc
import shutil
import pandas as pd
import numpy as np
import editdistance
import time 
import warnings
from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# For system resource checks
def check_system_resources(model_name="microsoft/trocr-small-handwritten"):
    """Check system resources to ensure we have enough space for TrOCR"""
    try:
        # Disk space check
        total, used, free = shutil.disk_usage('/')
        free_gb = free / (1024**3)  # Convert bytes to GB
        
        # Estimate model size based on variant
        estimated_size_mb = {
            "small": 300,    # ~300 MB for trocr-small
            "base": 600,     # ~600 MB for trocr-base
            "large": 1200    # ~1.2 GB for trocr-large
        }
        
        # Determine model size category
        size_category = 'small' if 'small' in model_name.lower() else \
                        'base' if 'base' in model_name.lower() else \
                        'large'
        
        model_size_mb = estimated_size_mb.get(size_category, 300)
        model_size_gb = model_size_mb / 1024
        
        # GPU memory check
        if torch.cuda.is_available():
            gpu_info = torch.cuda.get_device_properties(0)
            gpu_name = gpu_info.name
            total_gpu_memory_mb = gpu_info.total_memory / (1024 * 1024)
            
            # Get current free memory
            torch.cuda.empty_cache()
            free_gpu_memory_mb = torch.cuda.memory_reserved(0) / (1024 * 1024)
            
            # Estimate required memory (very rough estimate)
            required_gpu_mb = model_size_mb * 4  # Model size x4 for training overhead
        else:
            gpu_name = "No GPU detected"
            total_gpu_memory_mb = 0
            free_gpu_memory_mb = 0
            required_gpu_mb = model_size_mb * 3
        
        # Check if system meets requirements
        disk_sufficient = free_gb > (model_size_gb * 2)  # Need extra space for cache
        gpu_sufficient = total_gpu_memory_mb > required_gpu_mb if torch.cuda.is_available() else False
        
        resources = {
            "free_disk_gb": free_gb,
            "required_disk_gb": model_size_gb * 2,
            "disk_sufficient": disk_sufficient,
            "gpu_name": gpu_name,
            "total_gpu_memory_mb": total_gpu_memory_mb,
            "required_gpu_mb": required_gpu_mb,
            "gpu_sufficient": gpu_sufficient
        }
        
        # Print results
        print("\n===== System Resource Check =====")
        print(f"Model: {model_name}")
        print(f"Free Disk Space: {free_gb:.2f} GB (Need: {model_size_gb * 2:.2f} GB)")
        print(f"GPU: {gpu_name}")
        print(f"GPU Memory: {total_gpu_memory_mb:.2f} MB (Need: {required_gpu_mb:.2f} MB)")
        print(f"Disk Space: {'SUFFICIENT' if disk_sufficient else 'INSUFFICIENT'}")
        print(f"GPU Memory: {'SUFFICIENT' if gpu_sufficient else 'INSUFFICIENT'}")
        
        if not disk_sufficient or not gpu_sufficient:
            print("\n⚠️ WARNING: System resources may be insufficient for TrOCR model")
            if not disk_sufficient:
                print(f"  - Need {model_size_gb * 2:.2f} GB disk space, but only have {free_gb:.2f} GB")
            if not gpu_sufficient and torch.cuda.is_available():
                print(f"  - Need {required_gpu_mb:.2f} MB GPU memory, but only have {total_gpu_memory_mb:.2f} MB")
        
        return resources
        
    except Exception as e:
        print(f"Error checking system resources: {e}")
        return None

# Clean up function to free memory
def cleanup():
    """Clean up memory and GPU cache"""
    # Clean up global variables if they exist
    for var in ['model', 'processor', 'tokenizer', 'feature_extractor']:
        if var in globals():
            del globals()[var]
    
    # Run garbage collection
    gc.collect()
    
    # Clean CUDA cache if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("Memory cleaned up")

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Check system resources before proceeding
resources = check_system_resources("microsoft/trocr-small-handwritten")
if resources and (not resources["disk_sufficient"] or not resources["gpu_sufficient"]):
    raise RuntimeError("System resources insufficient for TrOCR. See warnings above.")

# Base directory and paths
base_dir = os.getcwd()
ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_train.txt') 
images_dir = os.path.join(base_dir, 'balinese_word_train')

# Data loading function
def load_data():
    """Load and prepare Balinese transliteration data"""
    filenames = []
    labels = []

    with open(ground_truth_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if line:  # Ensure the line is not empty
                parts = line.split(';')
                if len(parts) == 2:
                    filename, label = parts
                    label = label.lower()
                    filenames.append(filename)
                    labels.append(label)
                else:
                    print(f"Skipping malformed line: {line}")

    data = pd.DataFrame({
        'filename': filenames,
        'label': labels
    })
    
    print(f"Loaded {len(data)} image-text pairs")
    
    # Calculate label statistics
    label_counts = data['label'].value_counts()
    print(f"Unique labels: {len(label_counts)}")
    
    return data

# Load the data
data = load_data()

# Custom split function for train/val/test
def custom_split(df, rare_label_threshold=3, val_size=0.1, test_size=0.1, random_state=42):
    """
    Split data ensuring rare words are only in training set
    
    Args:
        df: DataFrame with 'label' column
        rare_label_threshold: Words appearing fewer than this many times are considered rare
        val_size: Proportion for validation set
        test_size: Proportion for test set
        random_state: Random seed for reproducibility
    
    Returns:
        train_df, val_df, test_df
    """
    from sklearn.model_selection import train_test_split
    
    # Identify rare labels
    label_counts = df['label'].value_counts()
    rare_labels = label_counts[label_counts < rare_label_threshold].index
    
    # Separate rare and common words
    rare_df = df[df['label'].isin(rare_labels)]
    common_df = df[~df['label'].isin(rare_labels)]
    
    # First split common words into train+val and test
    common_trainval, test_df = train_test_split(
        common_df, 
        test_size=test_size, 
        random_state=random_state,
        stratify=common_df['label'] if len(common_df) > 100 else None
    )
    
    # Split train+val into train and val
    train_common, val_df = train_test_split(
        common_trainval,
        test_size=val_size / (1 - test_size),  # Adjust to get correct proportion
        random_state=random_state,
        stratify=common_trainval['label'] if len(common_trainval) > 100 else None
    )
    
    # Add rare words to train set
    train_df = pd.concat([train_common, rare_df], ignore_index=True)
    
    # Shuffle all sets
    train_df = train_df.sample(frac=1, random_state=random_state).reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)
    test_df = test_df.reset_index(drop=True)
    
    print(f"Training size: {len(train_df)}; Validation size: {len(val_df)}; Test size: {len(test_df)}")
    
    return train_df, val_df, test_df

# Split the data
train_data, val_data, test_data = custom_split(data)

# Create a fixed collate function for batching
def collate_fn(batch):
    """Custom collate function to handle variable length inputs"""
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}

# Create dataset class for TrOCR - CORRECTED VERSION
class BalineseOCRDataset(Dataset):
    def __init__(self, df, images_dir, processor=None, max_target_length=128):
        self.data = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.loc[idx, 'filename']
        text = self.data.loc[idx, 'label']  # Raw text, not encoded
        
        img_path = os.path.join(self.images_dir, img_name)
        # Load as PIL image - don't transform it here
        image = Image.open(img_path).convert('RGB')
        
        if self.processor:
            # Process image and text separately
            pixel_values = self.processor.image_processor(image, return_tensors="pt").pixel_values
            labels = self.processor.tokenizer(
                text,
                padding="max_length",
                max_length=self.max_target_length,
                truncation=True,
                return_tensors="pt"
            ).input_ids
            
            # Combine in a dict
            encoding = {
                "pixel_values": pixel_values.squeeze(),
                "labels": labels.squeeze()
            }
            return encoding
        else:
            return image, text

# TrOCR implementation
def setup_trocr():
    """Set up TrOCR model, processor, and tokenizer"""
    try:
        # First check for sentencepiece
        try:
            import sentencepiece
        except ImportError:
            print("Installing sentencepiece...")
            import subprocess
            subprocess.check_call(['pip', 'install', 'sentencepiece'])
            import sentencepiece
            
        # Import necessary components
        from transformers import TrOCRProcessor, VisionEncoderDecoderModel
        
        print("Loading TrOCR model and processor...")
        
        # Load model and processor
        processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
        model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten")
        
        # Configure for generation
        model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
        model.config.pad_token_id = processor.tokenizer.pad_token_id
        model.config.vocab_size = model.config.decoder.vocab_size
        
        # Set decoding parameters
        model.config.eos_token_id = processor.tokenizer.sep_token_id
        model.config.max_length = 64
        model.config.early_stopping = True
        model.config.no_repeat_ngram_size = 3
        model.config.length_penalty = 2.0
        model.config.num_beams = 4
        
        return model, processor
    
    except Exception as e:
        print(f"Error setting up TrOCR: {e}")
        return None, None

# Create datasets and dataloaders
def create_dataloaders(processor, batch_sizes=(16, 32, 32)):
    """Create datasets and dataloaders for training, validation and testing"""
    # Create datasets - NO CUSTOM TRANSFORMS
    train_dataset = BalineseOCRDataset(
        train_data, 
        images_dir, 
        processor=processor
    )
    
    val_dataset = BalineseOCRDataset(
        val_data, 
        images_dir, 
        processor=processor
    )
    
    test_dataset = BalineseOCRDataset(
        test_data, 
        images_dir, 
        processor=processor
    )
    
    # Create dataloaders with custom collate function
    train_batch_size, val_batch_size, test_batch_size = batch_sizes
    
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=train_batch_size, 
        shuffle=True, 
        num_workers=2,
        collate_fn=collate_fn
    )
    
    val_dataloader = DataLoader(
        val_dataset, 
        batch_size=val_batch_size, 
        shuffle=False, 
        num_workers=2,
        collate_fn=collate_fn
    )
    
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=test_batch_size, 
        shuffle=False, 
        num_workers=2,
        collate_fn=collate_fn
    )
    
    return train_dataloader, val_dataloader, test_dataloader

# Training function
def train_trocr(model, processor, train_dataloader, val_dataloader, num_epochs=3, save_path="trocr_balinese"):
    """Train TrOCR model on Balinese data"""
    # Move model to device
    model.to(device)
    
    # Set up optimizer
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    
    # Set up scheduler (linear warmup then decay)
    from transformers import get_linear_schedule_with_warmup
    num_training_steps = len(train_dataloader) * num_epochs
    num_warmup_steps = num_training_steps // 10
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=num_warmup_steps, 
        num_training_steps=num_training_steps
    )
    
    # Training loop
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        for batch_idx, batch in enumerate(train_dataloader):
            # Extract inputs and move to device
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            
            # Forward pass
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            
            if (batch_idx + 1) % 50 == 0:
                print(f"Batch {batch_idx+1}/{len(train_dataloader)}, Loss: {loss.item():.4f}")
        
        avg_train_loss = train_loss / len(train_dataloader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dataloader):
                pixel_values = batch["pixel_values"].to(device)
                labels = batch["labels"].to(device)
                
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss
                
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_dataloader)
        
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            model.save_pretrained(f"{save_path}_best")
            print(f"Saved best model at epoch {epoch+1}")
    
    # Save final model
    model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
    
    return model, processor

# Evaluation function
def evaluate_trocr(model, processor, test_dataloader):
    """Evaluate TrOCR model and compute CER"""
    model.eval()
    total_cer = 0
    total_samples = 0
    predictions = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_dataloader):
            # Get images
            pixel_values = batch["pixel_values"].to(device)
            
            # Generate text
            generated_ids = model.generate(pixel_values)
            
            # Decode text
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
            
            # Get ground truth
            labels = batch["labels"]
            ground_truth = processor.batch_decode(labels, skip_special_tokens=True)
            
            # Compute CER
            for pred, gt in zip(generated_text, ground_truth):
                edit_distance = editdistance.eval(pred, gt)
                cer = edit_distance / max(len(gt), 1)
                
                total_cer += cer
                total_samples += 1
                
                predictions.append({
                    "prediction": pred,
                    "ground_truth": gt,
                    "cer": cer
                })
    
    # Calculate average CER
    avg_cer = total_cer / total_samples if total_samples > 0 else 0
    
    print(f"\nTest CER: {avg_cer:.4f}")
    
    # Display worst examples
    predictions.sort(key=lambda x: x["cer"], reverse=True)
    print("\nWorst 5 predictions:")
    for i, pred in enumerate(predictions[:5]):
        print(f"{i+1}. Pred: '{pred['prediction']}' | GT: '{pred['ground_truth']}' | CER: {pred['cer']:.4f}")
    
    # Save results to CSV
    csv_file = "test_cer_results.csv"
    if os.path.exists(csv_file):
        df = pd.read_csv(csv_file)
    else:
        df = pd.DataFrame(columns=["model_name", "test_cer"])
    
    model_name = "trocr-small-handwritten"
    if model_name in df["model_name"].values:
        df.loc[df["model_name"] == model_name, "test_cer"] = avg_cer
    else:
        new_row = pd.DataFrame({"model_name": [model_name], "test_cer": [avg_cer]})
        df = pd.concat([df, new_row], ignore_index=True)
    
    df.to_csv(csv_file, index=False)
    print(f"Results saved to {csv_file}")
    
    return avg_cer, predictions

# Main execution function
def run_trocr_pipeline(num_epochs=3, batch_size=8):
    """Run complete TrOCR pipeline from setup to evaluation"""
    try:
        # Clean up before starting
        cleanup()
        
        # Set up TrOCR
        model, processor = setup_trocr()
        
        if not model or not processor:
            raise RuntimeError("Failed to initialize TrOCR model or processor")
        
        # Create dataloaders
        train_dataloader, val_dataloader, test_dataloader = create_dataloaders(
            processor, 
            batch_sizes=(batch_size, batch_size, batch_size)
        )
        
        # Train model
        model, processor = train_trocr(
            model, 
            processor, 
            train_dataloader, 
            val_dataloader, 
            num_epochs=num_epochs, 
            save_path="trocr_balinese"
        )
        
        # Evaluate model
        cer, predictions = evaluate_trocr(model, processor, test_dataloader)
        
        # Final cleanup
        cleanup()
        
        return True
    
    except Exception as e:
        print(f"Error in TrOCR pipeline: {e}")
        cleanup()
        return False

# Run the pipeline if script is executed directly
if __name__ == "__main__":
    run_trocr_pipeline(num_epochs=15, batch_size=8)

Using device: cuda

===== System Resource Check =====
Model: microsoft/trocr-small-handwritten
Free Disk Space: 24.06 GB (Need: 0.59 GB)
GPU: NVIDIA A40
GPU Memory: 45416.12 MB (Need: 1200.00 MB)
Disk Space: SUFFICIENT
GPU Memory: SUFFICIENT
Loaded 15022 image-text pairs
Unique labels: 4702
Training size: 12922; Validation size: 1050; Test size: 1050
Memory cleaned up
Loading TrOCR model and processor...


2025-03-28 12:46:49.884671: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-28 12:46:49.998070: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-28 12:46:50.026452: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-28 12:46:50.778956: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Co


Epoch 1/15
Batch 50/1616, Loss: 14.4848
Batch 100/1616, Loss: 1.9431
Batch 150/1616, Loss: 0.5997
Batch 200/1616, Loss: 0.4329
Batch 250/1616, Loss: 0.4850
Batch 300/1616, Loss: 0.3088
Batch 350/1616, Loss: 0.2684
Batch 400/1616, Loss: 0.2979
Batch 450/1616, Loss: 0.1864
Batch 500/1616, Loss: 0.2160
Batch 550/1616, Loss: 0.1454
Batch 600/1616, Loss: 0.1908
Batch 650/1616, Loss: 0.1282
Batch 700/1616, Loss: 0.0913
Batch 750/1616, Loss: 0.1236
Batch 800/1616, Loss: 0.1153
Batch 850/1616, Loss: 0.1203
Batch 900/1616, Loss: 0.1266
Batch 950/1616, Loss: 0.1162
Batch 1000/1616, Loss: 0.1137
Batch 1050/1616, Loss: 0.0884
Batch 1100/1616, Loss: 0.1314
Batch 1150/1616, Loss: 0.1028
Batch 1200/1616, Loss: 0.0487
Batch 1250/1616, Loss: 0.0333
Batch 1300/1616, Loss: 0.0982
Batch 1350/1616, Loss: 0.1533
Batch 1400/1616, Loss: 0.1326
Batch 1450/1616, Loss: 0.1077
Batch 1500/1616, Loss: 0.1184
Batch 1550/1616, Loss: 0.0896
Batch 1600/1616, Loss: 0.1277
Epoch 1: Train Loss = 0.9813, Val Loss = 0.0584




Test CER: 0.0823

Worst 5 predictions:
1. Pred: 'turandupa' | GT: '.' | CER: 9.0000
2. Pred: 'sang' | GT: '8' | CER: 4.0000
3. Pred: 'tena' | GT: '2' | CER: 4.0000
4. Pred: 'len' | GT: '.' | CER: 3.0000
5. Pred: 'sul' | GT: 's' | CER: 2.0000
Results saved to test_cer_results.csv
Memory cleaned up
