In [9]:
import os
from pathlib import Path
from typing import Optional, Tuple, Dict
import importlib, sys

import cv2
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from fpn_common import (
    setup_project_path, get_device,
    imread_unicode,
    ResNetFPN, imagenet_normalize_chw,
    assert_saved_matches_backbone,
)

In [11]:
project_root = setup_project_path()
BASE_DIR = project_root / "hangul_dataset" / "korean_generated"
IMAGES_DIR = BASE_DIR / "images"
LABELS_DIR = BASE_DIR / "labels"


TYPE_FOLDERS = [
    "vertical_no_jong",
    "complex_no_jong",
    "horizontal_no_jong",
    "complex_jong",
    "horizontal_jong",
    "vertical_jong",
]

In [13]:
def apply_augment_cv2(
    img_rgb: np.ndarray,
    lbl: np.ndarray,
    rng: np.random.Generator,
    rotate_deg: float = 8.0,
    scale_range: Tuple[float, float] = (0.9, 1.1),
    translate_frac: float = 0.06,
    hflip_p: float = 0.5,
    brightness: float = 0.15,
    contrast: float = 0.15,
) -> Tuple[np.ndarray, np.ndarray]:
    H, W = lbl.shape[:2]

    angle = float(rng.uniform(-rotate_deg, rotate_deg))
    scale = float(rng.uniform(scale_range[0], scale_range[1]))
    tx = float(rng.uniform(-translate_frac, translate_frac) * W)
    ty = float(rng.uniform(-translate_frac, translate_frac) * H)

    M = cv2.getRotationMatrix2D((W * 0.5, H * 0.5), angle, scale)
    M[0, 2] += tx
    M[1, 2] += ty

    img_aug = cv2.warpAffine(
        img_rgb, M, (W, H),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=(0, 0, 0),
    )
    lbl_aug = cv2.warpAffine(
        lbl, M, (W, H),
        flags=cv2.INTER_NEAREST,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=0,
    )

    if rng.random() < hflip_p:
        img_aug = np.ascontiguousarray(img_aug[:, ::-1, :])
        lbl_aug = np.ascontiguousarray(lbl_aug[:, ::-1])

    img_f = img_aug.astype(np.float32) / 255.0
    c = float(rng.uniform(1.0 - contrast, 1.0 + contrast))
    b = float(rng.uniform(-brightness, brightness))
    img_f = (img_f - 0.5) * c + 0.5 + b
    img_f = np.clip(img_f, 0.0, 1.0)

    img_out = (img_f * 255.0).astype(np.uint8)
    lbl_out = lbl_aug.astype(lbl.dtype)
    return img_out, lbl_out


