In [1]:
import os
import gc
import shutil
import pandas as pd
import numpy as np
import editdistance
import time
import warnings
from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def check_system_resources(model_name="microsoft/trocr-small-handwritten"):
    """
    Check system resources to ensure we have enough space for TrOCR.
    """
    try:
        total, used, free = shutil.disk_usage('/')
        free_gb = free / (1024**3)  # Convert bytes to GB
        estimated_size_mb = {"small": 300, "base": 600, "large": 1200}
        if "small" in model_name.lower():
            size_category = "small"
        elif "base" in model_name.lower():
            size_category = "base"
        else:
            size_category = "large"
        model_size_mb = estimated_size_mb.get(size_category, 300)
        model_size_gb = model_size_mb / 1024
        if torch.cuda.is_available():
            gpu_info = torch.cuda.get_device_properties(0)
            gpu_name = gpu_info.name
            total_gpu_memory_mb = gpu_info.total_memory / (1024 * 1024)
            torch.cuda.empty_cache()
            free_gpu_memory_mb = torch.cuda.memory_reserved(0) / (1024 * 1024)
            required_gpu_mb = model_size_mb * 4  
        else:
            gpu_name = "No GPU detected"
            total_gpu_memory_mb = 0
            free_gpu_memory_mb = 0
            required_gpu_mb = model_size_mb * 3
        disk_sufficient = free_gb > (model_size_gb * 2)
        gpu_sufficient = (total_gpu_memory_mb > required_gpu_mb) if torch.cuda.is_available() else False
        resources = {
            "free_disk_gb": free_gb,
            "required_disk_gb": model_size_gb * 2,
            "disk_sufficient": disk_sufficient,
            "gpu_name": gpu_name,
            "total_gpu_memory_mb": total_gpu_memory_mb,
            "required_gpu_mb": required_gpu_mb,
            "gpu_sufficient": gpu_sufficient
        }
        print("\n===== System Resource Check =====")
        print(f"Model: {model_name}")
        print(f"Free Disk Space: {free_gb:.2f} GB (Need: {model_size_gb * 2:.2f} GB)")
        print(f"GPU: {gpu_name}")
        print(f"GPU Memory: {total_gpu_memory_mb:.2f} MB (Need: {required_gpu_mb:.2f} MB)")
        print(f"Disk Space: {'SUFFICIENT' if disk_sufficient else 'INSUFFICIENT'}")
        print(f"GPU Memory: {'SUFFICIENT' if gpu_sufficient else 'INSUFFICIENT'}")
        if not disk_sufficient or not gpu_sufficient:
            print("\WARNING: System resources may be insufficient for TrOCR model")
            if not disk_sufficient:
                print(f"  - Need {model_size_gb * 2:.2f} GB disk space, but only have {free_gb:.2f} GB")
            if not gpu_sufficient and torch.cuda.is_available():
                print(f"  - Need {required_gpu_mb:.2f} MB GPU memory, but only have {total_gpu_memory_mb:.2f} MB")
        return resources
    except Exception as e:
        print(f"Error checking system resources: {e}")
        return None


def cleanup():
    """Clean up memory and GPU cache."""
    for var in ['model', 'processor', 'tokenizer', 'feature_extractor']:
        if var in globals():
            del globals()[var]
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("Memory cleaned up")


# Global device setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# File paths (adjust as needed)
base_dir = os.getcwd()
train_ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_train.txt')
train_images_dir = os.path.join(base_dir, 'balinese_word_train')
test_ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_test.txt')
test_images_dir = os.path.join(base_dir, 'balinese_word_test')


def load_data(file_path):
    """
    Load data from a file with the format: filename;label.
    Converts text labels to lowercase.
    """
    filenames, labels = [], []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                parts = line.split(";")
                if len(parts) == 2:
                    img, label = parts
                    label = label.lower()
                    filenames.append(img)
                    labels.append(label.lower())
                else:
                    print(f"Skipping malformed line: {line}")
    return pd.DataFrame({"filename": filenames, "label": labels})


# Load training data and split into train/validation sets.
data = load_data(train_ground_truth_path)
from sklearn.model_selection import train_test_split
train_data, val_data = train_test_split(data, test_size=0.1, random_state=42)
train_data, val_data = train_data.reset_index(drop=True), val_data.reset_index(drop=True)
print(f"Training pairs: {len(train_data)} | Validation pairs: {len(val_data)}")


