## 예측 파이프라인

In [None]:
# -*- coding: utf-8 -*-
import os, json, csv, math
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torchvision.ops import box_iou
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from tqdm import tqdm

# -------------------- 경로/설정 --------------------
ROOT = r"C:\Users\USER\DL_OCR\dataset"
IMG_DIR = Path(ROOT) / "valid_image"
LBL_DIR = Path(ROOT) / "valid_label"

DET_CKPT = r"C:\Users\USER\DL_OCR\char_det_frcnn_best.pth"           # Faster R-CNN ckpt
REC_CKPT = r"C:\Users\USER\DL_OCR\crnn_ctc_best3.pth"  # CRNN ckpt

IOU_MATCH_THR = 0.8     # 탐지 평가/매칭 임계치
SCORE_THR = 0.8         # 예측 박스 score 필터
MIN_BOX_WH = 2          # 너무 작은 박스 제거
PAD_PX = 2              # 크롭 패딩
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUT_CSV = str(Path(ROOT).parent / "val_detect_recognize.csv")

# -------------------- CRNN 정의/로더 --------------------
class CRNN(nn.Module):
    def __init__(self, num_classes:int, img_h=32, cnn_out=256, hidden=256, layers=2):
        super().__init__()
        self.img_h = img_h
        self.cnn = nn.Sequential(
            nn.Conv2d(1,64,3,1,1), nn.ReLU(True), nn.MaxPool2d(2,2),  # stride 2
            nn.Conv2d(64,128,3,1,1), nn.ReLU(True), nn.MaxPool2d(2,2),# stride 4
            nn.Conv2d(128,cnn_out,3,1,1), nn.ReLU(True),
        )
        self.rnn = nn.LSTM(cnn_out, hidden, layers, batch_first=False, bidirectional=True)
        self.fc  = nn.Linear(hidden*2, num_classes)

    def forward(self,x):
        f = self.cnn(x)          # (B,C,H',W'), H'=img_h/4
        f = f.mean(2)            # (B,C,W')
        f = f.permute(2,0,1)     # (T,B,C)
        y,_ = self.rnn(f)        # (T,B,2H)
        return self.fc(y)        # (T,B,num_classes)

def load_crnn(rec_ckpt_path: str, device=DEVICE):
    ckpt = torch.load(rec_ckpt_path, map_location="cpu",weights_only=False)
    itos = ckpt["itos"]; stoi = ckpt["stoi"]
    num_classes = len(itos) + 1
    img_h = ckpt.get("img_height", 32)
    model = CRNN(num_classes=num_classes, img_h=img_h)
    model.load_state_dict(ckpt["model"], strict=True)
    model.to(device).eval()
    meta = {
        "itos": itos,
        "stoi": stoi,
        "img_height": img_h,
        "max_width": ckpt.get("max_width", 512),
        "blank_idx": len(itos)
    }
    return model, meta

@torch.no_grad()
def greedy_decode_ctc(logits: torch.Tensor, input_lengths: torch.Tensor, blank_idx: int):
    probs = logits.log_softmax(dim=-1)
    pred = probs.argmax(dim=-1) # (T,B)
    T,B = pred.shape
    hyps = []
    for b in range(B):
        seq, prev = [], -1
        Tvalid = min(int(input_lengths[b].item()), T)
        for t in range(Tvalid):
            p = int(pred[t,b].item())
            if p != blank_idx and p != prev:
                seq.append(p)
            prev = p
        hyps.append(seq)
    return hyps

# -------------------- Faster R-CNN 로더 --------------------
def load_detector(det_ckpt_path: str, num_classes: int = 2, device=DEVICE):
    m = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(
        weights="DEFAULT", weights_backbone="DEFAULT"
    )
    in_feat = m.roi_heads.box_predictor.cls_score.in_features
    m.roi_heads.box_predictor = FastRCNNPredictor(in_feat, num_classes)
    ckpt = torch.load(det_ckpt_path, map_location="cpu",weights_only=False)
    m.load_state_dict(ckpt["model"], strict=True)
    m.to(device).eval()
    return m

# -------------------- 유틸 --------------------
def cer(ref: str, hyp: str) -> float:
    n, m = len(ref), len(hyp)
    if n == 0: return 100.0 if m>0 else 0.0
    dp = np.zeros((n+1, m+1), dtype=np.int32)
    dp[:,0] = np.arange(n+1); dp[0,:] = np.arange(m+1)
    for i in range(1,n+1):
        for j in range(1,m+1):
            cost = 0 if ref[i-1]==hyp[j-1] else 1
            dp[i,j] = min(dp[i-1,j]+1, dp[i,j-1]+1, dp[i-1,j-1]+cost)
    return 100.0 * dp[n,m] / n

