In [None]:
# -*- coding: utf-8 -*-
import os, csv, math, unicodedata, random
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
from PIL import Image
from time import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# ===================== 경로/하이퍼파라미터 =====================
ROOT = r"C:\Users\USER\DL_OCR\dataset\crnn_crops"
TRAIN_CSV = Path(ROOT) / "train_labels.csv"
VALID_CSV = Path(ROOT) / "valid_labels.csv"

TRAIN_VARIANTS = {"gt_pad_blur", "det_pad"}
VALID_VARIANTS = {"gt_pad"}

IMG_HEIGHT = 32
MAX_WIDTH  = 512            
BATCH_SIZE = 500            
EPOCHS = 10                 
LR = 1e-3                   
WEIGHT_DECAY = 0.0
EARLY_STOP_PATIENCE = 5     
NUM_WORKERS = 0            
SEED = 1337
SAVE_PATH = "crnn_ctc_best3.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ===================== 유틸 =====================
def normalize_text(s: str) -> str:
    if s is None: return ""
    s = s.strip("\n\r\t ")
    return unicodedata.normalize("NFKC", s)

def read_csv_filter(csv_path: Path, variants: set) -> List[Tuple[str,str]]:
    rows = []
    with open(csv_path, "r", encoding="utf-8") as f:
        r = csv.DictReader(f)
        for row in r:
            if variants and row["variant"] not in variants:
                continue
            img_path = row["img_path"]
            text = normalize_text(row["text"])
            if len(text) == 0:
                continue
            rows.append((img_path, text))
    return rows

def build_vocab(samples: List[Tuple[str,str]]) -> Tuple[Dict[str,int], List[str]]:
    charset = set()
    for _, txt in samples:
        charset.update(list(txt))
    itos = sorted(list(charset))
    stoi = {ch:i for i,ch in enumerate(itos)}
    return stoi, itos

def text_to_indices(text: str, stoi: Dict[str,int]) -> List[int]:
    return [stoi[ch] for ch in text if ch in stoi]

# ===================== 전처리/데이터셋 =====================
class CRNNDataset(Dataset):
    def __init__(self, rows: List[Tuple[str,str]], stoi: Dict[str,int],
                 img_h: int = IMG_HEIGHT, max_w: int = MAX_WIDTH):
        self.rows = rows
        self.stoi = stoi
        self.img_h = img_h
        self.max_w = max_w

    def __len__(self): return len(self.rows)

    def _resize_keep_ratio_pad(self, img: Image.Image):
        # 높이 고정, 가로 비율 유지 → 우측 패딩
        w, h = img.size
        scale = self.img_h / h
        new_w = max(1, int(round(w * scale)))
        new_w = min(new_w, self.max_w)
        img = img.resize((new_w, self.img_h), Image.BILINEAR)
        canvas = Image.new("L", (self.max_w, self.img_h), 255)
        canvas.paste(img, (0, 0))
        return np.array(canvas), new_w

    def __getitem__(self, idx):
        img_path, text = self.rows[idx]
        img = Image.open(img_path).convert("L")
        img_np, valid_w = self._resize_keep_ratio_pad(img)
        img_t = torch.from_numpy(img_np).unsqueeze(0).float() / 255.0
        target_idx = torch.tensor(text_to_indices(text, self.stoi), dtype=torch.long)
        return img_t, valid_w, target_idx, img_path

def crnn_collate(batch, blank_idx: int):
    imgs, widths, targets, paths = zip(*batch)
    imgs = torch.stack(imgs, dim=0)  # [B,1,H,Wmax]
    # 다운샘플: MaxPool2d 두 번 → 가로축 stride=4
    input_lengths = torch.tensor([math.ceil(w/4) for w in widths], dtype=torch.long)
    target_concat = torch.cat(targets)
    target_lengths = torch.tensor([len(t) for t in targets], dtype=torch.long)
    return imgs, input_lengths, target_concat, target_lengths, paths

# ===================== CRNN 모델 =====================
class CRNN(nn.Module):
    """
    CNN(3층) → H-dim mean pooling → BiLSTM(2층, hidden=256) → Linear
    """
    def __init__(self, num_classes: int, img_h: int = IMG_HEIGHT, cnn_out: int = 256,
                 hidden: int = 256, layers: int = 2):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1,  64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2,2),  # stride 2
            nn.Conv2d(64,128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2,2),  # stride 4
            nn.Conv2d(128, cnn_out, 3, 1, 1), nn.ReLU(True),
        )
        # RNN 입력차원은 cnn_out (H축 평균을 취하므로)
        self.rnn = nn.LSTM(input_size=cnn_out, hidden_size=hidden,
                           num_layers=layers, batch_first=False, bidirectional=True)
        self.fc = nn.Linear(hidden*2, num_classes)

    def forward(self, x):
        # x: [B,1,32,W]
        f = self.cnn(x)        # [B,C=256,H',W'], H' = 32/4 = 8
        f = f.mean(2)          # H' 평균 → [B,C,W']
        f = f.permute(2,0,1)   # [W',B,C]  (T,B,D)
        y,_ = self.rnn(f)      # [T,B,2H]
        return self.fc(y)      # [T,B,C(num_classes)]