def collate_fn(batch):
    """Stack pixel values and labels from the batch."""
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}


class BalineseOCRDataset(Dataset):
    def __init__(self, df, images_dir, processor, max_target_length=128):
        self.df = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.images_dir, row["filename"])
        label = row["label"]
        image = Image.open(img_path).convert("RGB")
        encoding = self.processor(
            images=image,
            text=label,
            padding="max_length",
            max_length=self.max_target_length,
            truncation=True,
            return_tensors="pt"
        )
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        return encoding


def setup_trocr(model_name="microsoft/trocr-small-handwritten"):
    """
    Set up the TrOCR model and processor.
    """
    try:
        from transformers import TrOCRProcessor, VisionEncoderDecoderModel
        print(f"Loading processor and model: {model_name}")
        processor = TrOCRProcessor.from_pretrained(model_name)
        model = VisionEncoderDecoderModel.from_pretrained(model_name)
        model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
        model.config.pad_token_id = processor.tokenizer.pad_token_id
        model.config.vocab_size = model.config.decoder.vocab_size
        model.config.eos_token_id = processor.tokenizer.sep_token_id
        model.config.max_length = 64
        model.config.early_stopping = True
        model.config.no_repeat_ngram_size = 3
        model.config.length_penalty = 2.0
        model.config.num_beams = 4
        return model, processor
    except Exception as e:
        print(f"Error setting up TrOCR: {e}")
        return None, None


def create_dataloaders(processor, batch_sizes=(16, 32)):
    """
    Create DataLoaders for training and validation.
    """
    train_dataset = BalineseOCRDataset(train_data, train_images_dir, processor)
    val_dataset = BalineseOCRDataset(val_data, train_images_dir, processor)
    train_loader = DataLoader(train_dataset, batch_size=batch_sizes[0], shuffle=True,
                              num_workers=2, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_sizes[1], shuffle=False,
                            num_workers=2, collate_fn=collate_fn)
    return train_loader, val_loader


def train_trocr(model, processor, train_loader, val_loader, num_epochs=3, save_path="trocr_model"):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    from transformers import get_linear_schedule_with_warmup
    num_training_steps = len(train_loader) * num_epochs
    num_warmup_steps = num_training_steps // 10
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps,
                                                  num_training_steps=num_training_steps)
    best_val_loss = float("inf")
    training_losses = []
    validation_losses = []
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0.0
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        for step, batch in enumerate(train_loader):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_train_loss += loss.item()
            # if (step + 1) % 100 == 0:
                # print(f"Step {step+1}: Loss = {loss.item():.4f}")
        avg_train_loss = total_train_loss / len(train_loader)
        training_losses.append(avg_train_loss)
        
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch["pixel_values"].to(device)
                labels = batch["labels"].to(device)
                outputs = model(pixel_values=pixel_values, labels=labels)
                total_val_loss += outputs.loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        validation_losses.append(avg_val_loss)
        
        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            model.save_pretrained(f"{save_path}_best")
            processor.save_pretrained(f"{save_path}_best")
            print("** Best model saved! **")
    
    model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
    print("** Final model saved! **")
    return model, processor, training_losses, validation_losses


def calculate_global_cer(results):
    total_ed, total_refs = 0, 0
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        total_ed += editdistance.eval(ref, hyp)
        total_refs += len(ref)
    return total_ed / total_refs if total_refs > 0 else 0.0


def print_top_worst_samples(results, n=5):
    results_with_cer = []
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        cer = editdistance.eval(ref, hyp) / (len(ref) if len(ref) > 0 else 1)
        new_r = r.copy()
        new_r['cer'] = cer
        results_with_cer.append(new_r)
    results_with_cer.sort(key=lambda x: x['cer'], reverse=True)
    worst_samples = results_with_cer[:n]
    print(f"\n=== Top {n} Worst Samples by CER ===")
    for i, sample in enumerate(worst_samples, start=1):
        print(f"{i}) Image: {sample['image_filename']}")
        print(f"   CER: {sample['cer']:.4f}")
        print(f"   Predicted    : {sample['predicted_caption']}")
        print(f"   Ground Truth : {sample['ground_truth_caption']}\n")


csv_file = "test_cer_results.csv"

