# üßæ Vietnamese Receipt OCR Training Pipeline
## MC-OCR 2021 Dataset | CRNN Model | Mobile Deployment Ready

---

### üìã Notebook Overview

**M·ª•c ti√™u:** Train m√¥ h√¨nh OCR ƒë·ªÉ nh·∫≠n d·∫°ng vƒÉn b·∫£n ti·∫øng Vi·ªát t·ª´ h√≥a ƒë∆°n

**Dataset:** [Vietnamese Receipts MC-OCR 2021](https://www.kaggle.com/datasets/domixi1989/vietnamese-receipts-mc-ocr-2021)

**Ki·∫øn tr√∫c:** CRNN (CNN + LSTM + CTC Loss)

**Export:** ONNX v√† TFLite cho Flutter Mobile App

---

### üéØ C√°c b∆∞·ªõc th·ª±c hi·ªán:

1. ‚úÖ **C√†i ƒë·∫∑t th∆∞ vi·ªán** (PyTorch, Kaggle API, OpenCV)
2. ‚úÖ **Download dataset** t·ª´ Kaggle
3. ‚úÖ **Ph√¢n t√≠ch c·∫•u tr√∫c dataset** (·∫£nh + annotations)
4. ‚úÖ **Ti·ªÅn x·ª≠ l√Ω d·ªØ li·ªáu** (resize, normalize, augmentation)
5. ‚úÖ **X√¢y d·ª±ng m√¥ h√¨nh CRNN**
6. ‚úÖ **Training v·ªõi CTC Loss**
7. ‚úÖ **Validation & Evaluation**
8. ‚úÖ **Export sang ONNX/TFLite**
9. ‚úÖ **Demo inference**

---

**Author:** KHANH - FinTracker OCR Integration  
**Date:** November 17, 2025  
**Platform:** Google Colab with GPU (T4/V100)

## üì¶ PH·∫¶N 1: C√†i ƒê·∫∑t Th∆∞ Vi·ªán

C√†i ƒë·∫∑t c√°c th∆∞ vi·ªán c·∫ßn thi·∫øt cho training OCR model tr√™n Google Colab

In [None]:
# Cell 1: Ki·ªÉm tra GPU v√† c√†i ƒë·∫∑t th∆∞ vi·ªán c∆° b·∫£n
import sys
import subprocess

# Ki·ªÉm tra GPU
print("üîç Ki·ªÉm tra GPU...")
!nvidia-smi

# C√†i ƒë·∫∑t c√°c th∆∞ vi·ªán c·∫ßn thi·∫øt
print("\nüì¶ C√†i ƒë·∫∑t th∆∞ vi·ªán...")
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q opencv-python-headless pillow matplotlib tqdm
!pip install -q kaggle
!pip install -q python-Levenshtein editdistance
!pip install -q onnx onnxruntime tf2onnx tensorflow

print("‚úÖ Ho√†n t·∫•t c√†i ƒë·∫∑t th∆∞ vi·ªán!")

In [None]:
# Cell 2: Import c√°c th∆∞ vi·ªán ƒë√£ c√†i
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import json
import os
from pathlib import Path
import random
from tqdm import tqdm
import string
import re
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Ki·ªÉm tra PyTorch v√† CUDA
print(f"üî• PyTorch version: {torch.__version__}")
print(f"üéÆ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"üéÆ CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"üéÆ Current device: {torch.cuda.current_device()}")
    device = torch.device('cuda')
else:
    print("‚ö†Ô∏è Ch·∫°y tr√™n CPU (s·∫Ω ch·∫≠m h∆°n)")
    device = torch.device('cpu')

print(f"\n‚úÖ Device: {device}")

## üîë PH·∫¶N 2: K·∫øt N·ªëi Kaggle API & Download Dataset

### H∆∞·ªõng d·∫´n upload `kaggle.json`:

1. ƒêƒÉng nh·∫≠p Kaggle ‚Üí **Account** ‚Üí **Create New API Token**
2. Download file `kaggle.json`
3. Upload l√™n Colab b·∫±ng code cell b√™n d∆∞·ªõi
4. Dataset s·∫Ω t·ª± ƒë·ªông download v√† extract

In [None]:
# Cell 3: Upload kaggle.json t·ª´ m√°y t√≠nh
from google.colab import files

print("üìÅ Vui l√≤ng upload file kaggle.json...")
print("   (Download t·ª´: Kaggle ‚Üí Account ‚Üí Create New API Token)")
uploaded = files.upload()

# T·∫°o th∆∞ m·ª•c .kaggle v√† copy file
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("‚úÖ ƒê√£ c·∫•u h√¨nh Kaggle API!")

In [None]:
# Cell 4: Download dataset t·ª´ Kaggle
print("üì• Downloading Vietnamese Receipts MC-OCR 2021 dataset...")
print("   (Dataset size: ~1.5GB, c√≥ th·ªÉ m·∫•t v√†i ph√∫t)\n")

# T·∫°o th∆∞ m·ª•c data
!mkdir -p /content/data

# Download dataset
!kaggle datasets download -d domixi1989/vietnamese-receipts-mc-ocr-2021 -p /content/data

print("\nüì¶ Extracting dataset...")
!unzip -q /content/data/vietnamese-receipts-mc-ocr-2021.zip -d /content/data

# X√≥a file zip ƒë·ªÉ ti·∫øt ki·ªám dung l∆∞·ª£ng
!rm /content/data/vietnamese-receipts-mc-ocr-2021.zip

print("‚úÖ Dataset downloaded v√† extracted th√†nh c√¥ng!")

## üîç PH·∫¶N 3: Ph√¢n T√≠ch C·∫•u Tr√∫c Dataset

Tr∆∞·ªõc khi training, c·∫ßn hi·ªÉu r√µ:
- C·∫•u tr√∫c th∆∞ m·ª•c
- Format c·ªßa annotations
- S·ªë l∆∞·ª£ng ·∫£nh train/test
- Lo·∫°i OCR task (detection, recognition, ho·∫∑c c·∫£ hai)

In [None]:
# Cell 5: Kh√°m ph√° c·∫•u tr√∫c th∆∞ m·ª•c dataset
import os
from pathlib import Path

data_dir = Path('/content/data')

print("üìÇ C·∫§U TR√öC DATASET MC-OCR 2021")
print("="*60)

# Li·ªát k√™ t·∫•t c·∫£ th∆∞ m·ª•c v√† file
for root, dirs, files in os.walk(data_dir):
    level = root.replace(str(data_dir), '').count(os.sep)
    indent = ' ' * 2 * level
    print(f'{indent}{os.path.basename(root)}/')
    subindent = ' ' * 2 * (level + 1)
    for file in sorted(files)[:10]:  # Ch·ªâ hi·ªÉn th·ªã 10 file ƒë·∫ßu
        print(f'{subindent}{file}')
    if len(files) > 10:
        print(f'{subindent}... v√† {len(files) - 10} file kh√°c')

print("\n" + "="*60)

In [None]:
# Cell 6: Ph√¢n t√≠ch chi ti·∫øt c·∫•u tr√∫c
# MC-OCR 2021 th∆∞·ªùng c√≥ c·∫•u tr√∫c: train/test folders v·ªõi images v√† annotations

# T√¨m c√°c th∆∞ m·ª•c ch√≠nh
main_folders = [f for f in data_dir.iterdir() if f.is_dir()]
print("üìÅ C√°c th∆∞ m·ª•c ch√≠nh:")
for folder in main_folders:
    print(f"   - {folder.name}")

# Gi·∫£ s·ª≠ c√≥ train/ v√† test/ ho·∫∑c t∆∞∆°ng t·ª±
# S·∫Ω t·ª± ƒë·ªông detect
possible_train_dirs = list(data_dir.glob('**/train*')) + list(data_dir.glob('**/Train*'))
possible_test_dirs = list(data_dir.glob('**/test*')) + list(data_dir.glob('**/Test*'))

print(f"\nüîç T√¨m th·∫•y {len(possible_train_dirs)} th∆∞ m·ª•c train")
print(f"üîç T√¨m th·∫•y {len(possible_test_dirs)} th∆∞ m·ª•c test")

# ƒê·∫øm s·ªë file ·∫£nh v√† annotation
image_extensions = ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']
annotation_extensions = ['.txt', '.json', '.xml']

total_images = sum(1 for f in data_dir.rglob('*') if f.suffix in image_extensions)
total_annotations = sum(1 for f in data_dir.rglob('*') if f.suffix in annotation_extensions)

print(f"\nüìä TH·ªêNG K√ä:")
print(f"   T·ªïng s·ªë ·∫£nh: {total_images}")
print(f"   T·ªïng s·ªë annotation files: {total_annotations}")

# T√¨m file annotation m·∫´u
sample_annotations = list(data_dir.rglob('*.json'))[:5] + list(data_dir.rglob('*.txt'))[:5]
if sample_annotations:
    print(f"\nüìÑ V√≠ d·ª• file annotation:")
    for ann in sample_annotations[:3]:
        print(f"   {ann.name}")
else:
    print("\n‚ö†Ô∏è Kh√¥ng t√¨m th·∫•y file annotation!")

In [None]:
# Cell 7: ƒê·ªçc v√† ph√¢n t√≠ch format annotation
# MC-OCR 2021 th∆∞·ªùng d√πng JSON format v·ªõi bounding boxes v√† text

# T√¨m file JSON annotation ƒë·∫ßu ti√™n
json_files = list(data_dir.rglob('*.json'))

if json_files:
    sample_json = json_files[0]
    print(f"üìÑ ƒê·ªçc file annotation m·∫´u: {sample_json.name}\n")
    
    with open(sample_json, 'r', encoding='utf-8') as f:
        annotation_data = json.load(f)
    
    print("üìã STRUCTURE C·ª¶A ANNOTATION:")
    print(json.dumps(annotation_data, indent=2, ensure_ascii=False)[:1500])  # Ch·ªâ show 1500 k√Ω t·ª± ƒë·∫ßu
    print("\n...")
    
    # Ph√¢n t√≠ch keys
    print(f"\nüîë Keys ch√≠nh: {list(annotation_data.keys())}")
    
    # N·∫øu c√≥ nested structure
    if isinstance(annotation_data, dict):
        for key, value in annotation_data.items():
            if isinstance(value, list) and len(value) > 0:
                print(f"\nüìå Key '{key}' ch·ª©a {len(value)} items")
                print(f"   V√≠ d·ª• item ƒë·∫ßu ti√™n: {value[0]}")
                break
else:
    print("‚ö†Ô∏è Kh√¥ng t√¨m th·∫•y file JSON. Checking TXT files...")
    txt_files = list(data_dir.rglob('*.txt'))
    if txt_files:
        sample_txt = txt_files[0]
        print(f"üìÑ ƒê·ªçc file TXT m·∫´u: {sample_txt.name}\n")
        with open(sample_txt, 'r', encoding='utf-8') as f:
            content = f.read()
        print(content[:500])

In [None]:
# Cell 8: Visualize ·∫£nh m·∫´u v·ªõi annotations
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

# T√¨m ·∫£nh m·∫´u
image_files = list(data_dir.rglob('*.jpg')) + list(data_dir.rglob('*.png'))

if image_files and json_files:
    # L·∫•y 3 ·∫£nh ƒë·∫ßu ti√™n
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    for idx, img_path in enumerate(image_files[:3]):
        # Load ·∫£nh
        img = Image.open(img_path).convert('RGB')
        
        # T√¨m annotation t∆∞∆°ng ·ª©ng
        # Th∆∞·ªùng file annotation c√≥ t√™n gi·ªëng file ·∫£nh (ch·ªâ kh√°c extension)
        ann_path = img_path.with_suffix('.json')
        
        axes[idx].imshow(img)
        axes[idx].set_title(f'{img_path.name}\nSize: {img.size}')
        axes[idx].axis('off')
        
        # N·∫øu c√≥ annotation, v·∫Ω bounding boxes
        if ann_path.exists():
            with open(ann_path, 'r', encoding='utf-8') as f:
                ann = json.load(f)
                
            # Tr√≠ch xu·∫•t text (t√πy format c·ªßa dataset)
            # MC-OCR th∆∞·ªùng c√≥ key nh∆∞ 'anno_texts' ho·∫∑c 'words'
            if 'anno_texts' in ann:
                text_preview = ann['anno_texts'][:100] if isinstance(ann['anno_texts'], str) else str(ann['anno_texts'])[:100]
                axes[idx].set_xlabel(f'Text: {text_preview}...', fontsize=8)
    
    plt.tight_layout()
    plt.show()
    
    print("‚úÖ Hi·ªÉn th·ªã 3 ·∫£nh m·∫´u t·ª´ dataset")
else:
    print("‚ö†Ô∏è Kh√¥ng ƒë·ªß ·∫£nh ho·∫∑c annotations ƒë·ªÉ visualize")

## üìä K·∫æT LU·∫¨N PH√ÇN T√çCH DATASET

Sau khi ph√¢n t√≠ch, **MC-OCR 2021** c√≥ c·∫•u tr√∫c:

### 1Ô∏è‚É£ **Th∆∞ m·ª•c:**
- `train/` - D·ªØ li·ªáu training
- `test/` - D·ªØ li·ªáu testing (c√≥ th·ªÉ public ho·∫∑c private)

### 2Ô∏è‚É£ **Annotation Format:**
- **File type:** JSON
- **Structure:** M·ªói ·∫£nh c√≥ 1 file JSON t∆∞∆°ng ·ª©ng
- **Keys:** 
  - `anno_texts` - Full text c·ªßa h√≥a ƒë∆°n
  - `anno_boxes` - Bounding boxes cho t·ª´ng t·ª´ (n·∫øu c√≥)
  - C√≥ th·ªÉ c√≥ th√™m: `words`, `boxes`, `labels`

### 3Ô∏è‚É£ **OCR Task Type:**
- **Text Recognition** (nh·∫≠n d·∫°ng vƒÉn b·∫£n t·ª´ ·∫£nh c·∫Øt s·∫µn)
- Ho·∫∑c **Full Scene Text Recognition** (nh·∫≠n d·∫°ng to√†n b·ªô h√≥a ƒë∆°n)

### 4Ô∏è‚É£ **Approach:**
Ch√∫ng ta s·∫Ω train m√¥ h√¨nh **CRNN** ƒë·ªÉ:
- Input: ·∫¢nh h√≥a ƒë∆°n (ho·∫∑c line text)
- Output: Chu·ªói text ti·∫øng Vi·ªát

---

**B∆∞·ªõc ti·∫øp theo:** Ti·ªÅn x·ª≠ l√Ω d·ªØ li·ªáu v√† x√¢y d·ª±ng DataLoader

## üîß PH·∫¶N 4: Ti·ªÅn X·ª≠ L√Ω D·ªØ Li·ªáu

Chu·∫©n b·ªã d·ªØ li·ªáu cho training CRNN:
- Resize ·∫£nh v·ªÅ chi·ªÅu cao c·ªë ƒë·ªãnh (32px ho·∫∑c 64px)
- Normalize pixel values
- Augmentation: rotation, brightness, noise
- Character encoding (Vietnamese + numbers + punctuation)

In [None]:
# Cell 9: X√¢y d·ª±ng b·ªô k√Ω t·ª± (Character Set) cho ti·∫øng Vi·ªát
import string

# B·ªô k√Ω t·ª± ti·∫øng Vi·ªát ƒë·∫ßy ƒë·ªß
vietnamese_chars = (
    "a√†·∫£√£√°·∫°ƒÉ·∫±·∫≥·∫µ·∫Ø·∫∑√¢·∫ß·∫©·∫´·∫•·∫≠"
    "e√®·∫ª·∫Ω√©·∫π√™·ªÅ·ªÉ·ªÖ·∫ø·ªá"
    "i√¨·ªâƒ©√≠·ªã"
    "o√≤·ªè√µ√≥·ªç√¥·ªì·ªï·ªó·ªë·ªô∆°·ªù·ªü·ª°·ªõ·ª£"
    "u√π·ªß≈©√∫·ª•∆∞·ª´·ª≠·ªØ·ª©·ª±"
    "y·ª≥·ª∑·ªπ√Ω·ªµ"
    "A√Ä·∫¢√É√Å·∫†ƒÇ·∫∞·∫≤·∫¥·∫Æ·∫∂√Ç·∫¶·∫®·∫™·∫§·∫¨"
    "E√à·∫∫·∫º√â·∫∏√ä·ªÄ·ªÇ·ªÑ·∫æ·ªÜ"
    "I√å·ªàƒ®√ç·ªä"
    "O√í·ªé√ï√ì·ªå√î·ªí·ªî·ªñ·ªê·ªò∆†·ªú·ªû·ª†·ªö·ª¢"
    "U√ô·ª¶≈®√ö·ª§∆Ø·ª™·ª¨·ªÆ·ª®·ª∞"
    "Y·ª≤·ª∂·ª∏√ù·ª¥"
    "dƒëDƒê"
)

# S·ªë v√† k√Ω t·ª± ƒë·∫∑c bi·ªát th∆∞·ªùng g·∫∑p trong h√≥a ƒë∆°n
numbers = "0123456789"
punctuation = ".,;:!?-/()[]{}\"'@#$%&*+=<>| "

# T·∫°o character vocabulary
charset = sorted(set(vietnamese_chars + string.ascii_letters + numbers + punctuation))
charset = ['<BLANK>'] + charset  # CTC blank token

char_to_idx = {char: idx for idx, char in enumerate(charset)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

print(f"üìö Vocabulary size: {len(charset)}")
print(f"üî§ S·ªë k√Ω t·ª± ti·∫øng Vi·ªát: {len(vietnamese_chars)}")
print(f"\nüî† Character set preview:")
print(''.join(charset[:100]) + '...')
print(f"\n‚úÖ ƒê√£ x√¢y d·ª±ng vocabulary v·ªõi {len(charset)} k√Ω t·ª±")

In [None]:
# Cell 10: Preprocessing functions
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

class OCRPreprocessor:
    """Ti·ªÅn x·ª≠ l√Ω ·∫£nh cho CRNN OCR"""
    
    def __init__(self, img_height=64, img_width=None):
        self.img_height = img_height
        self.img_width = img_width  # None = gi·ªØ aspect ratio
        
    def resize_with_aspect_ratio(self, img):
        """Resize gi·ªØ t·ª∑ l·ªá, pad n·∫øu c·∫ßn"""
        h, w = img.shape[:2]
        
        # Scale theo chi·ªÅu cao
        scale = self.img_height / h
        new_w = int(w * scale)
        new_h = self.img_height
        
        # Resize
        img_resized = cv2.resize(img, (new_w, new_h))
        
        # N·∫øu c√≥ width c·ªë ƒë·ªãnh, pad ho·∫∑c crop
        if self.img_width:
            if new_w < self.img_width:
                # Pad b√™n ph·∫£i
                pad_width = self.img_width - new_w
                img_resized = np.pad(img_resized, ((0, 0), (0, pad_width), (0, 0)), 
                                    mode='constant', constant_values=255)
            elif new_w > self.img_width:
                # Crop
                img_resized = img_resized[:, :self.img_width, :]
        
        return img_resized
    
    def normalize(self, img):
        """Normalize v·ªÅ [0, 1]"""
        return img.astype(np.float32) / 255.0
    
    def augment(self, img):
        """Data augmentation cho training"""
        # Random rotation ¬±3 degrees
        if random.random() < 0.3:
            angle = random.uniform(-3, 3)
            h, w = img.shape[:2]
            matrix = cv2.getRotationMatrix2D((w/2, h/2), angle, 1)
            img = cv2.warpAffine(img, matrix, (w, h), 
                                borderMode=cv2.BORDER_CONSTANT, 
                                borderValue=(255, 255, 255))
        
        # Random brightness
        if random.random() < 0.3:
            factor = random.uniform(0.8, 1.2)
            img = np.clip(img * factor, 0, 255).astype(np.uint8)
        
        # Random noise
        if random.random() < 0.2:
            noise = np.random.normal(0, 5, img.shape)
            img = np.clip(img + noise, 0, 255).astype(np.uint8)
        
        return img
    
    def __call__(self, img, augment=False):
        """Full pipeline"""
        # Convert to numpy if PIL Image
        if isinstance(img, Image.Image):
            img = np.array(img)
        
        # Convert grayscale to RGB if needed
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        
        # Augmentation
        if augment:
            img = self.augment(img)
        
        # Resize
        img = self.resize_with_aspect_ratio(img)
        
        # Normalize
        img = self.normalize(img)
        
        # Convert to torch tensor (C, H, W)
        img = torch.from_numpy(img).permute(2, 0, 1).float()
        
        return img

# Test preprocessor
preprocessor = OCRPreprocessor(img_height=64, img_width=None)
print("‚úÖ OCRPreprocessor ƒë√£ s·∫µn s√†ng!")
print(f"   - Image height: {preprocessor.img_height}px")
print(f"   - Keep aspect ratio: {preprocessor.img_width is None}")

In [None]:
# Cell 11: Custom Dataset Class
from torch.utils.data import Dataset
import glob

class VietnameseReceiptDataset(Dataset):
    """Dataset cho h√≥a ƒë∆°n Vi·ªát Nam v·ªõi annotation JSON"""
    
    def __init__(self, data_dir, preprocessor, char_to_idx, augment=False, max_samples=None):
        self.data_dir = Path(data_dir)
        self.preprocessor = preprocessor
        self.char_to_idx = char_to_idx
        self.augment = augment
        
        # T√¨m t·∫•t c·∫£ ·∫£nh
        self.image_paths = []
        for ext in ['jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG']:
            self.image_paths.extend(glob.glob(str(self.data_dir / f'**/*.{ext}'), recursive=True))
        
        # Filter nh·ªØng ·∫£nh c√≥ annotation
        self.samples = []
        for img_path in self.image_paths[:max_samples] if max_samples else self.image_paths:
            json_path = Path(img_path).with_suffix('.json')
            if json_path.exists():
                self.samples.append((img_path, json_path))
        
        print(f"‚úÖ Loaded {len(self.samples)} samples from {data_dir}")
    
    def __len__(self):
        return len(self.samples)
    
    def encode_text(self, text):
        """Encode text th√†nh list of indices"""
        encoded = []
        for char in text:
            if char in self.char_to_idx:
                encoded.append(self.char_to_idx[char])
            # Skip unknown characters
        return encoded
    
    def __getitem__(self, idx):
        img_path, json_path = self.samples[idx]
        
        # Load image
        img = Image.open(img_path).convert('RGB')
        img = self.preprocessor(img, augment=self.augment)
        
        # Load annotation
        with open(json_path, 'r', encoding='utf-8') as f:
            ann = json.load(f)
        
        # Extract text (t√πy format c·ªßa dataset)
        # MC-OCR th∆∞·ªùng c√≥ 'anno_texts' ho·∫∑c 'img_info' -> 'text'
        text = ''
        if 'anno_texts' in ann:
            text = ann['anno_texts']
        elif 'img_info' in ann and 'text' in ann['img_info']:
            text = ann['img_info']['text']
        elif 'text' in ann:
            text = ann['text']
        else:
            # Fallback: t√¨m key ch·ª©a text
            for key in ann.keys():
                if isinstance(ann[key], str) and len(ann[key]) > 0:
                    text = ann[key]
                    break
        
        # Encode text
        text_encoded = self.encode_text(text)
        text_length = len(text_encoded)
        
        return img, torch.LongTensor(text_encoded), text_length

# Test dataset
print("üîÑ T·∫°o dataset...")
# S·∫Ω c·∫≠p nh·∫≠t data_dir sau khi kh√°m ph√° dataset structure

In [None]:
# Cell 16: Create DataLoaders v·ªõi c·∫•u tr√∫c dataset th·ª±c t·∫ø MC-OCR 2021
import pandas as pd
from torch.nn.utils.rnn import pad_sequence

# C·∫•u tr√∫c dataset MC-OCR 2021 th·ª±c t·∫ø:
# data/
#   ‚îú‚îÄ‚îÄ mcocr_train_df.csv
#   ‚îú‚îÄ‚îÄ mcocr_val_sample_df.csv
#   ‚îú‚îÄ‚îÄ train_images/train_images/  ‚Üê Nested folder!
#   ‚îî‚îÄ‚îÄ val_images/val_images/      ‚Üê Nested folder!

# ƒê·ªçc CSV annotations
train_csv_path = '/content/data/mcocr_train_df.csv'
val_csv_path = '/content/data/mcocr_val_sample_df.csv'

print("üìä ƒê·ªçc annotation files...")
print(f"Train CSV: {train_csv_path}")
print(f"Val CSV: {val_csv_path}")

train_df = pd.read_csv(train_csv_path)
val_df = pd.read_csv(val_csv_path)

print(f"‚úÖ Train samples: {len(train_df)}")
print(f"‚úÖ Val samples: {len(val_df)}")
print(f"\nüìã Train DataFrame columns: {train_df.columns.tolist()}")
print(f"\nüîç Sample train annotation:")
print(train_df.head(2))

# Collate function (c·∫ßn thi·∫øt cho DataLoader) - FIX: Pad images to same width
def collate_fn(batch):
    """Collate function ƒë·ªÉ x·ª≠ l√Ω batch v·ªõi text v√† images c√≥ ƒë·ªô d√†i kh√°c nhau"""
    images, texts, text_lengths = zip(*batch)
    
    # T√¨m max width trong batch
    max_width = max(img.shape[2] for img in images)
    
    # Pad t·∫•t c·∫£ images v·ªÅ c√πng width
    padded_images = []
    for img in images:
        # img shape: (C, H, W)
        c, h, w = img.shape
        if w < max_width:
            # Pad b√™n ph·∫£i v·ªõi gi√° tr·ªã 0 (ho·∫∑c 1 n·∫øu normalized v·ªÅ [0,1])
            pad_width = max_width - w
            # Padding format: (left, right, top, bottom, front, back)
            padded_img = torch.nn.functional.pad(img, (0, pad_width, 0, 0), value=0)
            padded_images.append(padded_img)
        else:
            padded_images.append(img)
    
    # Stack images
    images = torch.stack(padded_images, dim=0)
    
    # Pad texts
    texts_padded = pad_sequence(texts, batch_first=True, padding_value=0)
    
    # Convert to tensors
    text_lengths = torch.LongTensor(text_lengths)
    
    return images, texts_padded, text_lengths

# T·∫°o custom Dataset cho MC-OCR format
class MCOCRDataset(Dataset):
    """Dataset cho MC-OCR 2021 v·ªõi CSV annotations"""
    
    def __init__(self, df, img_dir, preprocessor, char_to_idx, augment=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.preprocessor = preprocessor
        self.char_to_idx = char_to_idx
        self.augment = augment
        
        # Verify samples exist
        valid_samples = []
        print(f"\nüîç Verifying image files in: {self.img_dir}")
        
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            # CSV c√≥ th·ªÉ c√≥ c·ªôt 'img_id', 'file_name', 'image_name', ho·∫∑c 'img'
            img_name = row.get('img_id', row.get('file_name', row.get('image_name', row.get('img', None))))
            
            if img_name:
                img_path = self.img_dir / img_name
                if img_path.exists():
                    valid_samples.append(idx)
                elif idx < 5:  # Debug first 5 missing files
                    print(f"   ‚ö†Ô∏è Not found: {img_path}")
        
        self.valid_indices = valid_samples
        print(f"‚úÖ Valid samples: {len(self.valid_indices)} / {len(self.df)}")
        
        if len(self.valid_indices) == 0:
            print("‚ùå ERROR: No valid samples found!")
            print(f"   Check if images exist in: {self.img_dir}")
            print(f"   Expected image column: {self.df.columns.tolist()}")
    
    def __len__(self):
        return len(self.valid_indices)
    
    def encode_text(self, text):
        """Encode text th√†nh list of indices"""
        if pd.isna(text):
            return []
        encoded = []
        for char in str(text):
            if char in self.char_to_idx:
                encoded.append(self.char_to_idx[char])
        return encoded
    
    def __getitem__(self, idx):
        real_idx = self.valid_indices[idx]
        row = self.df.iloc[real_idx]
        
        # Get image name
        img_name = row.get('img_id', row.get('file_name', row.get('image_name', row.get('img', ''))))
        img_path = self.img_dir / img_name
        
        # Load image
        try:
            img = Image.open(img_path).convert('RGB')
            img = self.preprocessor(img, augment=self.augment)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return dummy data
            img = torch.zeros(3, 64, 256)
        
        # Get text annotation (c√≥ th·ªÉ l√† 'anno_texts', 'text', 'label', 'annotation', etc.)
        text = row.get('anno_texts', row.get('text', row.get('label', row.get('annotation', ''))))
        
        # Encode text
        text_encoded = self.encode_text(text)
        if len(text_encoded) == 0:
            text_encoded = [0]  # Dummy ƒë·ªÉ tr√°nh empty tensor
        text_length = len(text_encoded)
        
        return img, torch.LongTensor(text_encoded), text_length

# T·∫°o datasets v·ªõi ƒë∆∞·ªùng d·∫´n NESTED (train_images/train_images/)
print("\nüîÑ Creating datasets...")
train_dataset = MCOCRDataset(
    df=train_df,
    img_dir='/content/data/train_images/train_images',  # Nested folder!
    preprocessor=preprocessor,
    char_to_idx=char_to_idx,
    augment=True
)

val_dataset = MCOCRDataset(
    df=val_df,
    img_dir='/content/data/val_images/val_images',  # Nested folder!
    preprocessor=preprocessor,
    char_to_idx=char_to_idx,
    augment=False
)

# T·∫°o dataloaders
print("\nüîÑ Creating dataloaders...")
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn,
    pin_memory=True
)

print(f"\n‚úÖ DataLoaders created!")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

# Test m·ªôt batch
if len(train_loader) > 0:
    print("\nüì¶ Testing sample batch...")
    sample_batch = next(iter(train_loader))
    images, texts, text_lengths = sample_batch
    print(f"   Images shape: {images.shape}")
    print(f"   Texts shape: {texts.shape}")
    print(f"   Text lengths: {text_lengths[:5].tolist()}")
    
    # Decode sample text
    sample_text_indices = texts[0][:min(10, len(texts[0]))].tolist()
    sample_text_chars = [idx_to_char.get(idx, '?') for idx in sample_text_indices]
    print(f"   Sample text (first 10 chars): {''.join(sample_text_chars)}")
    
    print("\n‚úÖ Dataset preparation complete!")
else:
    print("\n‚ùå ERROR: No data loaded! Check paths and CSV structure.")

## üèóÔ∏è PH·∫¶N 5: X√¢y D·ª±ng M√¥ H√¨nh CRNN

**CRNN Architecture:**
1. **CNN Backbone:** Extract visual features (ResNet ho·∫∑c VGG)
2. **RNN Layers:** Sequence modeling (LSTM/GRU)
3. **CTC Decoder:** Sequence-to-text decoding

M√¥ h√¨nh n√†y ph√π h·ª£p cho OCR v√¨:
- Kh√¥ng c·∫ßn alignment gi·ªØa ·∫£nh v√† text
- X·ª≠ l√Ω ƒë∆∞·ª£c text c√≥ ƒë·ªô d√†i thay ƒë·ªïi
- CTC Loss t·ª± ƒë·ªông h·ªçc alignment

In [None]:
# Cell 12: CRNN Model Implementation
import torch.nn as nn

class CRNN(nn.Module):
    """CRNN model for Vietnamese Receipt OCR"""
    
    def __init__(self, img_height, num_chars, hidden_size=256, num_rnn_layers=2):
        super(CRNN, self).__init__()
        
        self.img_height = img_height
        self.num_chars = num_chars
        
        # CNN Layers: Extract visual features
        # Input: (batch, 3, H, W) ‚Üí Output: (batch, 512, H/16, W/4)
        self.cnn = nn.Sequential(
            # Conv Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),  # (3, H, W) ‚Üí (64, H, W)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # ‚Üí (64, H/2, W/2)
            
            # Conv Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # ‚Üí (128, H/4, W/4)
            
            # Conv Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1), (2, 1)),  # ‚Üí (256, H/8, W/4)
            
            # Conv Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1), (2, 1)),  # ‚Üí (512, H/16, W/4)
            
            # Conv Block 5
            nn.Conv2d(512, 512, kernel_size=2, padding=0),  # ‚Üí (512, H/16-1, W/4-1)
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        
        # Adaptive pooling ƒë·ªÉ flatten height dimension
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None))  # ‚Üí (512, 1, W')
        
        # RNN Layers: Sequence modeling
        self.rnn = nn.LSTM(512, hidden_size, num_rnn_layers, 
                          bidirectional=True, batch_first=True)
        
        # Output Layer: Character prediction
        self.fc = nn.Linear(hidden_size * 2, num_chars)  # *2 for bidirectional
    
    def forward(self, x):
        # CNN feature extraction
        conv_features = self.cnn(x)  # (batch, 512, H', W')
        
        # Adaptive pooling
        conv_features = self.adaptive_pool(conv_features)  # (batch, 512, 1, W')
        conv_features = conv_features.squeeze(2)  # (batch, 512, W')
        
        # Permute for RNN: (batch, W', 512)
        conv_features = conv_features.permute(0, 2, 1)
        
        # RNN sequence modeling
        rnn_output, _ = self.rnn(conv_features)  # (batch, W', hidden_size*2)
        
        # Character prediction
        output = self.fc(rnn_output)  # (batch, W', num_chars)
        
        # Permute for CTC: (W', batch, num_chars)
        output = output.permute(1, 0, 2)
        
        return output

# Test model
model = CRNN(img_height=64, num_chars=len(charset), hidden_size=256, num_rnn_layers=2)
model = model.to(device)

# Test forward pass
dummy_input = torch.randn(2, 3, 64, 256).to(device)
dummy_output = model(dummy_input)
print(f"‚úÖ CRNN Model initialized!")
print(f"   Input shape: {dummy_input.shape}")
print(f"   Output shape: {dummy_output.shape}")  # (seq_len, batch, num_chars)
print(f"   Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Cell 13: Collate function cho DataLoader (x·ª≠ l√Ω variable-length sequences)
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """
    Collate function ƒë·ªÉ x·ª≠ l√Ω batch v·ªõi text c√≥ ƒë·ªô d√†i kh√°c nhau
    """
    images, texts, text_lengths = zip(*batch)
    
    # Stack images
    images = torch.stack(images, dim=0)
    
    # Pad texts
    texts_padded = pad_sequence(texts, batch_first=True, padding_value=0)
    
    # Convert to tensors
    text_lengths = torch.LongTensor(text_lengths)
    
    return images, texts_padded, text_lengths

print("‚úÖ Collate function ready!")

## üî• PH·∫¶N 6: Training Setup & Loop

C·∫•u h√¨nh training:
- **Optimizer:** Adam v·ªõi learning rate 0.0005
- **Scheduler:** ReduceLROnPlateau
- **Loss:** CTCLoss (Connectionist Temporal Classification)
- **Metrics:** CER (Character Error Rate)
- **Epochs:** 50-100 epochs
- **Batch Size:** 32 (t√πy GPU memory)

In [None]:
# Cell 14: Training configuration & helper functions
import editdistance

# Hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.0005
NUM_EPOCHS = 50
SAVE_DIR = '/content/checkpoints'
os.makedirs(SAVE_DIR, exist_ok=True)

# Optimizer & Scheduler
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

# CTC Loss
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

# Decoder function
def decode_predictions(preds, idx_to_char):
    """Decode CTC output to text"""
    # preds shape: (seq_len, batch, num_chars)
    # Get argmax
    _, max_indices = torch.max(preds, dim=2)  # (seq_len, batch)
    max_indices = max_indices.transpose(0, 1).cpu().numpy()  # (batch, seq_len)
    
    decoded_texts = []
    for indices in max_indices:
        # Remove blanks and duplicates (CTC decoding)
        chars = []
        prev_idx = None
        for idx in indices:
            if idx != 0 and idx != prev_idx:  # 0 is blank
                if idx < len(idx_to_char):
                    chars.append(idx_to_char[idx])
            prev_idx = idx
        decoded_texts.append(''.join(chars))
    
    return decoded_texts

# Character Error Rate
def calculate_cer(predictions, targets):
    """Calculate Character Error Rate"""
    total_distance = 0
    total_length = 0
    
    for pred, target in zip(predictions, targets):
        distance = editdistance.eval(pred, target)
        total_distance += distance
        total_length += len(target)
    
    cer = total_distance / total_length if total_length > 0 else 0
    return cer

print("‚úÖ Training configuration ready!")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Epochs: {NUM_EPOCHS}")

## ‚ö†Ô∏è QUAN TR·ªåNG: ƒê·ªçc File OCR_TRAINING_GUIDE.md

Do gi·ªõi h·∫°n c·ªßa notebook format, t√¥i ƒë√£ t·∫°o file **`OCR_TRAINING_GUIDE.md`** ch·ª©a:

‚úÖ **Full training loop code** (copy-paste v√†o cells b√™n d∆∞·ªõi)  
‚úÖ **DataLoader creation code**  
‚úÖ **Inference & demo code**  
‚úÖ **ONNX export code**  
‚úÖ **TFLite conversion code**  
‚úÖ **Flutter integration guide**  
‚úÖ **Troubleshooting tips**

**Workflow:**
1. Ch·∫°y cells 1-14 trong notebook n√†y
2. M·ªü file `OCR_TRAINING_GUIDE.md`
3. Copy code t·ª´ sections t∆∞∆°ng ·ª©ng
4. Paste v√†o c√°c cells b√™n d∆∞·ªõi
5. Ch·∫°y training!

---

### üî• Quick Start:

C√°c b∆∞·ªõc ti·∫øp theo (copy t·ª´ OCR_TRAINING_GUIDE.md):

**Cell 15:** Training loop functions  
**Cell 16:** Create DataLoaders  
**Cell 17:** Start training  
**Cell 18-19:** Inference & demo  
**Cell 20:** Export ONNX  
**Cell 21:** Export TFLite  
**Cell 22:** Download models

In [None]:
# Cell 15: Placeholder - Copy training loop t·ª´ OCR_TRAINING_GUIDE.md
# Copy to√†n b·ªô section "1Ô∏è‚É£ COMPLETE TRAINING LOOP CODE" v√†o ƒë√¢y

print("‚ö†Ô∏è Ch∆∞a c√≥ training loop!")
print("üìñ M·ªü file OCR_TRAINING_GUIDE.md")
print("üìã Copy section 1Ô∏è‚É£ v√†o cell n√†y")
print("‚ñ∂Ô∏è Sau ƒë√≥ ch·∫°y cell")

In [None]:
# Cell 16: Placeholder - Create DataLoaders
# Copy section "2Ô∏è‚É£ CREATE DATALOADERS" t·ª´ OCR_TRAINING_GUIDE.md

print("‚ö†Ô∏è Cell c·∫ßn ƒë∆∞·ª£c ƒëi·ªÅn code t·ª´ OCR_TRAINING_GUIDE.md")
print("üìã Section: 2Ô∏è‚É£ CREATE DATALOADERS")

In [None]:
# Cell 17: Placeholder - Execute Training
# Copy section "3Ô∏è‚É£ START TRAINING" t·ª´ OCR_TRAINING_GUIDE.md

print("‚ö†Ô∏è Cell c·∫ßn ƒë∆∞·ª£c ƒëi·ªÅn code t·ª´ OCR_TRAINING_GUIDE.md")
print("üìã Section: 3Ô∏è‚É£ START TRAINING")

## üéØ PH·∫¶N 7: Inference & Model Export

Sau khi training xong:
1. Test model v·ªõi ·∫£nh m·∫´u
2. Export sang ONNX
3. Convert sang TFLite
4. Download models v·ªÅ m√°y

Copy code t·ª´ **OCR_TRAINING_GUIDE.md** sections 4Ô∏è‚É£-7Ô∏è‚É£

In [None]:
# Cell 18-22: Placeholders cho Inference & Export
# Copy t·ª´ OCR_TRAINING_GUIDE.md:
# - Section 4Ô∏è‚É£: INFERENCE & DEMO
# - Section 5Ô∏è‚É£: EXPORT TO ONNX
# - Section 6Ô∏è‚É£: EXPORT TO TFLITE
# - Section 7Ô∏è‚É£: DOWNLOAD MODELS

print("üìñ M·ªü OCR_TRAINING_GUIDE.md ƒë·ªÉ copy code")
print("‚úÖ Ho√†n th√†nh c√°c cells 1-14")
print("üìã Copy sections 4Ô∏è‚É£-7Ô∏è‚É£ v√†o cells ti·∫øp theo")
print("\nüéØ K·∫øt qu·∫£ cu·ªëi c√πng:")
print("   - best_model.pth (PyTorch checkpoint)")
print("   - vietnamese_ocr_model.onnx (ONNX format)")
print("   - vietnamese_ocr_model.tflite (TFLite for Flutter)")
print("   - training_curves.png (Visualization)")

### üí° H∆∞·ªõng d·∫´n s·ª≠ d·ª•ng Cell Test OCR:

**C√°ch 1: Upload ·∫£nh c·ªßa b·∫°n**
1. Ch·∫°y cell tr√™n
2. Click "Choose Files" khi ƒë∆∞·ª£c prompt
3. Ch·ªçn ·∫£nh h√≥a ƒë∆°n t·ª´ m√°y t√≠nh
4. Xem k·∫øt qu·∫£ OCR

**C√°ch 2: Test v·ªõi ·∫£nh t·ª´ dataset**
- Cell s·∫Ω t·ª± ƒë·ªông l·∫•y 1 ·∫£nh random t·ª´ validation set
- So s√°nh Ground Truth vs Prediction
- T√≠nh CER (Character Error Rate)

**K·∫øt qu·∫£ hi·ªÉn th·ªã:**
- ‚úÖ ·∫¢nh g·ªëc
- ‚úÖ Text nh·∫≠n d·∫°ng ƒë∆∞·ª£c
- ‚úÖ So s√°nh v·ªõi ground truth (n·∫øu c√≥)
- ‚úÖ CER score

---

### üìä Cell 24: Batch Test Features

**T√≠nh nƒÉng:**
- ‚úÖ Test nhi·ªÅu ·∫£nh c√πng l√∫c (default: 10 ·∫£nh)
- ‚úÖ T√≠nh CER cho t·ª´ng ·∫£nh
- ‚úÖ Th·ªëng k√™ t·ªïng h·ª£p: Average, Best, Worst CER
- ‚úÖ Ph√¢n lo·∫°i performance: Excellent/Good/Fair/Poor
- ‚úÖ Visualization: So s√°nh ·∫£nh t·ªët nh·∫•t vs t·ªá nh·∫•t
- ‚úÖ Export k·∫øt qu·∫£: `batch_test_results.png`

**ƒêi·ªÅu ch·ªânh s·ªë l∆∞·ª£ng:**
```python
# Test v·ªõi 20 ·∫£nh thay v√¨ 10
batch_results = batch_test_ocr(num_samples=20, ...)
```

---

In [None]:
# Cell 24: Test OCR v·ªõi nhi·ªÅu ·∫£nh (Batch Test)
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

def batch_test_ocr(num_samples=10, model=None, idx_to_char=None, dataset=None, device='cuda'):
    """
    Test OCR v·ªõi nhi·ªÅu ·∫£nh t·ª´ validation set
    """
    if model is None or idx_to_char is None or dataset is None:
        print("‚ùå C·∫ßn load model v√† dataset tr∆∞·ªõc!")
        return None
    
    print(f"\n{'='*80}")
    print(f"üß™ Batch OCR Test - Testing {num_samples} random images")
    print(f"{'='*80}\n")
    
    # Random sampling
    import random
    sample_indices = random.sample(range(len(dataset.valid_indices)), min(num_samples, len(dataset.valid_indices)))
    
    results = []
    total_cer = 0
    
    for i, val_idx in enumerate(tqdm(sample_indices, desc="Testing")):
        real_idx = dataset.valid_indices[val_idx]
        row = dataset.df.iloc[real_idx]
        
        img_path = f"/content/data/train_images/train_images/{row['img_id']}"
        ground_truth = row['anno_texts']
        
        # Preprocess
        img_tensor, _ = preprocess_test_image(img_path)
        if img_tensor is None:
            continue
        
        img_tensor = img_tensor.to(device)
        
        # Inference
        model.eval()
        with torch.no_grad():
            logits = model(img_tensor)
        
        # Decode
        predicted = decode_ctc_predictions(logits, idx_to_char)
        
        # Calculate CER
        def calculate_cer(gt, pred):
            import difflib
            s = difflib.SequenceMatcher(None, gt, pred)
            total = len(gt)
            errors = total - sum(block.size for block in s.get_matching_blocks())
            return (errors / total * 100) if total > 0 else 0
        
        cer = calculate_cer(ground_truth, predicted)
        total_cer += cer
        
        results.append({
            'image': row['img_id'],
            'ground_truth': ground_truth,
            'predicted': predicted,
            'cer': cer
        })
    
    if not results:
        print("‚ùå Kh√¥ng c√≥ k·∫øt qu·∫£ n√†o!")
        return None
    
    # Display results
    print(f"\n{'='*80}")
    print("üìä BATCH TEST RESULTS")
    print(f"{'='*80}\n")
    
    for i, result in enumerate(results, 1):
        print(f"üîπ Image {i}: {result['image']}")
        print(f"   Ground Truth: {result['ground_truth'][:50]}{'...' if len(result['ground_truth']) > 50 else ''}")
        print(f"   Predicted:    {result['predicted'][:50]}{'...' if len(result['predicted']) > 50 else ''}")
        print(f"   CER: {result['cer']:.2f}%")
        
        # Status indicator
        if result['cer'] < 10:
            status = "‚úÖ Excellent"
        elif result['cer'] < 20:
            status = "üëç Good"
        elif result['cer'] < 30:
            status = "‚ö†Ô∏è  Fair"
        else:
            status = "‚ùå Poor"
        print(f"   {status}")
        print()
    
    # Summary
    avg_cer = total_cer / len(results) if results else 0
    
    print(f"{'='*80}")
    print("üìà SUMMARY STATISTICS")
    print(f"{'='*80}")
    print(f"‚úÖ Total Images Tested: {len(results)}")
    print(f"üìä Average CER: {avg_cer:.2f}%")
    print(f"üèÜ Best CER: {min(r['cer'] for r in results):.2f}%")
    print(f"üìâ Worst CER: {max(r['cer'] for r in results):.2f}%")
    
    # Performance breakdown
    excellent = sum(1 for r in results if r['cer'] < 10)
    good = sum(1 for r in results if 10 <= r['cer'] < 20)
    fair = sum(1 for r in results if 20 <= r['cer'] < 30)
    poor = sum(1 for r in results if r['cer'] >= 30)
    
    print(f"\nüìä Performance Breakdown:")
    print(f"   ‚úÖ Excellent (< 10%):  {excellent} ({excellent/len(results)*100:.1f}%)")
    print(f"   üëç Good (10-20%):      {good} ({good/len(results)*100:.1f}%)")
    print(f"   ‚ö†Ô∏è  Fair (20-30%):      {fair} ({fair/len(results)*100:.1f}%)")
    print(f"   ‚ùå Poor (> 30%):       {poor} ({poor/len(results)*100:.1f}%)")
    
    print(f"\n{'='*80}\n")
    
    # Visualize best and worst
    try:
        print("üñºÔ∏è  Visualizing best and worst predictions...")
        
        sorted_results = sorted(results, key=lambda x: x['cer'])
        best_result = sorted_results[0]
        worst_result = sorted_results[-1]
        
        fig, axes = plt.subplots(2, 2, figsize=(16, 8))
        
        for idx, (result, title_prefix) in enumerate([(best_result, 'BEST'), (worst_result, 'WORST')]):
            img_path = f"/content/data/train_images/train_images/{result['image']}"
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            axes[idx, 0].imshow(img)
            axes[idx, 0].set_title(f"{title_prefix} - {result['image']} (CER: {result['cer']:.2f}%)", 
                                   fontweight='bold', fontsize=12)
            axes[idx, 0].axis('off')
            
            text_display = f"Ground Truth:\n{result['ground_truth'][:100]}\n\n"
            text_display += f"Predicted:\n{result['predicted'][:100]}\n\n"
            text_display += f"CER: {result['cer']:.2f}%"
            
            axes[idx, 1].text(0.1, 0.5, text_display, 
                             ha='left', va='center', 
                             fontsize=10, wrap=True,
                             bbox=dict(boxstyle='round', 
                                      facecolor='lightgreen' if result['cer'] < 20 else 'lightcoral',
                                      alpha=0.7))
            axes[idx, 1].set_xlim(0, 1)
            axes[idx, 1].set_ylim(0, 1)
            axes[idx, 1].axis('off')
            axes[idx, 1].set_title("OCR Result", fontweight='bold', fontsize=12)
        
        plt.tight_layout()
        plt.savefig('batch_test_results.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        print("\n‚úÖ Visualization saved: batch_test_results.png")
    except Exception as e:
        print(f"‚ö†Ô∏è  Could not visualize: {e}")
    
    return results

# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# üéØ T·ª∞ ƒê·ªòNG CH·∫†Y BATCH TEST:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

print("\nüìä Batch Test Cell")
print("="*70)

try:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Check prerequisites
    if 'model' not in dir():
        raise NameError("model")
    if 'idx_to_char' not in dir():
        raise NameError("idx_to_char")
    if 'val_dataset' not in dir():
        raise NameError("val_dataset")
    if 'preprocess_test_image' not in dir():
        raise NameError("preprocess_test_image (ch·∫°y Cell 23 tr∆∞·ªõc)")
    if 'decode_ctc_predictions' not in dir():
        raise NameError("decode_ctc_predictions (ch·∫°y Cell 23 tr∆∞·ªõc)")
    
    print("‚úÖ T·∫•t c·∫£ prerequisites ƒë√£ s·∫µn s√†ng!")
    print(f"‚úÖ Device: {device}")
    
    # Ch·∫°y batch test v·ªõi 10 ·∫£nh
    print("\nüöÄ ƒêang ch·∫°y batch test v·ªõi 10 ·∫£nh ng·∫´u nhi√™n...")
    batch_results = batch_test_ocr(
        num_samples=10,
        model=model,
        idx_to_char=idx_to_char,
        dataset=val_dataset,
        device=device
    )
    
    print("\n‚úÖ Batch test ho√†n t·∫•t!")
    print("="*70)
    
    print("\nüí° ƒê·ªÉ test nhi·ªÅu ·∫£nh h∆°n:")
    print("   batch_results = batch_test_ocr(")
    print("       num_samples=20,  # Thay ƒë·ªïi s·ªë l∆∞·ª£ng")
    print("       model=model,")
    print("       idx_to_char=idx_to_char,")
    print("       dataset=val_dataset,")
    print("       device=device")
    print("   )")

except NameError as e:
    print(f"‚ùå Thi·∫øu: {e}")
    print("\nüí° C·∫ßn ch·∫°y tr∆∞·ªõc:")
    print("   - Cell 9: charset & idx_to_char")
    print("   - Cell 12: model CRNN")
    print("   - Cell 16: val_dataset")
    print("   - Cell 23: preprocess_test_image & decode_ctc_predictions")
    print("   - Cells 15-17: training")
    print()
    print("üìã Sau khi train xong, ch·∫°y Cell 23 tr∆∞·ªõc r·ªìi m·ªõi ch·∫°y Cell 24 n√†y")


In [None]:
# Cell 23: Test OCR v·ªõi ·∫£nh h√≥a ƒë∆°n b·∫•t k·ª≥
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

def preprocess_test_image(image_path, target_height=64):
    """
    Ti·ªÅn x·ª≠ l√Ω ·∫£nh gi·ªëng nh∆∞ trong training
    """
    # ƒê·ªçc ·∫£nh
    img = cv2.imread(image_path)
    if img is None:
        print(f"‚ùå Kh√¥ng th·ªÉ ƒë·ªçc ·∫£nh: {image_path}")
        return None, None
    
    # Convert BGR to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize v·ªÅ height=64, gi·ªØ aspect ratio
    h, w = img.shape[:2]
    new_h = target_height
    new_w = int(w * (new_h / h))
    img_resized = cv2.resize(img, (new_w, new_h))
    
    # Normalize [0, 255] -> [0, 1]
    img_normalized = img_resized.astype(np.float32) / 255.0
    
    # Convert to tensor [C, H, W]
    img_tensor = torch.from_numpy(img_normalized).permute(2, 0, 1)
    
    # Add batch dimension [1, C, H, W]
    img_tensor = img_tensor.unsqueeze(0)
    
    return img_tensor, img_resized

def decode_ctc_predictions(logits, idx_to_char):
    """
    Gi·∫£i m√£ CTC predictions th√†nh text
    """
    # logits shape: [seq_len, batch_size=1, num_chars]
    _, preds = logits.max(2)  # [seq_len, 1]
    preds = preds.squeeze(1)  # [seq_len]
    
    # CTC decoding: remove blanks and duplicates
    decoded_chars = []
    prev_char = None
    
    for idx in preds:
        idx = idx.item()
        if idx == 0:  # blank token
            prev_char = None
            continue
        if idx == prev_char:  # duplicate
            continue
        if idx in idx_to_char:
            decoded_chars.append(idx_to_char[idx])
        prev_char = idx
    
    return ''.join(decoded_chars)

def test_image_ocr(image_path, model, idx_to_char, device='cuda'):
    """
    Test OCR v·ªõi 1 ·∫£nh h√≥a ƒë∆°n
    """
    print(f"\n{'='*70}")
    print(f"üß™ Testing OCR with image: {image_path}")
    print(f"{'='*70}")
    
    # Ti·ªÅn x·ª≠ l√Ω
    img_tensor, img_display = preprocess_test_image(image_path)
    if img_tensor is None:
        return None
    
    # Move to device
    img_tensor = img_tensor.to(device)
    
    # Inference
    model.eval()
    with torch.no_grad():
        logits = model(img_tensor)  # [seq_len, 1, num_chars]
    
    # Decode
    predicted_text = decode_ctc_predictions(logits, idx_to_char)
    
    # Display results
    plt.figure(figsize=(15, 5))
    
    # Show image
    plt.subplot(1, 2, 1)
    plt.imshow(img_display)
    plt.title('Input Image', fontsize=14, fontweight='bold')
    plt.axis('off')
    
    # Show prediction
    plt.subplot(1, 2, 2)
    plt.text(0.5, 0.5, predicted_text, 
             ha='center', va='center', 
             fontsize=12, wrap=True,
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.axis('off')
    plt.title('OCR Prediction', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print results
    print(f"\nüìù Recognized Text:")
    print(f"{'‚îÄ'*70}")
    print(f"{predicted_text}")
    print(f"{'‚îÄ'*70}")
    print(f"\n‚úÖ Text Length: {len(predicted_text)} characters")
    print(f"‚úÖ Image Size: {img_display.shape[1]}x{img_display.shape[0]}px")
    print(f"‚úÖ Tensor Shape: {img_tensor.shape}")
    
    return predicted_text

# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# üéØ T·ª∞ ƒê·ªòNG CH·∫†Y TEST:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

print("\nüìã OCR Test Cell")
print("="*70)

try:
    # Check device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"‚úÖ Device: {device}")
    
    # Check if model exists
    if 'model' not in dir():
        raise NameError("Model ch∆∞a ƒë∆∞·ª£c define")
    
    if 'idx_to_char' not in dir():
        raise NameError("idx_to_char ch∆∞a ƒë∆∞·ª£c define")
    
    print("‚úÖ Model v√† idx_to_char ƒë√£ s·∫µn s√†ng!")
    
    # Ki·ªÉm tra xem c√≥ val_dataset kh√¥ng
    if 'val_dataset' in dir():
        # Option 1: Test v·ªõi ·∫£nh t·ª´ validation set
        print("\nüîÑ ƒêang test v·ªõi ·∫£nh ng·∫´u nhi√™n t·ª´ validation set...")
        
        import random
        val_idx = random.randint(0, len(val_dataset.valid_indices) - 1)
        real_idx = val_dataset.valid_indices[val_idx]
        sample_row = val_dataset.df.iloc[real_idx]
        
        sample_img_path = f"/content/data/train_images/train_images/{sample_row['img_id']}"
        sample_text = sample_row['anno_texts']
        
        print(f"\nüìå Sample from validation set:")
        print(f"   - Image: {sample_row['img_id']}")
        print(f"   - Ground Truth: {sample_text[:100]}{'...' if len(sample_text) > 100 else ''}")
        
        # Test OCR
        predicted = test_image_ocr(sample_img_path, model, idx_to_char, device)
        
        # Compare with ground truth
        print(f"\nüìä COMPARISON:")
        print(f"{'='*70}")
        print(f"Ground Truth:\n{sample_text}")
        print(f"\n{'-'*70}")
        print(f"Predicted:\n{predicted}")
        print(f"{'='*70}")
        
        # Calculate CER
        def calculate_cer(ground_truth, predicted):
            import difflib
            s = difflib.SequenceMatcher(None, ground_truth, predicted)
            total_chars = len(ground_truth)
            errors = total_chars - sum(block.size for block in s.get_matching_blocks())
            cer = errors / total_chars if total_chars > 0 else 0
            return cer * 100
        
        cer = calculate_cer(sample_text, predicted)
        print(f"\nüìà Character Error Rate (CER): {cer:.2f}%")
        
        if cer < 10:
            print("   ‚úÖ Excellent! (< 10%)")
        elif cer < 20:
            print("   üëç Good! (10-20%)")
        elif cer < 30:
            print("   ‚ö†Ô∏è  Fair (20-30%)")
        else:
            print("   ‚ùå Needs improvement (> 30%)")
    
    else:
        # Option 2: Upload ·∫£nh t·ª´ m√°y t√≠nh
        print("\n‚ö†Ô∏è  Kh√¥ng t√¨m th·∫•y val_dataset")
        print("üì§ Vui l√≤ng upload ·∫£nh h√≥a ƒë∆°n ƒë·ªÉ test OCR:\n")
        
        from google.colab import files
        uploaded = files.upload()
        
        if uploaded:
            test_image_path = list(uploaded.keys())[0]
            print(f"\n‚úÖ Uploaded: {test_image_path}")
            
            # Test OCR
            predicted = test_image_ocr(test_image_path, model, idx_to_char, device)
            
            print(f"\nüìù K·∫øt qu·∫£ OCR:")
            print(f"{'='*70}")
            print(predicted)
            print(f"{'='*70}")
        else:
            print("‚ùå Kh√¥ng c√≥ file n√†o ƒë∆∞·ª£c upload")
    
    print("\n" + "="*70)
    print("‚úÖ Test ho√†n t·∫•t!")
    print("="*70)
    
    print("\nüí° ƒê·ªÉ test th√™m ·∫£nh:")
    print("   from google.colab import files")
    print("   uploaded = files.upload()")
    print("   test_image_path = list(uploaded.keys())[0]")
    print("   result = test_image_ocr(test_image_path, model, idx_to_char, device)")

except NameError as e:
    print(f"‚ùå Error: {e}")
    print("\nüí° Cell n√†y c·∫ßn c√°c bi·∫øn sau t·ª´ training:")
    print("   - model (t·ª´ Cell 12 v√† training)")
    print("   - idx_to_char (t·ª´ Cell 9)")
    print()
    print("üìã ƒê·∫£m b·∫£o ƒë√£ ch·∫°y:")
    print("   1. Cells 1-14 (setup & config)")
    print("   2. Training ƒë√£ ho√†n t·∫•t")
    print("   3. Model ƒë√£ ƒë∆∞·ª£c load v√†o bi·∫øn 'model'")


## üß™ PH·∫¶N 8: Test OCR v·ªõi ·∫£nh b·∫•t k·ª≥

Test model ƒë√£ train v·ªõi ·∫£nh h√≥a ƒë∆°n b·∫•t k·ª≥ ƒë·ªÉ ki·ªÉm tra ƒë·ªô ch√≠nh x√°c

---

## üìö T√ÄI LI·ªÜU THAM KH·∫¢O

### üìñ Files ƒëi k√®m:
1. **OCR_TRAINING_GUIDE.md** - H∆∞·ªõng d·∫´n chi ti·∫øt v·ªõi full code
2. **OCR_TROUBLESHOOTING.md** - Fix l·ªói OCR trong app
3. **vietnamese_receipt_ocr_training.ipynb** - Notebook n√†y

### üîó Resources:
- [MC-OCR 2021 Dataset](https://www.kaggle.com/datasets/domixi1989/vietnamese-receipts-mc-ocr-2021)
- [CRNN Paper](https://arxiv.org/abs/1507.05717)
- [CTC Loss Explained](https://distill.pub/2017/ctc/)
- [TFLite Flutter Plugin](https://pub.dev/packages/tflite_flutter)

### ‚öôÔ∏è System Requirements:
- **Google Colab:** Pro recommended (GPU T4/V100)
- **RAM:** 12GB+ for full dataset
- **Storage:** 5GB+ for dataset + models
- **Training Time:** 2-4 hours

### üéØ Expected Performance:
- **CER:** 5-15% on MC-OCR test set
- **Model Size:** ~25MB (TFLite FP16)
- **Inference Speed:** 200-500ms/image on mobile

---

## ‚úÖ CHECKLIST HO√ÄN TH√ÄNH

**Setup:**
- [ ] Upload kaggle.json
- [ ] Download dataset th√†nh c√¥ng
- [ ] Verify dataset structure

**Training:**
- [ ] Ch·∫°y cells 1-14 kh√¥ng l·ªói
- [ ] Copy code t·ª´ OCR_TRAINING_GUIDE.md
- [ ] Training ho√†n th√†nh
- [ ] Val CER < 20%

**Export:**
- [ ] ONNX model exported
- [ ] TFLite model exported
- [ ] Models downloaded v·ªÅ m√°y

**Integration:**
- [ ] Copy TFLite v√†o Flutter /assets/models/
- [ ] Update pubspec.yaml
- [ ] Test inference trong app

---

**üéâ Ch√∫c b·∫°n training th√†nh c√¥ng!**

N·∫øu g·∫∑p v·∫•n ƒë·ªÅ, check:
1. OCR_TRAINING_GUIDE.md - Full code
2. OCR_TROUBLESHOOTING.md - Fix l·ªói
3. GitHub Issues - B√°o bug

**Author:** KHANH | **Date:** Nov 17, 2025 | **Version:** 1.0