def poly_to_xyxy(x4: List[float], y4: List[float]) -> Tuple[int,int,int,int]:
    x1, x2 = int(min(x4)), int(max(x4))
    y1, y2 = int(min(y4)), int(max(y4))
    return x1, y1, x2, y2

def read_gt(json_path: Path):
    with open(json_path, "r", encoding="utf-8") as f:
        js = json.load(f)
    boxes, texts = [], []
    for bb in js.get("bbox", []):
        x1,y1,x2,y2 = poly_to_xyxy(bb["x"], bb["y"])
        boxes.append([x1,y1,x2,y2])
        texts.append(str(bb.get("data","")))
    if len(boxes)==0:
        return torch.zeros((0,4), dtype=torch.float32), []
    return torch.tensor(boxes, dtype=torch.float32), texts

def clamp_box(x1,y1,x2,y2,W,H,pad=PAD_PX):
    x1 = max(0, x1 - pad); y1 = max(0, y1 - pad)
    x2 = min(W, x2 + pad); y2 = min(H, y2 + pad)
    if x2 <= x1: x2 = min(W, x1+1)
    if y2 <= y1: y2 = min(H, y1+1)
    return x1,y1,x2,y2

def resize_keep_ratio_pad_gray(img_pil: Image.Image, img_h: int, max_w: int):
    img_pil = img_pil.convert("L")
    w,h = img_pil.size
    scale = img_h / max(h,1)
    new_w = max(1, int(round(w*scale)))
    new_w = min(new_w, max_w)
    img_rs = img_pil.resize((new_w, img_h), Image.BILINEAR)
    canvas = Image.new("L", (max_w, img_h), 255)
    canvas.paste(img_rs, (0,0))
    arr = np.array(canvas, dtype=np.float32) / 255.0
    t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)  # [1,1,H,W]
    input_len = math.ceil(new_w/4)  # MaxPool 2회 → stride=4
    return t, input_len