def log_test_cer(model_name, cer_value):
    try:
        df = pd.read_csv(csv_file)
    except FileNotFoundError:
        df = pd.DataFrame(columns=["model_name", "test_cer"])
    if model_name in df['model_name'].values:
        df.loc[df['model_name'] == model_name, 'test_cer'] = cer_value
    else:
        new_row = pd.DataFrame({"model_name": [model_name], "test_cer": [cer_value]})
        df = pd.concat([df, new_row], ignore_index=True)
    df.to_csv(csv_file, index=False)
    print(f"Logged {model_name}: {cer_value:.4f}")


def test_trocr_model(model_path="trocr_balinese_best", processor_path=None):
    """
    Test a trained TrOCR model on test data.
    """
    from transformers import TrOCRProcessor, VisionEncoderDecoderModel
    if processor_path is None:
        processor_path = "microsoft/trocr-small-handwritten"
    test_filenames, test_labels = [], []
    with open(test_ground_truth_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if line:
                parts = line.split(';')
                if len(parts) == 2:
                    filename, label = parts
                    test_filenames.append(filename)
                    test_labels.append(label.lower())
                else:
                    print(f"Skipping malformed line: {line}")
    test_data = pd.DataFrame({"filename": test_filenames, "label": test_labels})
    print(f"Loaded {len(test_data)} test image-text pairs")
    try:
        print(f"Loading model from {model_path}")
        model = VisionEncoderDecoderModel.from_pretrained(model_path)
        print(f"Loading processor from {processor_path}")
        processor = TrOCRProcessor.from_pretrained(processor_path)
        model = model.to(device)
        results = []
        batch_size = 16
        model.eval()
        with torch.no_grad():
            for i in range(0, len(test_data), batch_size):
                batch_data = test_data.iloc[i:i+batch_size]
                images = []
                for idx, row in batch_data.iterrows():
                    img_path = os.path.join(test_images_dir, row['filename'])
                    image = Image.open(img_path).convert('RGB')
                    images.append(image)
                pixel_values = processor.image_processor(images, return_tensors="pt").pixel_values.to(device)
                generated_ids = model.generate(pixel_values)
                predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)
                for (_, row), pred in zip(batch_data.iterrows(), predictions):
                    results.append({
                        'image_filename': row['filename'],
                        'predicted_caption': pred,
                        'ground_truth_caption': row['label']
                    })
        cer = calculate_global_cer(results)
        print(f"\nTest CER for {model_path}: {cer:.4f}")
        print_top_worst_samples(results, n=5)
        model_name_only = os.path.basename(model_path)
        log_test_cer(model_name_only, cer)
        return cer, results
    except Exception as e:
        print(f"Error testing TrOCR model: {e}")
        import traceback
        traceback.print_exc()
        return None, None
    finally:
        if 'model' in locals():
            del model
        if 'processor' in locals():
            del processor
        torch.cuda.empty_cache()
        print("Memory cleared after testing")


def run_full_pipeline(model_variant="small", num_epochs=3, batch_size=8):
    if model_variant == "small":
        hf_model = "microsoft/trocr-small-handwritten"
        save_path = "trocr_balinese_small_best"
        proc_path = "microsoft/trocr-small-handwritten"
    elif model_variant == "base":
        hf_model = "microsoft/trocr-base-handwritten"
        save_path = "trocr_balinese_base_best"
        proc_path = "microsoft/trocr-base-handwritten"
    else:
        hf_model = "microsoft/trocr-large-handwritten"
        save_path = "trocr_balinese_large_best"
        proc_path = "microsoft/trocr-large-handwritten"

    print(f"\n=== Running full pipeline for TrOCR {model_variant.upper()} ===")
    cleanup()  # Clear memory before starting the new model run
    check_system_resources(hf_model)
    model, processor = setup_trocr(hf_model)
    if model is None or processor is None:
        raise RuntimeError("Model setup failed.")
    
    train_loader, val_loader = create_dataloaders(processor, (batch_size, batch_size))
    model, processor, training_losses, validation_losses = train_trocr(
        model, processor, train_loader, val_loader, num_epochs=num_epochs, save_path=save_path
    )
    
    cleanup()  # Clean up before testing
    
    test_cer, test_results = test_trocr_model(model_path=save_path, processor_path=proc_path)
    return training_losses, validation_losses, test_cer, test_results



