In [6]:
# !pip install transformers datasets peft accelerate

In [7]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

In [8]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
GPU device: NVIDIA GeForce RTX 3080 Laptop GPU


In [11]:
import os
import pandas as pd
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import (
    TrOCRProcessor, 
    VisionEncoderDecoderModel,
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments
)
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import logging
import warnings
import random
from tqdm import tqdm

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Paths to data
DATA_DIR = "data/handwritten"
RO_CSV_DIR= os.path.join(DATA_DIR, "handwriting_dataset_ro_splits")
EN_CSV_DIR = os.path.join(DATA_DIR, "iam")
RO_TRAIN_CSV = os.path.join(RO_CSV_DIR, "ro_train.csv")
RO_VAL_CSV = os.path.join(RO_CSV_DIR, "ro_val.csv")
RO_TEST_CSV = os.path.join(RO_CSV_DIR, "ro_test.csv")
EN_TRAIN_CSV = os.path.join(EN_CSV_DIR, "train.csv")
EN_VAL_CSV = os.path.join(EN_CSV_DIR, "validation.csv")
EN_TEST_CSV = os.path.join(EN_CSV_DIR, "test.csv")
MAX_SAMPLES_PER_LANGUAGE = 1000  # Set to None for full dataset

# Model configuration
MODEL_NAME = "microsoft/trocr-base-handwritten"
OUTPUT_DIR = "trocr_lora_bilingual_improved"

def normalize_romanian_text(text):
    """Normalize Romanian text to standardize diacritics"""
    replacements = {
        'ş': 'ș', 'ţ': 'ț',  
        'ã': 'ă', 'å': 'ă', 
        'â': 'ă',
        'î': 'ă',
    }
    
    for old, new in replacements.items():
        text = text.replace(old, new)
    
    return text

# Image preprocessing for Romanian handwriting
def preprocess_image_for_romanian(image):
    """Apply preprocessing optimized for Romanian handwriting"""
    gray = image.convert('L')
    
    from PIL import ImageEnhance
    enhancer = ImageEnhance.Contrast(gray)
    enhanced = enhancer.enhance(1.5) 
    
    return enhanced.convert('RGB')

logger.info("Loading processor and model...")
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

model.to(device)

original_encoder_forward = model.encoder.forward

def safe_encoder_forward(self, pixel_values=None, **kwargs):
    return original_encoder_forward(pixel_values=pixel_values)

model.encoder.forward = lambda pixel_values=None, **kwargs: safe_encoder_forward(model.encoder, pixel_values=pixel_values)

# Configure LoRA with more capacity for cross-lingual learning
logger.info("Configuring LoRA...")
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=16,  
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
    bias="none",
)

model = get_peft_model(model, peft_config)
logger.info(f"Trainable parameters: {model.print_trainable_parameters()}")

class OCRDatasetWithLang(Dataset):
    def __init__(self, csv_file, processor, language_id, max_target_length=128, image_dir=None):
        self.df = pd.read_csv(csv_file)
        self.processor = processor
        self.max_target_length = max_target_length
        self.image_dir = image_dir
        self.language_id = language_id
        
        assert "image" in self.df.columns, f"'image' column missing in {csv_file}"
        assert "text" in self.df.columns, f"'text' column missing in {csv_file}"
        
        self.df = self.df.dropna(subset=["image", "text"])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]["image"]
        text = str(self.df.iloc[idx]["text"])
        
        if self.language_id == 'ro':
            text = normalize_romanian_text(text)
        
        if self.image_dir and not os.path.isabs(img_path):
            img_path = os.path.join(self.image_dir, img_path)
        
        try:
            image = Image.open(img_path).convert("RGB")
            
            if self.language_id == 'ro':
                image = preprocess_image_for_romanian(image)
                
            pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0)
            
            labels = self.processor.tokenizer(
                text, 
                padding="max_length", 
                max_length=self.max_target_length,
                truncation=True
            ).input_ids
            
            return {
                "pixel_values": pixel_values,
                "labels": torch.tensor(labels),
                "language_id": self.language_id
            }
        except Exception as e:
            logger.error(f"Error processing {img_path}: {e}")
            return None

