In [12]:
!git clone https://github.com/htrnguyen/compare_ocr_benchmark.git

Cloning into 'compare_ocr_benchmark'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 58 (delta 23), reused 49 (delta 14), pack-reused 0 (from 0)[K
Receiving objects: 100% (58/58), 2.18 MiB | 25.97 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [3]:
!pip install python-Levenshtein jiwer ultralytics



In [13]:
import sys, os, time
import torch
import pandas as pd
from PIL import Image
import numpy as np

# Đảm bảo các module custom có trong sys.path
sys.path.append('/kaggle/working/compare_ocr_benchmark/CNN_TR_OCR_Resnet')

from dataset_polygon import char2idx, idx2char
from model_cnn_transformer import OCRModel
from ultralytics import YOLO

# Nếu dùng file chung cho tiền xử lý/metrics/utils thì thêm:
sys.path.append('/kaggle/working/compare_ocr_benchmark/common')
from metrics import compute_metrics
from utils import read_annotations, save_results

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


In [14]:
# Dữ liệu trên Kaggle:
IMG_DIR = '/kaggle/input/nckh-2425-crops'
CSV_ANN = '/kaggle/input/nckh-2425-crops/crops_gt.csv'

# Đường dẫn model trong working dir (nên copy model vào đây trước)
YOLO_MODEL_PATH = '/kaggle/input/cnn_tr_ocr_resnet/pytorch/default/1/best.pt'
OCR_MODEL_PATH = '/kaggle/input/cnn_tr_ocr_resnet/pytorch/default/1/best_ocr_model.pth'
FONT_PATH = "/kaggle/working/compare_ocr_benchmark/CNN_TR_OCR_Resnet/Roboto-Regular.ttf"

In [15]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
VOCAB_SIZE = len(char2idx)

yolo_model = YOLO(YOLO_MODEL_PATH)
ocr_model = OCRModel(vocab_size=VOCAB_SIZE).to(DEVICE)
ocr_model.load_state_dict(torch.load(OCR_MODEL_PATH, map_location=DEVICE))
ocr_model.eval()

Device: cuda


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 214MB/s]