if __name__ == "__main__":
    results = {}
    for variant in ["small", "base"]:
        print(f"\n--- Running {variant.upper()} model ---")
        train_losses, val_losses, test_cer, test_results = run_full_pipeline(model_variant=variant, num_epochs=15, batch_size=2)
        
        results[variant] = {
            "train_losses": train_losses,
            "val_losses": val_losses,
            "test_cer": test_cer,
            "test_results": test_results
        }
        
        print(f"Final Test CER for {variant.upper()}: {test_cer:.4f}" if test_cer is not None else f"Testing for {variant} failed.")
        cleanup()  # Clear memory after testing each model
    
    print("\nAll models have been run. Summary of results:")
    for variant, res in results.items():
        print(f"{variant.upper()} model -> Train Losses: {res['train_losses']}, Val Losses: {res['val_losses']}, Test CER: {res['test_cer']:.4f}")


Using device: cuda
Training pairs: 13519 | Validation pairs: 1503

--- Running SMALL model ---

=== Running full pipeline for TrOCR SMALL ===
Memory cleaned up

===== System Resource Check =====
Model: microsoft/trocr-small-handwritten
Free Disk Space: 21.35 GB (Need: 0.59 GB)
GPU: NVIDIA A40
GPU Memory: 45416.12 MB (Need: 1200.00 MB)
Disk Space: SUFFICIENT
GPU Memory: SUFFICIENT
Loading processor and model: microsoft/trocr-small-handwritten


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/15
[Epoch 1] Train Loss: 0.5513 | Val Loss: 0.0709
** Best model saved! **

Epoch 2/15
[Epoch 2] Train Loss: 0.0694 | Val Loss: 0.0608
** Best model saved! **

Epoch 3/15
[Epoch 3] Train Loss: 0.0478 | Val Loss: 0.0391
** Best model saved! **

Epoch 4/15
[Epoch 4] Train Loss: 0.0303 | Val Loss: 0.0295
** Best model saved! **

Epoch 5/15
[Epoch 5] Train Loss: 0.0212 | Val Loss: 0.0245
** Best model saved! **

Epoch 6/15
[Epoch 6] Train Loss: 0.0155 | Val Loss: 0.0239
** Best model saved! **

Epoch 7/15
[Epoch 7] Train Loss: 0.0116 | Val Loss: 0.0208
** Best model saved! **

Epoch 8/15
[Epoch 8] Train Loss: 0.0082 | Val Loss: 0.0208
** Best model saved! **

Epoch 9/15
[Epoch 9] Train Loss: 0.0057 | Val Loss: 0.0203
** Best model saved! **

Epoch 10/15
[Epoch 10] Train Loss: 0.0037 | Val Loss: 0.0201
** Best model saved! **

Epoch 11/15
[Epoch 11] Train Loss: 0.0024 | Val Loss: 0.0187
** Best model saved! **

Epoch 12/15
[Epoch 12] Train Loss: 0.0014 | Val Loss: 0.0191

Epoch 13/

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'decoder.output_projection.weight', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/15
[Epoch 1] Train Loss: 0.2832 | Val Loss: 0.2599
** Best model saved! **

Epoch 2/15
[Epoch 2] Train Loss: 0.1665 | Val Loss: 9.9288

Epoch 3/15
[Epoch 3] Train Loss: 0.1137 | Val Loss: 11.6090

Epoch 4/15
[Epoch 4] Train Loss: 0.1045 | Val Loss: 13.1801

Epoch 5/15



KeyboardInterrupt



In [None]:
test_ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_test.txt')
test_images_dir        = os.path.join(base_dir, 'balinese_word_test')

test_filenames = []
test_labels    = []