# ===================== 디코딩/지표 =====================
def greedy_decode(logits: torch.Tensor, input_lengths: torch.Tensor, blank_idx: int) -> List[List[int]]:
    probs = logits.log_softmax(dim=-1)
    pred = probs.argmax(dim=-1)  # [T,B]
    T, B = pred.shape
    hyps = []
    for b in range(B):
        seq, prev = [], -1
        Tvalid = min(int(input_lengths[b].item()), T)
        for t in range(Tvalid):
            p = int(pred[t, b].item())
            if p != blank_idx and p != prev:
                seq.append(p)
            prev = p
        hyps.append(seq)
    return hyps

def cer(ref: str, hyp: str) -> float:
    n, m = len(ref), len(hyp)
    dp = np.zeros((n+1, m+1), dtype=np.int32)
    for i in range(n+1): dp[i,0] = i
    for j in range(m+1): dp[0,j] = j
    for i in range(1, n+1):
        for j in range(1, m+1):
            cost = 0 if ref[i-1]==hyp[j-1] else 1
            dp[i,j] = min(dp[i-1,j]+1, dp[i,j-1]+1, dp[i-1,j-1]+cost)
    return dp[n,m] / max(n,1)

# ===================== 학습/평가 루프 =====================
def evaluate(model, loader, itos, blank_idx):
    model.eval()
    cer_sum, cnt = 0.0, 0
    with torch.no_grad():
        for imgs, input_lens, tgt_concat, tgt_lens, paths in loader:
            imgs = imgs.to(DEVICE)
            logits = model(imgs)
            hyps_idx = greedy_decode(logits.cpu(), input_lens, blank_idx)
            offset = 0
            refs = []
            for L in tgt_lens.tolist():
                idxs = tgt_concat[offset:offset+L].tolist()
                refs.append("".join(itos[i] for i in idxs))
                offset += L
            hyps = []
            for seq in hyps_idx:
                hyps.append("".join(itos[i] for i in seq if i < len(itos)))
            for r, h in zip(refs, hyps):
                cer_sum += cer(r, h); cnt += 1
    return (cer_sum / max(cnt,1)) * 100.0

def main():
    # 데이터 읽기
    train_rows = read_csv_filter(TRAIN_CSV, TRAIN_VARIANTS)
    valid_rows = read_csv_filter(VALID_CSV, VALID_VARIANTS)
    print(f"train samples: {len(train_rows)}, valid samples: {len(valid_rows)}")

    # Vocab
    stoi, itos = build_vocab(train_rows + valid_rows)
    blank_idx = len(itos)
    num_classes = len(itos) + 1

    # Datasets / Loaders
    train_ds = CRNNDataset(train_rows, stoi, IMG_HEIGHT, MAX_WIDTH)
    valid_ds = CRNNDataset(valid_rows, stoi, IMG_HEIGHT, MAX_WIDTH)
    collate_fn = lambda b: crnn_collate(b, blank_idx)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, collate_fn=collate_fn, pin_memory=True)
    valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, collate_fn=collate_fn, pin_memory=True)

    # 모델/최적화기/손실
    model = CRNN(num_classes=num_classes).to(DEVICE)
    ctc_loss = nn.CTCLoss(blank=blank_idx, zero_infinity=True)
    optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_cer = float("inf")
    no_improve = 0

    for epoch in range(1, EPOCHS+1):
        model.train()
        running_loss, n_seen = 0.0, 0
        t0 = time()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d}", ncols=120)
        for step, (imgs, input_lens, tgt_concat, tgt_lens, paths) in enumerate(pbar, start=1):
            imgs = imgs.to(DEVICE)
            logits = model(imgs)
            log_probs = logits.log_softmax(dim=-1)
            loss = ctc_loss(log_probs, tgt_concat.to(DEVICE), input_lens.to(DEVICE), tgt_lens.to(DEVICE))

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            running_loss += loss.item()
            n_seen += imgs.size(0)
            ips = n_seen / max(time()-t0, 1e-6)
            lr = optimizer.param_groups[0]['lr']
            pbar.set_postfix(loss=f"{running_loss/step:.4f}", lr=f"{lr:.2e}", ips=f"{ips:.1f}")

        # 검증
        val_cer = evaluate(model, valid_loader, itos, blank_idx)
        print(f"[Epoch {epoch:02d}] train_loss={running_loss/len(train_loader):.4f}  valid_CER={val_cer:.2f}%")

        # Early stopping & 저장
        if val_cer < best_cer - 1e-6:
            best_cer = val_cer
            no_improve = 0
            ckpt = {
                "model": model.state_dict(),
                "itos": itos,
                "stoi": stoi,
                "epoch": epoch,
                "valid_cer": best_cer,
                "img_height": IMG_HEIGHT,
                "max_width": MAX_WIDTH,
            }
            torch.save(ckpt, SAVE_PATH)
            print(f"  >> Saved BEST to {SAVE_PATH} (CER {best_cer:.2f}%)")
        else:
            no_improve += 1
            print(f"  >> no improvement ({no_improve}/{EARLY_STOP_PATIENCE})")
            if no_improve >= EARLY_STOP_PATIENCE:
                print("Early stopping triggered.")
                break

        # 샘플 디코드 2~3개 출력
        model.eval()
        with torch.no_grad():
            for imgs, input_lens, tgt_concat, tgt_lens, paths in valid_loader:
                imgs = imgs.to(DEVICE)
                logits = model(imgs)
                hyps_idx = greedy_decode(logits.cpu(), input_lens, blank_idx)
                offset = 0
                refs = []
                for L in tgt_lens.tolist():
                    idxs = tgt_concat[offset:offset+L].tolist()
                    refs.append("".join(itos[i] for i in idxs))
                    offset += L
                for i in range(min(3, len(hyps_idx))):
                    hyp_str = "".join(itos[j] for j in hyps_idx[i] if j < len(itos))
                    print(f"  ex{i+1}: REF='{refs[i]}' | HYP='{hyp_str}'")
                break

    print("Done.")

