# Custom OCR Model Training Notebook

This notebook demonstrates how to train a custom OCR model using deep learning.
We'll build a CNN-LSTM architecture with CTC loss for end-to-end text recognition.

## 🎯 Learning Objectives:
- Understand deep learning-based OCR architecture
- Learn about CTC (Connectionist Temporal Classification) loss
- Train a custom OCR model from scratch
- Evaluate model performance with proper metrics
- Compare custom model with pretrained solutions

## 🏗️ Model Architecture:
1. **CNN Feature Extractor**: Extracts visual features from images
2. **LSTM Sequence Processor**: Models sequential dependencies
3. **CTC Head**: Handles variable-length sequences without alignment

## 📋 Prerequisites:
- PyTorch and related libraries installed
- Training dataset prepared
- GPU recommended for faster training

## 1. Import Required Libraries

First, let's import all necessary libraries for deep learning and OCR training.

In [None]:
# Standard libraries
import os
import sys
import time
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Data handling
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# Progress tracking
from tqdm.notebook import tqdm

# Add project root to path
project_root = Path.cwd().parent
sys.path.append(str(project_root))

# Import custom modules
from scripts.custom_model import create_model, count_parameters
from utils.dataset import CharacterMapping, OCRDataset, ctc_collate_fn, create_sample_dataset
from utils.metrics import calculate_detailed_metrics, print_metrics_report, MetricsTracker

print("✅ All libraries imported successfully!")
print(f"📁 Project root: {project_root}")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"💻 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration and Setup

Let's define our training configuration and set up paths.

In [None]:
# Training configuration
config = {
    # Data paths
    'train_csv': project_root / 'data' / 'train' / 'dataset.csv',
    'val_csv': project_root / 'data' / 'val' / 'dataset.csv',
    'train_dir': project_root / 'data' / 'train',
    'val_dir': project_root / 'data' / 'val',
    
    # Model parameters
    'img_height': 32,
    'img_width': 128,
    'lstm_hidden_size': 256,
    'lstm_num_layers': 2,
    
    # Training parameters
    'num_epochs': 50,
    'batch_size': 16,  # Smaller batch size for notebook
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'optimizer': 'adam',
    
    # Other parameters
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# Create output directories
output_dirs = {
    'models': project_root / 'models',
    'checkpoints': project_root / 'models' / 'checkpoints',
    'logs': project_root / 'models' / 'logs',
    'results': project_root / 'results'
}

for name, path in output_dirs.items():
    path.mkdir(parents=True, exist_ok=True)
    print(f"📁 {name.capitalize()} directory: {path}")

device = torch.device(config['device'])
print(f"\n🔧 Using device: {device}")

## 3. Create Sample Dataset (if needed)

If you don't have a training dataset yet, let's create a sample one for demonstration.

In [None]:
# Check if training data exists
if not config['train_csv'].exists():
    print("📊 Creating sample training dataset...")
    create_sample_dataset(str(config['train_dir']), num_samples=200)
    print("✅ Sample training dataset created!")
else:
    print("✅ Training dataset found!")

if not config['val_csv'].exists():
    print("📊 Creating sample validation dataset...")
    create_sample_dataset(str(config['val_dir']), num_samples=50)
    print("✅ Sample validation dataset created!")
else:
    print("✅ Validation dataset found!")

# Load and inspect datasets
train_df = pd.read_csv(config['train_csv'])
val_df = pd.read_csv(config['val_csv'])

print(f"\n📈 Dataset Statistics:")
print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"\n📋 Sample training data:")
print(train_df.head())

## 4. Initialize Character Mapping and Model

Create character mapping and initialize our custom OCR model.

In [None]:
# Initialize character mapping
char_mapping = CharacterMapping()
print(f"📝 Character mapping created with {char_mapping.num_classes} classes")
print(f"Characters: {char_mapping.characters}")

# Test character mapping
test_text = "hello world 123"
encoded = char_mapping.encode(test_text)
decoded = char_mapping.decode(encoded)
print(f"\n🧪 Character mapping test:")
print(f"Original: '{test_text}'")
print(f"Encoded: {encoded}")
print(f"Decoded: '{decoded}'")