def test_trocr_model(model_path="trocr_balinese", processor_path=None):
    """
    Test a trained TrOCR model on test data
    
    Args:
        model_path: Path to the saved model
        processor_path: Path to the processor (if different from model_path)
    """
    from transformers import TrOCRProcessor, VisionEncoderDecoderModel
    
    # If processor_path is not specified, try using the original processor
    if processor_path is None:
        processor_path = "microsoft/trocr-small-handwritten"
    
    # Load test data
    test_filenames = []
    test_labels = []
    
    with open(test_ground_truth_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if line:
                parts = line.split(';')
                if len(parts) == 2:
                    filename, label = parts
                    label = label.lower()
                    test_filenames.append(filename)
                    test_labels.append(label)
                else:
                    print(f"Skipping malformed line: {line}")
    
    test_data = pd.DataFrame({
        'filename': test_filenames,
        'label': test_labels
    })
    
    print(f"Loaded {len(test_data)} test image-text pairs")
    
    try:
        # Load model and processor (from separate paths if needed)
        print(f"Loading model from {model_path}")
        model = VisionEncoderDecoderModel.from_pretrained(model_path)
        
        print(f"Loading processor from {processor_path}")
        processor = TrOCRProcessor.from_pretrained(processor_path)
        
        # Move model to device
        model = model.to(device)
        
        # Results for storing predictions
        results = []
        batch_size = 16
        
        # Process test data in batches
        model.eval()
        with torch.no_grad():
            for i in range(0, len(test_data), batch_size):
                batch_data = test_data.iloc[i:i+batch_size]
                
                # Process images
                images = []
                for idx, row in batch_data.iterrows():
                    img_path = os.path.join(test_images_dir, row['filename'])
                    image = Image.open(img_path).convert('RGB')
                    images.append(image)
                
                # Get model predictions
                pixel_values = processor.image_processor(images, return_tensors="pt").pixel_values.to(device)
                generated_ids = model.generate(pixel_values)
                predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)
                
                # Store results
                for j, (pred, (_, row)) in enumerate(zip(predictions, batch_data.iterrows())):
                    results.append({
                        'image_filename': row['filename'],
                        'predicted_caption': pred,
                        'ground_truth_caption': row['label']
                    })
                
                # print(f"Processed {min(i+batch_size, len(test_data))}/{len(test_data)} test samples")
        
        # Calculate CER
        cer = calculate_global_cer(results)
        print(f"TrOCR Model Test CER: {cer:.4f}")
        
        # Print worst samples
        print_top_worst_samples(results, n=5)
        
        # Log to CSV
        model_name = os.path.basename(model_path)
        log_test_cer(model_name, cer)
        
        return cer, results
    
    except Exception as e:
        print(f"Error testing TrOCR model: {e}")
        import traceback
        traceback.print_exc()
        return None, None
    finally:
        # Clean up
        if 'model' in locals():
            del model
        if 'processor' in locals():
            del processor
        torch.cuda.empty_cache()
        print("Memory cleared")

In [None]:
def calculate_global_cer(results):
    total_ed   = 0
    total_refs = 0
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        total_ed   += dist
        total_refs += len(ref)
    if total_refs == 0:
        return 0.0
    return total_ed / total_refs

In [None]:
def print_top_worst_samples(results, n=5):
    # Calculate CER for each sample
    results_with_cer = []
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        length = len(ref)
        cer = dist / length if length > 0 else 0
        # Copy the record and add cer
        new_r = r.copy()
        new_r['cer'] = cer
        results_with_cer.append(new_r)

    # Sort by CER (descending) and take the top N
    results_with_cer.sort(key=lambda x: x['cer'], reverse=True)
    worst_samples = results_with_cer[:n]

    print(f"\n=== Top {n} Worst Samples by CER ===")
    for i, sample in enumerate(worst_samples, start=1):
        print(f"{i}) Image: {sample['image_filename']}")
        print(f"   CER: {sample['cer']:.4f}")
        print(f"   Predicted       : {sample['predicted_caption']}")
        print(f"   Ground Truth    : {sample['ground_truth_caption']}")
        print()

In [None]:
csv_file = "test_cer_results.csv"

def log_test_cer(model_name, cer_value):
    df = pd.read_csv(csv_file)
    # Check if model_name exists
    if model_name in df['model_name'].values:
        # Update existing row
        df.loc[df['model_name'] == model_name, 'test_cer'] = cer_value
    else:
        # Add new row - use concat instead of append
        new_row = pd.DataFrame({"model_name": [model_name], "test_cer": [cer_value]})
        df = pd.concat([df, new_row], ignore_index=True)
    
    df.to_csv(csv_file, index=False)
    print(f"Logged {model_name}: {cer_value:.4f}")

In [None]:
cer, results = test_trocr_model("trocr_balinese_best", "microsoft/trocr-small-handwritten")