if __name__ == "__main__":
    main()

train samples: 265413, valid samples: 17145


Epoch 01: 100%|██████████████████████████████████| 531/531 [05:06<00:00,  1.73it/s, ips=866.7, loss=5.0740, lr=1.00e-03]


[Epoch 01] train_loss=5.0740  valid_CER=94.57%
  >> Saved BEST to crnn_ctc_best3.pth (CER 94.57%)
  ex1: REF='474387-05-714500' | HYP='2'
  ex2: REF='탕시우' | HYP='2'
  ex3: REF='부산광역시' | HYP='2'


Epoch 02: 100%|██████████████████████████████████| 531/531 [05:03<00:00,  1.75it/s, ips=875.3, loss=3.0941, lr=1.00e-03]


[Epoch 02] train_loss=3.0941  valid_CER=25.86%
  >> Saved BEST to crnn_ctc_best3.pth (CER 25.86%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='명서우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 03: 100%|██████████████████████████████████| 531/531 [04:59<00:00,  1.77it/s, ips=884.8, loss=0.6370, lr=1.00e-03]


[Epoch 03] train_loss=0.6370  valid_CER=9.16%
  >> Saved BEST to crnn_ctc_best3.pth (CER 9.16%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='탕서우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 04: 100%|██████████████████████████████████| 531/531 [04:53<00:00,  1.81it/s, ips=903.6, loss=0.2722, lr=1.00e-03]


[Epoch 04] train_loss=0.2722  valid_CER=5.45%
  >> Saved BEST to crnn_ctc_best3.pth (CER 5.45%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='담시우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 05: 100%|██████████████████████████████████| 531/531 [04:55<00:00,  1.80it/s, ips=897.5, loss=0.1668, lr=1.00e-03]


[Epoch 05] train_loss=0.1668  valid_CER=3.98%
  >> Saved BEST to crnn_ctc_best3.pth (CER 3.98%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='탕시우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 06: 100%|██████████████████████████████████| 531/531 [04:57<00:00,  1.79it/s, ips=893.0, loss=0.1165, lr=1.00e-03]


[Epoch 06] train_loss=0.1165  valid_CER=3.10%
  >> Saved BEST to crnn_ctc_best3.pth (CER 3.10%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='탕시우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 07: 100%|██████████████████████████████████| 531/531 [04:59<00:00,  1.78it/s, ips=887.4, loss=0.0857, lr=1.00e-03]


[Epoch 07] train_loss=0.0857  valid_CER=2.60%
  >> Saved BEST to crnn_ctc_best3.pth (CER 2.60%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='탕시우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 08: 100%|██████████████████████████████████| 531/531 [04:52<00:00,  1.82it/s, ips=907.2, loss=0.0661, lr=1.00e-03]


[Epoch 08] train_loss=0.0661  valid_CER=2.17%
  >> Saved BEST to crnn_ctc_best3.pth (CER 2.17%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='탕시우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 09: 100%|██████████████████████████████████| 531/531 [05:05<00:00,  1.74it/s, ips=868.8, loss=0.0507, lr=1.00e-03]


[Epoch 09] train_loss=0.0507  valid_CER=2.16%
  >> Saved BEST to crnn_ctc_best3.pth (CER 2.16%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='탕시우'
  ex3: REF='부산광역시' | HYP='부산광역시'


Epoch 10: 100%|██████████████████████████████████| 531/531 [04:59<00:00,  1.77it/s, ips=884.9, loss=0.0413, lr=1.00e-03]


[Epoch 10] train_loss=0.0413  valid_CER=1.78%
  >> Saved BEST to crnn_ctc_best3.pth (CER 1.78%)
  ex1: REF='474387-05-714500' | HYP='474387-05-714500'
  ex2: REF='탕시우' | HYP='탕시우'
  ex3: REF='부산광역시' | HYP='부산광역시'
Done.
