# Vision Transformer (ViT)

In [None]:
# libarary import and setting
import os, csv, json, unicodedata
import math, random, re
from typing import Optional, List

import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torchvision import transforms

import matplotlib.pyplot as plt
from matplotlib import rc

# device, seed 설정
device = torch.device(
    "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
    else "cpu"
)
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)

# 경로 설정
CSV_TRAIN = "../text_crops/train_labels.csv"
CSV_VAL = "../text_crops/valid_labels.csv"
CHARSET_JSON = "charset.json"
save_dir  = "../vitctc_best.pt"

# ViT 입력 설정
IMG_H = 64      # 이미지 세로 길이 고정
IN_CH = 1       # 입력 채널 수 (흑백=1)
PATCH_W = 8     # 패치 너비
STRIDE = 8      # 패치 스트라이드

# 학습 설정
BATCH = 32
LR = 0.0003      
EPOCHS = 10

# 모델 크기
D_MODEL = 256   #hidden dimension
N_HEAD = 8      #attention head 수
N_LAYERS = 4    #encoder layer 수
FF_DIM = 1024   #feed-forward hidden dimension

# charset class 설정
class Charset:
    def __init__(self, itos, blank_idx=0):
        self.itos = itos
        self.blank_idx = blank_idx
        self.stoi = {ch:i for i,ch in enumerate(itos)}

    def encode(self, s):
    # 문자열 -> 인덱스 리스트
        return [self.stoi[ch] for ch in s if ch in self.stoi]

    def decode_ctc(self, seq: List[int]) -> str:
    # CTC 결과 -> 문자열
        res, prev = [], None
        for s in seq:
            if s != self.blank_idx and s != prev:
                res.append(self.itos[s])
            prev = s
        return "".join(res)

    @property
    def size(self):
        return len(self.itos)

with open("charset.json","r",encoding="utf-8") as f:
    itos = json.load(f)
charset = Charset(itos, blank_idx=0)

In [None]:
## 데이터 관련 함수
# data read 함수
DATA_ROOT = "../text_crops"

def resolve_path(p: str) -> str:
    p = p[32:]
    p = p.strip().replace("\\", "/")
    return os.path.normpath(os.path.join(DATA_ROOT, p))

def read_rows(csv_path, path_key="crop_path", text_key="text"):
    rows = []
    with open(csv_path, "r", encoding="utf-8") as f:
        r = csv.DictReader(f)
        for row in r:
            p = resolve_path(row.get(path_key))
            t = row.get(text_key)
            rows.append({"crop_path": p, "text": t})
    print(f"[{csv_path}] loaded {len(rows)} rows")
    return rows