class CombinedOCRDataset(Dataset):
    def __init__(self, ro_dataset, en_dataset):
        self.ro_dataset = ro_dataset
        self.en_dataset = en_dataset
        self.total_len = len(ro_dataset) + len(en_dataset)
    
    def __len__(self):
        return self.total_len
    
    def __getitem__(self, idx):
        if idx < len(self.ro_dataset):
            return self.ro_dataset[idx]
        else:
            return self.en_dataset[idx - len(self.ro_dataset)]

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    
    has_lang_id = all('language_id' in item for item in batch)
    
    result = {
        "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }
    
    if has_lang_id:
        result['language_id'] = [item['language_id'] for item in batch]
    
    return result

# Create datasets
logger.info("Creating datasets...")
IMAGE_DIR = None  

train_ro_dataset = OCRDatasetWithLang(RO_TRAIN_CSV, processor, 'ro', image_dir=IMAGE_DIR)
val_ro_dataset = OCRDatasetWithLang(RO_VAL_CSV, processor, 'ro', image_dir=IMAGE_DIR)
test_ro_dataset = OCRDatasetWithLang(RO_TEST_CSV, processor, 'ro', image_dir=IMAGE_DIR)

train_en_dataset = OCRDatasetWithLang(EN_TRAIN_CSV, processor, 'en', image_dir=IMAGE_DIR)
val_en_dataset = OCRDatasetWithLang(EN_VAL_CSV, processor, 'en', image_dir=IMAGE_DIR)
test_en_dataset = OCRDatasetWithLang(EN_TEST_CSV, processor, 'en', image_dir=IMAGE_DIR)

if MAX_SAMPLES_PER_LANGUAGE is not None:
    train_ro_dataset.df = train_ro_dataset.df.head(MAX_SAMPLES_PER_LANGUAGE)
    train_en_dataset.df = train_en_dataset.df.head(MAX_SAMPLES_PER_LANGUAGE)
    # Also limit validation sets
    val_limit = max(50, int(MAX_SAMPLES_PER_LANGUAGE * 0.1))
    val_ro_dataset.df = val_ro_dataset.df.head(val_limit)
    val_en_dataset.df = val_en_dataset.df.head(val_limit)

logger.info(f"Romanian training samples: {len(train_ro_dataset)}")
logger.info(f"English training samples: {len(train_en_dataset)}")

# Combine datasets using our custom combiner
combined_train_dataset = CombinedOCRDataset(train_ro_dataset, train_en_dataset)
combined_val_dataset = CombinedOCRDataset(val_ro_dataset, val_en_dataset)

logger.info(f"Combined training samples: {len(combined_train_dataset)}")
logger.info(f"Combined validation samples: {len(combined_val_dataset)}")

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=50,  # more frequent logging
    learning_rate=2e-5,
    per_device_train_batch_size=2, 
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),  
    fp16_opt_level="O1",  
    fp16_full_eval=False,  
    report_to="none",
    save_strategy="steps",
    save_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="cer",
    greater_is_better=False,
    max_grad_norm=1.0,
    gradient_accumulation_steps=4,
)

def compute_metrics(pred):
    try:
        from evaluate import load
        cer_metric = load("cer")
        wer_metric = load("wer")
    except (ImportError, ModuleNotFoundError):
        logger.warning("Could not load metrics.")
        return {"cer": 0.0, "wer": 0.0}
    
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    pred_ids_fixed = np.where(pred_ids < 0, processor.tokenizer.pad_token_id, pred_ids)
    
    try:
        pred_str = processor.batch_decode(pred_ids_fixed, skip_special_tokens=True)
        
        labels_ids[labels_ids < 0] = processor.tokenizer.pad_token_id
        label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
        
        cer = cer_metric.compute(predictions=pred_str, references=label_str)
        wer = wer_metric.compute(predictions=pred_str, references=label_str)
        
        return {"cer": cer, "wer": wer}
    except Exception as e:
        logger.error(f"Error computing metrics: {e}")
        return {"cer": 1.0, "wer": 1.0} 

class BilingualTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        model_inputs = inputs.copy()
        
        language_ids = None
        if 'language_id' in model_inputs:
            language_ids = model_inputs.pop('language_id')
        
        outputs = model(**model_inputs)
        loss = outputs.loss
        
        # Apply weight to Romanian samples if language IDs are available
        if language_ids is not None:
            # If there are any Romanian samples, increase the loss weight
            if 'ro' in language_ids:
                ro_sample_weight = 2.5
                ro_ratio = language_ids.count('ro') / len(language_ids)
                weighted_factor = 1.0 + (ro_ratio * (ro_sample_weight - 1.0))
                loss = loss * weighted_factor
                logger.info(f"Batch with {ro_ratio*100:.1f}% Romanian samples, loss weight: {weighted_factor:.2f}")
        
        return (loss, outputs) if return_outputs else loss