# -------------------- 메인 파이프라인 --------------------
@torch.no_grad()
def main():
    # 모델 로드
    det = load_detector(DET_CKPT, num_classes=2, device=DEVICE)
    rec, meta = load_crnn(REC_CKPT, device=DEVICE)
    itos, img_h, max_w, blank_idx = meta["itos"], meta["img_height"], meta["max_width"], meta["blank_idx"]

    img_paths = sorted([p for p in IMG_DIR.iterdir() if p.suffix.lower() in (".png",".jpg",".jpeg",".tif",".bmp")])
    assert len(img_paths)>0, f"No images in {IMG_DIR}"

    # 통계
    TP=FP=FN=0
    cer_sum=0.0; cer_cnt=0
    matched_boxes=0; total_gt=0; total_pred=0

    with open(OUT_CSV, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["image_path","pred_x1","pred_y1","pred_x2","pred_y2","score",
                    "matched","gt_text","pred_text","cer"])

        for img_path in tqdm(img_paths, desc="Detect→Recognize (valid)"):
            # GT 로드
            json_path = LBL_DIR / (img_path.stem + ".json")
            gt_boxes, gt_texts = read_gt(json_path)
            total_gt += len(gt_boxes)

            # 이미지 로드 & 탐지
            pil = Image.open(img_path).convert("RGB")
            W,H = pil.size
            out = det([torchvision.transforms.functional.to_tensor(pil).to(DEVICE)])[0]
            pred_boxes = out["boxes"].detach().cpu()
            scores     = out["scores"].detach().cpu()

            # 점수/크기 필터
            keep = scores >= SCORE_THR
            pred_boxes = pred_boxes[keep]; scores = scores[keep]
            if len(pred_boxes) > 0 and MIN_BOX_WH > 0:
                ww = (pred_boxes[:,2]-pred_boxes[:,0])
                hh = (pred_boxes[:,3]-pred_boxes[:,1])
                big = (ww>=MIN_BOX_WH) & (hh>=MIN_BOX_WH)
                pred_boxes = pred_boxes[big]; scores = scores[big]
            total_pred += len(pred_boxes)

            # 매칭 (IoU 최댓값 기준, 1:1 할당)
            if len(pred_boxes)==0 and len(gt_boxes)==0:
                continue
            if len(pred_boxes)==0:
                FN += len(gt_boxes); continue
            if len(gt_boxes)==0:
                FP += len(pred_boxes)
                # 텍스트 없음으로 CSV 기록(선택)
                for pb, sc in zip(pred_boxes, scores):
                    x1,y1,x2,y2 = map(int, pb.tolist())
                    x1,y1,x2,y2 = clamp_box(x1,y1,x2,y2,W,H,PAD_PX)
                    w.writerow([str(img_path), x1,y1,x2,y2, float(sc), 0, "", "", ""])
                continue

            ious = box_iou(pred_boxes, gt_boxes)  # (P,G)
            used_g = set()
            # 간단한 그리디 매칭: 각 pred는 최고 IoU gt와 시도
            for p in range(len(pred_boxes)):
                gi = int(torch.argmax(ious[p]).item())
                iou = float(ious[p,gi].item())
                if gi not in used_g and iou >= IOU_MATCH_THR:
                    TP += 1; used_g.add(gi); matched = True
                    # ---- 인식(CRNN) ----
                    x1,y1,x2,y2 = map(int, pred_boxes[p].tolist())
                    x1,y1,x2,y2 = clamp_box(x1,y1,x2,y2,W,H,PAD_PX)
                    crop = pil.crop((x1,y1,x2,y2))
                    crop_t, in_len = resize_keep_ratio_pad_gray(crop, img_h=img_h, max_w=max_w)
                    logits = rec(crop_t.to(DEVICE))         # [T,1,C]
                    hyp_idx = greedy_decode_ctc(logits.cpu(), torch.tensor([in_len]), blank_idx)[0]
                    hyp = "".join(itos[j] for j in hyp_idx if j < len(itos))
                    ref = gt_texts[gi]
                    c = cer(ref, hyp)
                    cer_sum += c; cer_cnt += 1; matched_boxes += 1
                    w.writerow([str(img_path), x1,y1,x2,y2, float(scores[p]), 1, ref, hyp, f"{c:.2f}"])
                else:
                    # 매칭 실패 → FP
                    FP += 1; matched = False
                    x1,y1,x2,y2 = map(int, pred_boxes[p].tolist())
                    x1,y1,x2,y2 = clamp_box(x1,y1,x2,y2,W,H,PAD_PX)
                    # 매칭 실패한 것은 텍스트 비교대상이 없으므로 빈칸으로 기록
                    w.writerow([str(img_path), x1,y1,x2,y2, float(scores[p]), 0, "", "", ""])
            FN += (len(gt_boxes) - len(used_g))

    P = TP / (TP+FP+1e-9)
    R = TP / (TP+FN+1e-9)
    F1 = 2*P*R / (P+R+1e-9)
    avg_cer = (cer_sum / max(cer_cnt,1))
    print(f"\nDetection@IoU={IOU_MATCH_THR}: P={P:.3f}  R={R:.3f}  F1={F1:.3f}  "
          f"(TP={TP}, FP={FP}, FN={FN}, Pred={total_pred}, GT={total_gt})")
    print(f"OCR CER on matched boxes: {avg_cer:.2f}%  (matched={matched_boxes}/{total_gt})")
    print(f"Saved: {OUT_CSV}")

if __name__ == "__main__":
    main()

Detect→Recognize (valid): 100%|██████████| 727/727 [03:30<00:00,  3.45it/s]


Detection@IoU=0.8: P=0.998  R=0.999  F1=0.999  (TP=17122, FP=28, FN=23, Pred=17150, GT=17145)
OCR CER on matched boxes: 2.07%  (matched=17122/17145)
Saved: C:\Users\USER\DL_OCR\val_detect_recognize.csv





---

## 우리 손글씨 예측시키기

In [15]:
# ===== 경로만 정확히 맞추기 =====
ROOT = r"C:\Users\USER\DL_OCR\dataset"
TEST_DIR = Path(ROOT) / "test_image"             # ← 여기!  datasets/test 아님
OUT_CSV_TEST = str(Path(ROOT).parent / "test_detect_recognize.csv")
VIS_DIR = Path(ROOT).parent / "test_vis"
SCORE_THR = 0.7
VIS_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# ===== 이미지 수집을 rglob로(하위 폴더까지) + 디버그 로그 =====
def _collect_images(folder: Path):
    exts = {".png",".jpg",".jpeg",".tif",".tiff",".bmp",".gif"}
    return sorted([p for p in folder.rglob("*") if p.suffix.lower() in exts])