# Create model
model = create_model(
    num_classes=char_mapping.num_classes,
    img_height=config['img_height'],
    img_width=config['img_width'],
    lstm_hidden_size=config['lstm_hidden_size'],
    lstm_num_layers=config['lstm_num_layers']
)

model = model.to(device)
print(f"\n🧠 Model created with {count_parameters(model):,} trainable parameters")

# Test model with dummy input
dummy_input = torch.randn(2, 1, config['img_height'], config['img_width']).to(device)
with torch.no_grad():
    dummy_output = model(dummy_input)
print(f"🔍 Model test - Input: {dummy_input.shape}, Output: {dummy_output.shape}")

## 5. Create Data Loaders

Set up data loaders for training and validation.

In [None]:
# Create datasets
train_dataset = OCRDataset(
    csv_file=str(config['train_csv']),
    image_dir=str(config['train_dir']),
    char_mapping=char_mapping,
    img_height=config['img_height'],
    img_width=config['img_width'],
    is_training=True
)

val_dataset = OCRDataset(
    csv_file=str(config['val_csv']),
    image_dir=str(config['val_dir']),
    char_mapping=char_mapping,
    img_height=config['img_height'],
    img_width=config['img_width'],
    is_training=False
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    collate_fn=ctc_collate_fn,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers'],
    collate_fn=ctc_collate_fn,
    pin_memory=True
)

print(f"📦 Data loaders created:")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Visualize sample batch
sample_batch = next(iter(train_loader))
print(f"\n🔍 Sample batch:")
print(f"Images shape: {sample_batch['images'].shape}")
print(f"Labels shape: {sample_batch['labels'].shape}")
print(f"Sample texts: {sample_batch['texts'][:3]}")

## 6. Visualize Sample Data

Let's look at some sample images from our dataset.

In [None]:
# Visualize sample images
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

sample_batch = next(iter(train_loader))
images = sample_batch['images']
texts = sample_batch['texts']

for i in range(min(6, len(images))):
    # Convert tensor to numpy and denormalize
    img = images[i].squeeze().numpy()
    img = (img * 0.5) + 0.5  # Denormalize from [-1,1] to [0,1]
    
    axes[i].imshow(img, cmap='gray')
    axes[i].set_title(f"Text: '{texts[i]}'")
    axes[i].axis('off')

plt.tight_layout()
plt.suptitle('Sample Training Images', fontsize=16, y=1.02)
plt.show()

# Show dataset statistics
text_lengths = [len(text) for text in train_df['label']]
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.hist(text_lengths, bins=20, alpha=0.7)
plt.title('Distribution of Text Lengths')
plt.xlabel('Number of Characters')
plt.ylabel('Frequency')

plt.subplot(1, 2, 2)
char_freq = {}
for text in train_df['label']:
    for char in text.lower():
        char_freq[char] = char_freq.get(char, 0) + 1

chars = list(char_freq.keys())[:20]  # Show top 20 characters
freqs = [char_freq[char] for char in chars]
plt.bar(chars, freqs)
plt.title('Character Frequency (Top 20)')
plt.xlabel('Characters')
plt.ylabel('Frequency')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

## 7. Initialize Training Components

Set up optimizer, loss function, and metrics tracking.

In [None]:
# Initialize optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

# Initialize learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=20,
    gamma=0.5
)

# Initialize CTC loss function
criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

# Initialize metrics tracking
metrics_tracker = MetricsTracker()
train_losses = []
val_losses = []
val_accuracies = []

# Initialize TensorBoard logging (optional)
log_dir = output_dirs['logs'] / f"experiment_{int(time.time())}"
writer = SummaryWriter(log_dir=str(log_dir))

print(f"🔧 Training components initialized:")
print(f"Optimizer: {config['optimizer']}")
print(f"Learning rate: {config['learning_rate']}")
print(f"Batch size: {config['batch_size']}")
print(f"Number of epochs: {config['num_epochs']}")
print(f"📊 TensorBoard logs: {log_dir}")