trainer = BilingualTrainer(
    model=model,
    args=training_args,
    train_dataset=combined_train_dataset,
    eval_dataset=combined_val_dataset,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)
def evaluate_model(model, test_dataset, language):
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=4,  
        collate_fn=collate_fn,
        shuffle=False
    )
    
    model.eval()
    
    try:
        from evaluate import load
        cer_metric = load("cer")
        wer_metric = load("wer")
    except ImportError:
        logger.error("Could not load metrics. Please install jiwer: pip install jiwer")
        return 1.0, 1.0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc=f"Evaluating {language}"):
            if batch is None:
                continue
                
            eval_batch = batch.copy()
            if 'language_id' in eval_batch:
                del eval_batch['language_id']
                
            pixel_values = eval_batch["pixel_values"].to(device)
            labels = eval_batch["labels"]
            
            try:
                generated_ids = model.generate(pixel_values=pixel_values, max_length=64)
                
                pred_str = processor.batch_decode(generated_ids, skip_special_tokens=True)
                label_str = processor.batch_decode(labels, skip_special_tokens=True)
                
                all_preds.extend(pred_str)
                all_labels.extend(label_str)
            except Exception as e:
                logger.error(f"Error during evaluation: {e}")
                continue
    
    if all_preds:
        cer = cer_metric.compute(predictions=all_preds, references=all_labels)
        wer = wer_metric.compute(predictions=all_preds, references=all_labels)
        
        logger.info(f"{language} Test CER: {cer:.4f}")
        logger.info(f"{language} Test WER: {wer:.4f}")
        
        return cer, wer
    else:
        logger.error(f"No predictions were generated for {language}")
        return 1.0, 1.0

# Evaluate and log the results
# if os.path.exists(OUTPUT_DIR):
#     logger.info("Evaluating on test sets...")
#     ro_cer, ro_wer = evaluate_model(model, test_ro_dataset, "Romanian")
#     en_cer, en_wer = evaluate_model(model, test_en_dataset, "English")
# 
#     logger.info("Training and evaluation complete!")
#     logger.info(f"Final Romanian CER: {ro_cer:.4f}, WER: {ro_wer:.4f}")
#     logger.info(f"Final English CER: {en_cer:.4f}, WER: {en_wer:.4f}")

INFO:__main__:Using device: cuda
INFO:__main__:Loading processor and model...
Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "pooler_act": "tanh",
  "pooler_output_size": 768,
  "qkv_bias": false,
  "torch_dtype": "float32",
  "transformers_version": "4.51.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 

trainable params: 5,013,504 || all params: 338,935,296 || trainable%: 1.4792


In [17]:
def load_and_use_model(model_path):
    """Load a saved LoRA model and run inference"""
    # Ensure we're using a local path
    model_path = os.path.abspath(model_path)
    
    processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
    base_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
    
    # Set required token IDs
    base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    base_model.config.pad_token_id = processor.tokenizer.pad_token_id
    
    original_encoder_forward = base_model.encoder.forward
    def safe_encoder_forward(self, pixel_values=None, **kwargs):
        return original_encoder_forward(pixel_values=pixel_values)
    base_model.encoder.forward = lambda pixel_values=None, **kwargs: safe_encoder_forward(base_model.encoder, pixel_values=pixel_values)
    
    model = PeftModel.from_pretrained(
        base_model, 
        model_path,
        local_files_only=True  
    )
    model.to(device)
    
    return model, processor

def ocr_inference(model, processor, image_path, is_romanian=False):
    """Perform OCR inference on a single image"""
    # Load and optionally preprocess the image
    image = Image.open(image_path).convert("RGB")
    
    # Apply Romanian-specific preprocessing if needed
    if is_romanian:
        image = preprocess_image_for_romanian(image)
    
    # Get model predictions
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    generated_ids = model.generate(pixel_values=pixel_values, max_length=64)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # Apply Romanian normalization if needed
    if is_romanian:
        generated_text = normalize_romanian_text(generated_text)
    
    return generated_text



In [22]:
model, processor = load_and_use_model("models/trocr_lora_bilingual_improved")
text = ocr_inference(model, processor, "img.png", is_romanian=True)
print(f"OCR Result: {text}")

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "pooler_act": "tanh",
  "pooler_output_size": 768,
  "qkv_bias": false,
  "torch_dtype": "float32",
  "transformers_version": "4.51.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_mod

OCR Result: Stone- Grea


Inca nu functioneaza bine trebuie sa adaug la db scris real.