@torch.no_grad()
def main_test_debug():
    print(f"[DBG] TEST_DIR={TEST_DIR}")
    if not TEST_DIR.exists():
        print("[ERR] TEST_DIR 없음"); return
    img_paths = _collect_images(TEST_DIR)
    print(f"[DBG] 이미지 개수: {len(img_paths)}  (예: {img_paths[0] if img_paths else '없음'})")
    print(f"[DBG] 결과 CSV: {OUT_CSV_TEST}")
    print(f"[DBG] VIS_DIR: {VIS_DIR}")

    if len(img_paths) == 0:
        print("[ERR] 처리할 이미지가 없습니다."); return

    det = load_detector(DET_CKPT, num_classes=2, device=DEVICE)
    rec, meta = load_crnn(REC_CKPT, device=DEVICE)
    itos, img_h, max_w, blank_idx = meta["itos"], meta["img_height"], meta["max_width"], meta["blank_idx"]
    font = _load_korean_font(18)

    with open(OUT_CSV_TEST, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f); w.writerow(["image_path","x1","y1","x2","y2","score","pred_text"])

        for img_path in tqdm(img_paths, desc="Detect→Recognize (TEST)"):
            pil = Image.open(img_path).convert("RGB"); W,H = pil.size
            out = det([torchvision.transforms.functional.to_tensor(pil).to(DEVICE)])[0]
            boxes = out["boxes"].detach().cpu(); scores = out["scores"].detach().cpu()

            # 필터
            keep = scores >= SCORE_THR
            boxes, scores = boxes[keep], scores[keep]
            if len(boxes)>0 and MIN_BOX_WH>0:
                ww = boxes[:,2]-boxes[:,0]; hh = boxes[:,3]-boxes[:,1]
                big = (ww>=MIN_BOX_WH) & (hh>=MIN_BOX_WH)
                boxes, scores = boxes[big], scores[big]

            # 그래도 0개면 최상위 1개 fallback
            if len(boxes)==0 and out["boxes"].shape[0]>0:
                top = int(torch.argmax(out["scores"]).item())
                boxes = out["boxes"][top:top+1].detach().cpu()
                scores = out["scores"][top:top+1].detach().cpu()

            # 항상 시각화 파일은 저장
            if len(boxes)==0:
                out_path = VIS_DIR / (img_path.stem + "_vis.png")
                pil.save(out_path)
                continue

            idxs = torch.argsort(scores, descending=True).tolist()
            for i in idxs:
                x1,y1,x2,y2 = map(int, boxes[i].tolist())
                x1,y1,x2,y2 = clamp_box(x1,y1,x2,y2,W,H,PAD_PX)
                crop = pil.crop((x1,y1,x2,y2))
                crop_t, in_len = resize_keep_ratio_pad_gray(crop, img_h=img_h, max_w=max_w)
                logits = rec(crop_t.to(DEVICE))
                hyp_idx = greedy_decode_ctc(logits.cpu(), torch.tensor([in_len]), blank_idx)[0]
                hyp = "".join(itos[j] for j in hyp_idx if j < len(itos))
                pil = draw_box_and_text(pil, (x1,y1,x2,y2), hyp, font=font)
                w.writerow([str(img_path), x1,y1,x2,y2, float(scores[i]), hyp])

            (VIS_DIR / f"{img_path.stem}_vis.png").write_bytes(pil.tobytes() if hasattr(pil, "tobytes") else b"")
            pil.save(VIS_DIR / f"{img_path.stem}_vis.png")   # PNG로 저장
    print("[DBG] 완료")

In [17]:
# ===== 엔트리포인트: 테스트만 강제 실행 =====
if __name__ == "__main__":
    main_test_debug()

[DBG] TEST_DIR=C:\Users\USER\DL_OCR\dataset\test_image
[DBG] 이미지 개수: 9  (예: C:\Users\USER\DL_OCR\dataset\test_image\권서영_250910_155950-6.png)
[DBG] 결과 CSV: C:\Users\USER\DL_OCR\test_detect_recognize.csv
[DBG] VIS_DIR: C:\Users\USER\DL_OCR\test_vis


Detect→Recognize (TEST): 100%|██████████| 9/9 [00:02<00:00,  3.15it/s]

[DBG] 완료