## 8. Training Loop

Now let's train our custom OCR model!

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device, epoch):
    """Train model for one epoch."""
    model.train()
    epoch_loss = 0.0
    num_batches = len(train_loader)
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    
    for batch_idx, batch in enumerate(pbar):
        # Move data to device
        images = batch['images'].to(device)
        labels = batch['labels'].to(device)
        label_lengths = batch['label_lengths'].to(device)
        input_lengths = batch['input_lengths'].to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        outputs = outputs.transpose(0, 1)  # [seq_len, batch_size, num_classes]
        
        # Calculate loss
        loss = criterion(outputs, labels, input_lengths, label_lengths)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update metrics
        epoch_loss += loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Avg': f'{epoch_loss/(batch_idx+1):.4f}'
        })
    
    return epoch_loss / num_batches


def validate_epoch(model, val_loader, criterion, char_mapping, device):
    """Validate model for one epoch."""
    model.eval()
    epoch_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch['images'].to(device)
            labels = batch['labels'].to(device)
            label_lengths = batch['label_lengths'].to(device)
            input_lengths = batch['input_lengths'].to(device)
            texts = batch['texts']
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            outputs_for_loss = outputs.transpose(0, 1)
            loss = criterion(outputs_for_loss, labels, input_lengths, label_lengths)
            epoch_loss += loss.item()
            
            # Get predictions
            predictions = model.predict(images)
            
            # Decode predictions
            for i, pred in enumerate(predictions):
                pred_text = char_mapping.ctc_decode(pred.cpu().numpy())
                target_text = texts[i]
                
                all_predictions.append(pred_text)
                all_targets.append(target_text)
    
    avg_loss = epoch_loss / len(val_loader)
    metrics = calculate_detailed_metrics(all_predictions, all_targets)
    
    return avg_loss, metrics, all_predictions, all_targets


# Training loop
print("🚀 Starting training...")
best_val_accuracy = 0.0
start_time = time.time()

for epoch in range(config['num_epochs']):
    # Training phase
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, epoch)
    train_losses.append(train_loss)
    
    # Validation phase
    val_loss, val_metrics, val_predictions, val_targets = validate_epoch(
        model, val_loader, criterion, char_mapping, device
    )
    val_losses.append(val_loss)
    val_accuracies.append(val_metrics['accuracy'])
    
    # Update metrics tracker
    metrics_tracker.update(epoch, val_predictions, val_targets)
    
    # Update learning rate
    scheduler.step()
    
    # Log to TensorBoard
    writer.add_scalar('Training/Loss', train_loss, epoch)
    writer.add_scalar('Validation/Loss', val_loss, epoch)
    writer.add_scalar('Validation/Accuracy', val_metrics['accuracy'], epoch)
    writer.add_scalar('Validation/CER', val_metrics['cer'], epoch)
    writer.add_scalar('Validation/WER', val_metrics['wer'], epoch)
    
    # Save best model
    if val_metrics['accuracy'] > best_val_accuracy:
        best_val_accuracy = val_metrics['accuracy']
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_accuracy': best_val_accuracy,
            'config': config,
            'char_mapping': char_mapping
        }
        
        torch.save(checkpoint, output_dirs['checkpoints'] / 'best_model.pth')
        print(f"💾 Saved best model with accuracy: {best_val_accuracy:.4f}")
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{config['num_epochs']}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  Val Accuracy: {val_metrics['accuracy']:.4f}")
    print(f"  CER: {val_metrics['cer']:.4f}")
    print(f"  WER: {val_metrics['wer']:.4f}")
    print(f"  Best Accuracy: {best_val_accuracy:.4f}")
    print("-" * 50)

# Training completed
total_time = time.time() - start_time
print(f"\n🎉 Training completed in {total_time/60:.2f} minutes")
print(f"📊 Best validation accuracy: {best_val_accuracy:.4f}")

writer.close()

## 9. Visualize Training Progress

Let's plot the training and validation metrics.