# 이미지 전처리 함수
class Resize:
    """이미지 세로길이 통일하며 원본 비율 유지"""
    def __init__(self, h: int):
        self.h = h

    def __call__(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        new_w = max(1, int(w * (self.h / float(h))))
        return img.resize((new_w, self.h), Image.BILINEAR)

def resize_keep_ratio_pad(img: Image.Image, target_h: int, target_min_w: int, pad_multiple: int):
    """ViT용 이미지 전처리: 비율 유지 + 패딩"""
    w, h = img.size
    scale = target_h / h
    new_w = max(target_min_w, int(math.ceil(w * scale)))
    img = img.resize((new_w, target_h), Image.BILINEAR)
    
    pad_w = int(math.ceil(new_w / pad_multiple) * pad_multiple)
    canvas = Image.new('L', (pad_w, target_h), color=255)
    canvas.paste(img, (0, 0))
    return canvas, new_w, pad_w

class OCRDataset(Dataset):
    """CSV(crop_path,text) → (tensor[C,H,W], target[int-seq])"""
    def __init__(self, csv_path: str, charset: Charset, target_h: int = IMG_H, path_key="crop_path"):
        self.rows = read_rows(csv_path, path_key=path_key, text_key="text")
        self.cs = charset
        self.norm = "NFC"
        self.target_h = target_h

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        
        img = Image.open(r["crop_path"]).convert("L")  # 흑백
        # ViT용 전처리 적용
        img_resized, new_w, pad_w = resize_keep_ratio_pad(
            img, self.target_h, PATCH_W, PATCH_W
        )
        x = TF.to_tensor(img_resized)  # [1,H,W]
        
        txt = unicodedata.normalize(self.norm, r["text"])
        txt = "".join(ch for ch in txt if ch in self.cs.stoi)  # OOV 제거
        y = torch.tensor(self.cs.encode(txt), dtype=torch.long)
        
        # ViT용 추가 정보
        steps = int(math.ceil(new_w / PATCH_W))
        
        return x, y, steps

def vit_collate(batch):
    """ViT용 배치 collate 함수"""
    xs, ys, steps_list = zip(*batch)
    
    # 이미지 패딩 (배치 내 최대 너비로 통일)
    C, H = xs[0].shape[:2]
    max_W = max(x.shape[2] for x in xs)
    xs_pad = [F.pad(x, (0, max_W - x.shape[2], 0, 0)) for x in xs]
    xs = torch.stack(xs_pad, 0)  # [B,C,H,max_W]
    
    # CTC용 타겟 준비
    y_lens = torch.tensor([y.numel() for y in ys], dtype=torch.long)
    ys = torch.cat(ys, 0)  # [sum(L)]
    
    # ViT용 입력 길이
    input_lengths = torch.tensor(steps_list, dtype=torch.long)
    
    return xs, ys, y_lens, input_lengths

In [None]:
# ViT 모델 (위치 인코딩 포함)

class SinusoidalPositionalEncoding(nn.Module):
    """사인/코사인 위치 인코딩"""
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model

    def forward(self, x):
        T = x.size(0)
        device = x.device
        pe = torch.zeros(T, self.d_model, device=device)
        pos = torch.arange(0, T, device=device).unsqueeze(1)
        div = torch.exp(torch.arange(0, self.d_model, 2, device=device) * 
                       (-math.log(10000.0) / self.d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        return pe

class ViTCTC(nn.Module):
    """ViT 모델"""
    def __init__(self, vocab_size: int, img_h=IMG_H, patch_w=PATCH_W, 
                 d_model=D_MODEL, nhead=N_HEAD, num_layers=N_LAYERS, dim_ff=FF_DIM):
        super().__init__()
        self.img_h = img_h
        self.patch_w = patch_w
        self.d_model = d_model
        
        # 패치 임베딩 (이미지를 세로 스트립으로 분할)
        self.patch_embed = nn.Conv2d(1, d_model, kernel_size=(img_h, patch_w), 
                                   stride=(img_h, patch_w))
        
        # 트랜스포머 인코더
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=0.1, activation='relu', batch_first=False, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        
        # 위치 인코딩
        self.pos_enc = SinusoidalPositionalEncoding(d_model)
        
        # 분류 헤드
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # 패치 임베딩: [B,C,H,W] → [T,B,D]
        tokens = self.patch_embed(x).squeeze(2).permute(2, 0, 1)  # [T,B,D]
        
        # 위치 인코딩 추가
        tokens = tokens + self.pos_enc(tokens).unsqueeze(1)
        
        # 트랜스포머 인코딩
        enc = self.encoder(tokens)  # [T,B,D]
        
        # 분류
        logits = self.fc(enc).permute(1,0,2)  # [T,B,V]
        
        return logits, tokens.size(0)  # logits, sequence_length
    
# train/val dataset & dataloader
train_set = OCRDataset(CSV_TRAIN, charset, target_h=IMG_H, path_key="crop_path")
val_set = OCRDataset(CSV_VAL, charset, target_h=IMG_H, path_key="crop_path")

train_loader = DataLoader(train_set, batch_size=BATCH, shuffle=True, collate_fn=vit_collate)
val_loader = DataLoader(val_set, batch_size=BATCH, shuffle=False, collate_fn=vit_collate)

# model
VOCAB_SIZE = charset.size
model = ViTCTC(vocab_size=VOCAB_SIZE, img_h=IMG_H, patch_w=PATCH_W,
               d_model=D_MODEL, nhead=N_HEAD, num_layers=N_LAYERS, 
               dim_ff=FF_DIM).to(device)

criterion = nn.CTCLoss(blank=charset.blank_idx, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)


In [None]:
# ctc header functions
def ctc_loss_fn(model_output, targets, input_lengths, target_lengths, blank=0):
    """CTC loss 계산"""
    if isinstance(model_output, tuple):
        logits = model_output[0]  
    else:
        logits = model_output
    
    log_probs = F.log_softmax(logits.permute(1,0,2), dim=-1).cpu()    # (T,B,C)
    return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=blank, zero_infinity=True)

@torch.no_grad()
def greedy_decode(model_output, input_lengths, blank=0):
    """CTC greedy 디코딩"""
    if isinstance(model_output, tuple):
        logits = model_output[0]  
    else:
        logits = model_output
        
    preds = logits.argmax(-1)  # (B,T)
    out = []
    for b in range(preds.size(0)):
        L = int(input_lengths[b].item())
        out.append(preds[b,:L].tolist())
    return out

def _levenshtein(a: str, b: str) -> int:
    """레벤슈타인 거리 계산 (공간 최적화 버전)"""
    n, m = len(a), len(b)
    dp = list(range(m+1))
    for i in range(1, n+1):
        prev = dp[0]; dp[0] = i
        for j in range(1, m+1):
            tmp = dp[j]
            cost = 0 if a[i-1]==b[j-1] else 1
            dp[j] = min(dp[j]+1, dp[j-1]+1, prev+cost)
            prev = tmp
    return dp[m]

def cer(ref: str, hyp: str) -> float:
    """문자 단위 CER 계산"""
    if len(ref)==0:
        return 0.0 if len(hyp)==0 else 1.0
    return _levenshtein(ref, hyp) / len(ref)

@torch.no_grad()
def evaluate(model, loader, charset, device):
    """모델 평가 함수"""
    model.eval()
    total = 0.0
    n = 0
    
    for batch in loader:
        # 튜플 언패킹 (vit_collate 함수에 맞춤)
        imgs, targets, target_lengths, input_lengths = batch
        imgs = imgs.to(device)
        input_lengths = input_lengths.to(device)
        
        model_output = model(imgs)
        pred_ids = greedy_decode(model_output, input_lengths, blank=charset.blank_idx)
        preds = [charset.decode_ctc(s) for s in pred_ids]
        
        # CER 계산
        target_start = 0
        for i, p in enumerate(preds):
            target_length = int(target_lengths[i].item())
            target_end = target_start + target_length
            target_ids = targets[target_start:target_end].tolist()
            r = charset.decode_ctc(target_ids)
            
            total += cer(r, p)
            n += 1
            target_start = target_end
    
    return total / max(1, n)

def train(model, train_loader, val_loader, charset, device, epochs=EPOCHS, lr=LR):
    """모델 학습 함수"""
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    best = 1e9
    
    for ep in range(1, epochs+1):
        model.train()
        total = 0.0
        
        for batch in train_loader:
            imgs, targets, target_lengths, input_lengths = batch
            
            imgs = imgs.to(device)
            targets = targets.to(device)
            input_lengths = input_lengths.to(device)
            target_lengths = target_lengths.to(device)
            
            model_output = model(imgs)
            loss = ctc_loss_fn(model_output, targets, input_lengths, target_lengths, blank=charset.blank_idx)
            
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            
            total += loss.item()
        
        val_cer = evaluate(model, val_loader, charset, device)
        print(f"[{ep}] train_loss={total/len(train_loader):.4f} | val_CER={val_cer:.4f}")
        
        if val_cer < best:
            best = val_cer
            os.makedirs(save_dir, exist_ok=True)
            torch.save({"model": model.state_dict(),
                       "config": {"img_h": model.img_h, "patch_w": model.patch_w,
                                  "d_model": model.d_model, "num_classes": charset.size}},
                       f"{save_dir}/best_vit_ctc.pt")
            
            with open(f"{save_dir}/charset.json","w",encoding="utf-8") as f:
                json.dump(charset.itos[1:], f, ensure_ascii=False, indent=2)
            print("  saved:", f"{save_dir}/best_vit_ctc0909.pt")
    return best

# Vit-CTC 모델 학습 결과 최종 출력
best_cer = train(model, train_loader, val_loader, charset, device)
print("최종 best CER:", best_cer)