In [15]:
class HangulTypeSegDataset(Dataset):
    def __init__(
        self,
        type_name: str,
        size: Tuple[int, int] = (256, 256),
        augment: bool = False,
        seed: int = 42,
        aug_cfg: Optional[Dict] = None,
    ):
        super().__init__()
        self.type_name = type_name
        self.size = size
        self.augment = augment
        self.seed = seed
        self.aug_cfg = aug_cfg or {}

        self.img_dir = IMAGES_DIR / type_name
        self.lbl_dir = LABELS_DIR / type_name

        self.files = sorted([p.name for p in self.lbl_dir.iterdir() if p.is_file()])
        if len(self.files) == 0:
            raise RuntimeError(f"No label files in {self.lbl_dir}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx: int):
        fn = self.files[idx]
        img_path = self.img_dir / fn
        lbl_path = self.lbl_dir / fn

        img_bgr = imread_unicode(img_path, cv2.IMREAD_COLOR)
        if img_bgr is None:
            raise RuntimeError(f"Failed to read image: {img_path}")
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        img_rgb = cv2.resize(img_rgb, (self.size[1], self.size[0]), interpolation=cv2.INTER_LINEAR)

        lbl = imread_unicode(lbl_path, cv2.IMREAD_GRAYSCALE)
        if lbl is None:
            raise RuntimeError(f"Failed to read label: {lbl_path}")
        lbl = cv2.resize(lbl, (self.size[1], self.size[0]), interpolation=cv2.INTER_NEAREST)

        if self.augment:
            rng = np.random.default_rng(self.seed * 1000003 + idx)
            img_rgb, lbl = apply_augment_cv2(img_rgb, lbl, rng, **self.aug_cfg)

        img = img_rgb.astype(np.float32) / 255.0
        lbl = lbl.astype(np.int64)

        x = torch.from_numpy(img).permute(2, 0, 1)
        x = imagenet_normalize_chw(x)
        y = torch.from_numpy(lbl)

        return x, y, fn


In [17]:
def evaluate(model: nn.Module, loader: DataLoader, device, criterion: nn.Module) -> float:
    model.eval()
    total, n = 0.0, 0
    with torch.no_grad():
        for x, y, _ in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            total += float(loss.item()) * x.size(0)
            n += x.size(0)
    return total / max(n, 1)


In [19]:
def train_one_type(
    type_name: str,
    out_dir: Path,
    size=(256, 256),
    batch_size: int = 8,
    epochs: int = 20,
    lr: float = 1e-3,
    val_ratio: float = 0.1,
    num_workers: int = 0,
    backbone: str = "resnet34",
    pretrained_backbone: bool = True,
    seed: int = 42,
    use_augmentation: bool = True,
    aug_cfg: Optional[Dict] = None,
    plot_loss: bool = True,
    save_plot: bool = True,
):
    out_dir.mkdir(parents=True, exist_ok=True)

    g = torch.Generator().manual_seed(seed)

    base_ds = HangulTypeSegDataset(type_name=type_name, size=size, augment=False, seed=seed)
    n_total = len(base_ds)
    n_val = max(1, int(n_total * val_ratio))
    n_train = n_total - n_val

    all_idx = torch.randperm(n_total, generator=g).tolist()
    val_idx = all_idx[:n_val]
    train_idx = all_idx[n_val:]

    train_ds_full = HangulTypeSegDataset(
        type_name=type_name,
        size=size,
        augment=bool(use_augmentation),
        seed=seed,
        aug_cfg=aug_cfg or dict(
            rotate_deg=8.0,
            scale_range=(0.9, 1.1),
            translate_frac=0.06,
            hflip_p=0.5,
            brightness=0.15,
            contrast=0.15,
        ),
    )
    val_ds_full = HangulTypeSegDataset(type_name=type_name, size=size, augment=False, seed=seed)

    train_ds = Subset(train_ds_full, train_idx)
    val_ds = Subset(val_ds_full, val_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=False)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)

    device = get_device()
    model = ResNetFPN(num_classes=4, backbone=backbone, pretrained=pretrained_backbone).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val = float("inf")
    best_path = out_dir / f"fpn_{type_name}_{backbone}.pth"

    train_losses, val_losses = [], []

    print(f"\n==============================")
    print(f" Train type: {type_name}")
    print(f" total={n_total}, train={n_train}, val={n_val}")
    print(f" aug={use_augmentation} | save -> {best_path}")
    print(f"==============================")

    for epoch in range(1, epochs + 1):
        model.train()
        running, seen = 0.0, 0

        pbar = tqdm(train_loader, desc=f"[{type_name}] Epoch {epoch}/{epochs}", leave=False)
        for x, y, _ in pbar:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            running += loss.item() * x.size(0)
            seen += x.size(0)
            pbar.set_postfix(train_loss=f"{loss.item():.4f}")

        train_loss = running / max(seen, 1)
        train_losses.append(train_loss)

        val_loss = evaluate(model, val_loader, device, criterion)
        val_losses.append(val_loss)

        print(f"[{type_name}] Epoch {epoch:03d}/{epochs} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), best_path)
            assert_saved_matches_backbone(best_path, backbone)

    print(f"\n Done: {type_name}")
    print(f"Best val_loss = {best_val:.4f}")
    print(f"Saved to: {best_path}")

    if plot_loss:
        plt.figure(figsize=(6, 4))
        plt.plot(train_losses, label="Train Loss")
        plt.plot(val_losses, label="Val Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title(f"Loss Curve - {type_name}")
        plt.legend()
        plt.grid(True)
        plt.show()

        if save_plot:
            plot_path = out_dir / f"loss_{type_name}_{backbone}.png"
            plt.figure(figsize=(6, 4))
            plt.plot(train_losses, label="Train Loss")
            plt.plot(val_losses, label="Val Loss")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.title(f"Loss Curve - {type_name}")
            plt.legend()
            plt.grid(True)
            plt.savefig(plot_path, dpi=150)
            plt.close()
            print(f"Loss plot saved: {plot_path}")

    return best_path

In [21]:
def train_all_types(
    out_dir: Path,
    backbone: str = "resnet34",
    pretrained_backbone: bool = True,
    size=(256, 256),
    batch_size=2,
    epochs=10,
    lr=1e-3,
    val_ratio=0.1,
    num_workers=0,
    seed=42,
    use_augmentation=True,
    aug_cfg=None,
):
    saved = {}
    for t in TYPE_FOLDERS:
        best_path = train_one_type(
            type_name=t,
            out_dir=out_dir,
            size=size,
            batch_size=batch_size,
            epochs=epochs,
            lr=lr,
            val_ratio=val_ratio,
            num_workers=num_workers,
            backbone=backbone,
            pretrained_backbone=pretrained_backbone,
            seed=seed,
            use_augmentation=use_augmentation,
            aug_cfg=aug_cfg,
            plot_loss=True,
            save_plot=True,
        )
        saved[t] = best_path
    return saved

In [None]:
if __name__ == "__main__":
    OUT_DIR = project_root / "fpn" / "weights" / "resnet34_run"
    AUG_CFG = dict(
        rotate_deg=2.0,
        scale_range=(0.8, 1.1),
        translate_frac=0.06,
        hflip_p=0.0,
        brightness=0.0,
        contrast=0.0,
    )

    saved = train_all_types(
        out_dir=OUT_DIR,
        backbone="resnet34",
        pretrained_backbone=True,
        size=(256, 256),
        batch_size=2,
        epochs=10,
        lr=1e-3,
        val_ratio=0.1,
        num_workers=0,
        seed=42,
        use_augmentation=True,
        aug_cfg=AUG_CFG,
    )

    print("\nSaved weights:")
    for k, v in saved.items():
        print(k, "->", v)



 Train type: vertical_no_jong
 total=2880, train=2592, val=288
 aug=True | save -> D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\fpn\weights\resnet34_run\fpn_vertical_no_jong_resnet34.pth


[vertical_no_jong] Epoch 1/10:   0%|          | 0/1296 [00:00<?, ?it/s]