In [None]:
# Plot training progress
epochs = range(1, len(train_losses) + 1)

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss plot
axes[0, 0].plot(epochs, train_losses, 'b-', label='Training Loss')
axes[0, 0].plot(epochs, val_losses, 'r-', label='Validation Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Accuracy plot
axes[0, 1].plot(epochs, val_accuracies, 'g-', label='Validation Accuracy')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Learning rate plot
lr_history = []
for epoch in range(len(train_losses)):
    if epoch == 0:
        lr_history.append(config['learning_rate'])
    elif epoch % 20 == 0:
        lr_history.append(lr_history[-1] * 0.5)
    else:
        lr_history.append(lr_history[-1])

axes[1, 0].plot(epochs, lr_history, 'm-', label='Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].legend()
axes[1, 0].grid(True)
axes[1, 0].set_yscale('log')

# Best metrics summary
best_metrics = metrics_tracker.get_best_metrics()
metrics_text = f"""Best Model Performance:
Accuracy: {best_metrics['accuracy']:.4f}
CER: {best_metrics['cer']:.4f}
WER: {best_metrics['wer']:.4f}
BLEU: {best_metrics['bleu_score']:.4f}
Epoch: {best_metrics['epoch']}

Training Statistics:
Total Epochs: {len(train_losses)}
Training Time: {total_time/60:.1f} min
Final Train Loss: {train_losses[-1]:.4f}
Final Val Loss: {val_losses[-1]:.4f}"""

axes[1, 1].text(0.1, 0.5, metrics_text, transform=axes[1, 1].transAxes, 
                fontsize=12, verticalalignment='center',
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
axes[1, 1].set_title('Training Summary')
axes[1, 1].axis('off')

plt.tight_layout()
plt.savefig(output_dirs['results'] / 'training_progress.png', dpi=300, bbox_inches='tight')
plt.show()

# Print detailed final metrics
print_metrics_report(best_metrics, "Final Model Evaluation")

## 10. Test Model Predictions

Let's test our trained model on some validation samples.

In [None]:
# Load best model
best_checkpoint = torch.load(output_dirs['checkpoints'] / 'best_model.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
model.eval()

# Get sample batch from validation set
sample_batch = next(iter(val_loader))
images = sample_batch['images'].to(device)
texts = sample_batch['texts']

# Get predictions
with torch.no_grad():
    predictions = model.predict(images)

# Decode predictions
pred_texts = []
for pred in predictions:
    pred_text = char_mapping.ctc_decode(pred.cpu().numpy())
    pred_texts.append(pred_text)

# Visualize predictions
fig, axes = plt.subplots(3, 2, figsize=(15, 12))
axes = axes.flatten()

for i in range(min(6, len(images))):
    # Convert tensor to numpy and denormalize
    img = images[i].squeeze().cpu().numpy()
    img = (img * 0.5) + 0.5  # Denormalize
    
    # Check if prediction is correct
    is_correct = pred_texts[i].lower().strip() == texts[i].lower().strip()
    color = 'green' if is_correct else 'red'
    status = '✓' if is_correct else '✗'
    
    axes[i].imshow(img, cmap='gray')
    axes[i].set_title(f"{status} Target: '{texts[i]}'\nPrediction: '{pred_texts[i]}'", 
                     color=color, fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.suptitle('Model Predictions on Validation Data', fontsize=16, y=1.02)
plt.savefig(output_dirs['results'] / 'sample_predictions.png', dpi=300, bbox_inches='tight')
plt.show()

# Calculate accuracy for this batch
correct = sum(1 for pred, target in zip(pred_texts, texts) 
              if pred.lower().strip() == target.lower().strip())
batch_accuracy = correct / len(pred_texts)
print(f"\n📊 Batch accuracy: {batch_accuracy:.4f} ({correct}/{len(pred_texts)})")

# Show some example predictions
print(f"\n📝 Example predictions:")
for i in range(min(5, len(pred_texts))):
    status = '✓' if pred_texts[i].lower().strip() == texts[i].lower().strip() else '✗'
    print(f"  {status} Target: '{texts[i]}' → Prediction: '{pred_texts[i]}'")

## 11. Compare with Pretrained Models

Let's compare our custom model with EasyOCR and Pytesseract on the same validation data.

In [None]:
# Import comparison modules (from original OCR project)
try:
    import easyocr
    import pytesseract
    from PIL import Image as PILImage
    
    # Initialize EasyOCR
    print("🔄 Initializing EasyOCR...")
    easy_reader = easyocr.Reader(['en'])
    
    # Get validation images for comparison
    val_image_paths = []
    val_target_texts = []
    
    for idx in range(min(20, len(val_dataset))):  # Test on first 20 images
        row = val_df.iloc[idx]
        image_path = config['val_dir'] / row['imagename']
        if image_path.exists():
            val_image_paths.append(str(image_path))
            val_target_texts.append(row['label'])
    
    # Get predictions from all three models
    print(f"\n🔍 Comparing models on {len(val_image_paths)} validation images...")
    
    # Custom model predictions
    custom_predictions = []
    for img_path in tqdm(val_image_paths, desc="Custom model"):
        # Preprocess image
        img = PILImage.open(img_path).convert('L')
        img_array = np.array(img)
        transformed = train_dataset.transforms(image=img_array)
        img_tensor = transformed['image'].unsqueeze(0).to(device)
        
        # Predict
        with torch.no_grad():
            pred = model.predict(img_tensor)
            pred_text = char_mapping.ctc_decode(pred[0].cpu().numpy())
            custom_predictions.append(pred_text)
    
    # EasyOCR predictions
    easyocr_predictions = []
    for img_path in tqdm(val_image_paths, desc="EasyOCR"):
        try:
            results = easy_reader.readtext(img_path)
            text = ' '.join([result[1] for result in results])
            easyocr_predictions.append(text.strip())
        except:
            easyocr_predictions.append("")
    
    # Pytesseract predictions
    pytesseract_predictions = []
    for img_path in tqdm(val_image_paths, desc="Pytesseract"):
        try:
            img = PILImage.open(img_path)
            text = pytesseract.image_to_string(img, lang='eng')
            pytesseract_predictions.append(text.strip())
        except:
            pytesseract_predictions.append("")
    
    # Calculate metrics for all models
    custom_metrics = calculate_detailed_metrics(custom_predictions, val_target_texts)
    easyocr_metrics = calculate_detailed_metrics(easyocr_predictions, val_target_texts)
    pytesseract_metrics = calculate_detailed_metrics(pytesseract_predictions, val_target_texts)
    
    # Create comparison table
    comparison_df = pd.DataFrame({
        'Model': ['Custom CRNN', 'EasyOCR', 'Pytesseract'],
        'Accuracy': [custom_metrics['accuracy'], easyocr_metrics['accuracy'], pytesseract_metrics['accuracy']],
        'CER': [custom_metrics['cer'], easyocr_metrics['cer'], pytesseract_metrics['cer']],
        'WER': [custom_metrics['wer'], easyocr_metrics['wer'], pytesseract_metrics['wer']],
        'BLEU': [custom_metrics['bleu_score'], easyocr_metrics['bleu_score'], pytesseract_metrics['bleu_score']]
    })
    
    print("\n🏆 Model Comparison Results:")
    print("=" * 60)
    print(comparison_df.to_string(index=False, float_format='%.4f'))
    
    # Visualize comparison
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    metrics = ['Accuracy', 'CER', 'WER']
    for i, metric in enumerate(metrics):
        values = comparison_df[metric]
        bars = axes[i].bar(comparison_df['Model'], values, 
                          color=['blue', 'green', 'orange'], alpha=0.7)
        axes[i].set_title(f'{metric} Comparison')
        axes[i].set_ylabel(metric)
        
        # Add value labels on bars
        for bar, value in zip(bars, values):
            axes[i].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{value:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(output_dirs['results'] / 'model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Save detailed comparison
    detailed_comparison = pd.DataFrame({
        'image_path': val_image_paths,
        'target': val_target_texts,
        'custom_prediction': custom_predictions,
        'easyocr_prediction': easyocr_predictions,
        'pytesseract_prediction': pytesseract_predictions
    })
    detailed_comparison.to_csv(output_dirs['results'] / 'detailed_comparison.csv', index=False)
    print(f"\n💾 Detailed comparison saved to: {output_dirs['results'] / 'detailed_comparison.csv'}")
    
except ImportError as e:
    print(f"⚠️ Cannot compare with pretrained models: {e}")
    print("Install easyocr and pytesseract to enable comparison")

## 12. Model Analysis and Insights

Let's analyze our model's performance and provide insights.

In [None]:
# Analyze model performance by text length
val_sample_metrics = best_metrics['sample_metrics']

# Group by text length
length_analysis = {}
for sample in val_sample_metrics:
    length = len(sample['target'])
    if length not in length_analysis:
        length_analysis[length] = {'correct': 0, 'total': 0, 'cer_sum': 0}
    
    length_analysis[length]['total'] += 1
    length_analysis[length]['cer_sum'] += sample['cer']
    if sample['match']:
        length_analysis[length]['correct'] += 1

# Calculate accuracy by length
lengths = sorted(length_analysis.keys())
accuracies_by_length = []
cer_by_length = []

for length in lengths:
    data = length_analysis[length]
    accuracy = data['correct'] / data['total']
    avg_cer = data['cer_sum'] / data['total']
    accuracies_by_length.append(accuracy)
    cer_by_length.append(avg_cer)

# Plot performance by text length
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].plot(lengths, accuracies_by_length, 'bo-', linewidth=2, markersize=8)
axes[0].set_title('Accuracy by Text Length')
axes[0].set_xlabel('Text Length (characters)')
axes[0].set_ylabel('Accuracy')
axes[0].grid(True, alpha=0.3)

axes[1].plot(lengths, cer_by_length, 'ro-', linewidth=2, markersize=8)
axes[1].set_title('Character Error Rate by Text Length')
axes[1].set_xlabel('Text Length (characters)')
axes[1].set_ylabel('CER')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dirs['results'] / 'performance_by_length.png', dpi=300, bbox_inches='tight')
plt.show()

# Character-level error analysis
char_errors = {}
char_total = {}

for sample in val_sample_metrics:
    target = sample['target'].lower()
    prediction = sample['prediction'].lower()
    
    # Simple character-by-character comparison
    max_len = max(len(target), len(prediction))
    target_padded = target.ljust(max_len)
    pred_padded = prediction.ljust(max_len)
    
    for i in range(max_len):
        if i < len(target):
            char = target[i]
            char_total[char] = char_total.get(char, 0) + 1
            
            if i >= len(prediction) or target[i] != prediction[i]:
                char_errors[char] = char_errors.get(char, 0) + 1

# Calculate character-level accuracy
char_accuracy = {}
for char in char_total:
    if char_total[char] > 5:  # Only consider characters that appear more than 5 times
        errors = char_errors.get(char, 0)
        char_accuracy[char] = 1 - (errors / char_total[char])

# Sort by accuracy
sorted_chars = sorted(char_accuracy.items(), key=lambda x: x[1])

print("\n📊 Character-level Analysis:")
print("=" * 40)
print("\n❌ Most Difficult Characters:")
for char, acc in sorted_chars[:5]:
    print(f"  '{char}': {acc:.3f} accuracy ({char_errors.get(char, 0)}/{char_total[char]} errors)")

print("\n✅ Most Accurate Characters:")
for char, acc in sorted_chars[-5:]:
    print(f"  '{char}': {acc:.3f} accuracy ({char_errors.get(char, 0)}/{char_total[char]} errors)")

# Model insights and recommendations
print("\n🔍 Model Insights and Recommendations:")
print("=" * 50)

avg_accuracy = best_metrics['accuracy']
avg_cer = best_metrics['cer']

if avg_accuracy > 0.9:
    print("✅ Excellent performance! Your custom model is working very well.")
elif avg_accuracy > 0.7:
    print("👍 Good performance! There's room for improvement.")
else:
    print("⚠️ Performance needs improvement. Consider the recommendations below.")

print(f"\n📈 Improvement Suggestions:")
if avg_cer > 0.1:
    print("• Increase training data size")
    print("• Add more data augmentation")
    print("• Experiment with different model architectures")

if len(lengths) > 1 and max(accuracies_by_length) - min(accuracies_by_length) > 0.3:
    print("• Model performance varies significantly with text length")
    print("• Consider training on more balanced text length distribution")

print("\n🚀 Next Steps:")
print("• Try transfer learning from pretrained models")
print("• Experiment with attention mechanisms")
print("• Use larger and more diverse datasets")
print("• Implement beam search decoding")
print("• Fine-tune on domain-specific data")

## 13. Save Final Model and Summary

Let's save our final model and create a comprehensive summary.

In [None]:
# Save final model in multiple formats
final_model_path = output_dirs['models'] / 'final_ocr_model.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'char_mapping': char_mapping,
    'training_history': {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    },
    'best_metrics': best_metrics,
    'total_training_time': total_time
}, final_model_path)

# Create training summary report
summary_report = f"""# Custom OCR Model Training Summary

## Model Configuration
- Architecture: CNN + LSTM + CTC
- Image Size: {config['img_height']}x{config['img_width']}
- LSTM Hidden Size: {config['lstm_hidden_size']}
- LSTM Layers: {config['lstm_num_layers']}
- Character Classes: {char_mapping.num_classes}
- Total Parameters: {count_parameters(model):,}

## Training Configuration
- Epochs: {config['num_epochs']}
- Batch Size: {config['batch_size']}
- Learning Rate: {config['learning_rate']}
- Optimizer: {config['optimizer']}
- Training Samples: {len(train_df)}
- Validation Samples: {len(val_df)}

## Performance Results
- Best Accuracy: {best_metrics['accuracy']:.4f}
- Character Error Rate (CER): {best_metrics['cer']:.4f}
- Word Error Rate (WER): {best_metrics['wer']:.4f}
- BLEU Score: {best_metrics['bleu_score']:.4f}
- Best Epoch: {best_metrics['epoch']}

## Training Statistics
- Total Training Time: {total_time/60:.2f} minutes
- Final Training Loss: {train_losses[-1]:.4f}
- Final Validation Loss: {val_losses[-1]:.4f}
- Device Used: {device}

## Model Files
- Best Model: models/checkpoints/best_model.pth
- Final Model: models/final_ocr_model.pth
- Training Logs: {log_dir}
- Results: results/

## Usage Instructions
To use the trained model for inference:

```python
# Load model
checkpoint = torch.load('models/best_model.pth')
model = create_model(checkpoint['config'])
model.load_state_dict(checkpoint['model_state_dict'])

# Or use the prediction script
python scripts/predict.py --model models/best_model.pth --image path/to/image.jpg
```

Generated on: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
"""

# Save summary report
with open(output_dirs['results'] / 'training_summary.md', 'w') as f:
    f.write(summary_report)

print("💾 Model and summary saved successfully!")
print(f"📄 Training summary: {output_dirs['results'] / 'training_summary.md'}")
print(f"🧠 Final model: {final_model_path}")
print(f"🏆 Best model: {output_dirs['checkpoints'] / 'best_model.pth'}")
print(f"📊 TensorBoard logs: {log_dir}")

# Display final summary
print("\n" + "="*60)
print("🎉 TRAINING COMPLETED SUCCESSFULLY! 🎉")
print("="*60)
print(f"📊 Final Results:")
print(f"   Best Accuracy: {best_metrics['accuracy']:.4f} ({best_metrics['accuracy']*100:.2f}%)")
print(f"   Character Error Rate: {best_metrics['cer']:.4f}")
print(f"   Word Error Rate: {best_metrics['wer']:.4f}")
print(f"   Training Time: {total_time/60:.2f} minutes")
print(f"\n🚀 Your custom OCR model is ready for use!")
print(f"📁 Check the results folder for detailed analysis and comparisons.")