OCRModel(
  (encoder): CNNEncoder(
    (layer0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downs

In [16]:
from torchvision import transforms

def preprocess_ocr_image(pil_img):
    transform = transforms.Compose(
        [
            transforms.Resize((32, 128)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return transform(pil_img).unsqueeze(0)

def decode_sequence(indices):
    chars = []
    SOS_TOKEN = next((token for token in char2idx.keys() if "SOS" in token), None)
    for idx in indices:
        ch = idx2char.get(idx, "")
        if ch == "<EOS>":
            break
        if ch not in ("<PAD>", SOS_TOKEN):
            chars.append(ch)
    return "".join(chars)

In [None]:
def group_text_lines(results, y_threshold=80):
    # Nhóm bbox thành dòng, ưu tiên trái-phải nếu 2 bbox gần nhau theo y
    bboxes_with_text = []
    for bbox, text in results:
        x1, y1, x2, y2 = bbox
        y_center = (y1 + y2) / 2
        x_left = min(x1, x2)
        x_right = max(x1, x2)
        bboxes_with_text.append({
            "bbox": bbox,
            "text": text,
            "y_center": y_center,
            "x_left": x_left,
            "x_right": x_right,
        })

    # Sắp xếp ban đầu theo y_center (trên xuống dưới), x_left (trái-phải)
    bboxes_with_text.sort(key=lambda x: (x["y_center"], x["x_left"]))

    # Gộp thành các dòng
    lines = []
    for item in bboxes_with_text:
        added = False
        for line in lines:
            # Nếu khoảng cách y gần 1 dòng, gộp vào dòng đó
            if any(abs(item["y_center"] - i["y_center"]) < y_threshold for i in line):
                line.append(item)
                added = True
                break
        if not added:
            lines.append([item])
    # Sort từng dòng theo x_left (trái-phải)
    for line in lines:
        line.sort(key=lambda x: x["x_left"])
    return lines

def ocr_text_from_lines(results, y_threshold=80):
    grouped_lines = group_text_lines(results, y_threshold)
    # Ghép các dòng lại thành 1 string (mỗi dòng 1 phần), cách nhau dấu cách
    lines_text = [" ".join([item["text"] for item in line]) for line in grouped_lines]
    return " ".join(lines_text).strip()

In [18]:
def yolo_ocr_pipeline(image_path, conf_threshold=0.5, y_threshold=80):
    # 1. Detection
    results = yolo_model(image_path)
    boxes = results[0].boxes.xyxy.cpu().numpy()
    confs = results[0].boxes.conf.cpu().numpy()
    filtered = [(box, conf) for box, conf in zip(boxes, confs) if conf > conf_threshold]
    boxes = [box for box, conf in filtered]

    # 2. Crop bbox và nhận diện text từng bbox
    img_pil = Image.open(image_path).convert("RGB")
    ocr_results = []
    for box in boxes:
        x1, y1, x2, y2 = map(int, box)
        crop = img_pil.crop((x1, y1, x2, y2))
        image_tensor = preprocess_ocr_image(crop).to(DEVICE)
        with torch.no_grad():
            memory = ocr_model.encoder(image_tensor)
            SOS_TOKEN = next((token for token in char2idx.keys() if "SOS" in token), None)
            MAX_LEN = 36
            ys = torch.tensor([[char2idx[SOS_TOKEN]]], device=DEVICE)
            for _ in range(MAX_LEN):
                out = ocr_model.decoder(
                    ys,
                    memory,
                    tgt_mask=ocr_model.generate_square_subsequent_mask(ys.size(1)).to(DEVICE),
                )
                prob = out[:, -1, :]
                _, next_word = torch.max(prob, dim=1)
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                if next_word.item() == char2idx["<EOS>"]:
                    break
            pred_text = decode_sequence(ys.squeeze(0).tolist())
            ocr_results.append(((x1, y1, x2, y2), pred_text))
    # 3. Ghép kết quả theo thứ tự trái qua phải, trên xuống dưới thành 1 dòng
    return ocr_text_from_lines(ocr_results, y_threshold=y_threshold)

In [19]:
df = read_annotations(CSV_ANN)

results = []
for idx, row in df.iterrows():
    fname = row['filename']
    desc_gt = row['description_gt']
    label = row.get('label', '')
    img_path = os.path.join(IMG_DIR, fname)

    try:
        t1 = time.perf_counter()
        pred = yolo_ocr_pipeline(img_path, conf_threshold=0.5, y_threshold=80)
        t2 = time.perf_counter()
        infer_time = round(t2 - t1, 3)
    except Exception as e:
        pred = f"OCR_Error: {e}"
        infer_time = 0.0

    metrics = compute_metrics(desc_gt, pred)

    results.append({
        "filename": fname,
        "label": label,
        "ground_truth": desc_gt,
        "predicted_text": pred,
        "cer": metrics["cer"],
        "wer": metrics["wer"],
        "lev": metrics["lev"],
        "acc": metrics["acc"],
        "time": infer_time
    })
    if idx % 50 == 0:
        print(f"Processed: {idx}/{len(df)}")


image 1/1 /kaggle/input/nckh-2425-crops/001_heo-cao-boi_F_crop_0.jpg: 288x640 3 texts, 53.8ms
Speed: 9.5ms preprocess, 53.8ms inference, 265.2ms postprocess per image at shape (1, 3, 288, 640)
Processed: 0/2284

image 1/1 /kaggle/input/nckh-2425-crops/001_heo-cao-boi_F_crop_1.jpg: 640x640 4 texts, 9.6ms
Speed: 3.0ms preprocess, 9.6ms inference, 1.5ms postprocess per image at shape (1, 3, 640, 640)

image 1/1 /kaggle/input/nckh-2425-crops/001_heo-cao-boi_F_crop_2.jpg: 288x640 1 text, 9.0ms
Speed: 1.0ms preprocess, 9.0ms inference, 1.4ms postprocess per image at shape (1, 3, 288, 640)

image 1/1 /kaggle/input/nckh-2425-crops/001_heo-cao-boi_F_crop_3.jpg: 288x640 1 text, 8.0ms
Speed: 1.0ms preprocess, 8.0ms inference, 1.3ms postprocess per image at shape (1, 3, 288, 640)

image 1/1 /kaggle/input/nckh-2425-crops/001_heo-cao-boi_F_crop_4.jpg: 544x640 2 texts, 40.0ms
Speed: 1.6ms preprocess, 40.0ms inference, 1.3ms postprocess per image at shape (1, 3, 544, 640)

image 1/1 /kaggle/input/nck

In [20]:
import os
OUT_CSV = '/kaggle/working/compare_ocr_benchmark/results/cnntr_results.csv'
os.makedirs(os.path.dirname(OUT_CSV), exist_ok=True)
save_results(results, OUT_CSV)

Lưu thành công: /kaggle/working/compare_ocr_benchmark/results/cnntr_results.csv
