In [51]:
from pathlib import Path
import sys
from typing import Optional, Tuple, Dict, List
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt

def setup_project_path():
    current = Path.cwd()
    while current != current.parent and not (current / "fpn").exists():
        current = current.parent
    if not (current / "fpn").exists():
        raise RuntimeError("Could not find project_root containing 'fpn'")
    return current

project_root = setup_project_path()
sys.path.insert(0, str(project_root))

print("project_root:", project_root)

project_root: D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis


In [53]:
def ensure_out_dirs(save_dir: Path):
    save_dir = Path(save_dir)
    (save_dir / "labels_png").mkdir(parents=True, exist_ok=True)
    (save_dir / "labels_npy").mkdir(parents=True, exist_ok=True)
    (save_dir / "masked").mkdir(parents=True, exist_ok=True)
    (save_dir / "images").mkdir(parents=True, exist_ok=True)  # 필요 없으면 아래 저장에서 끄면 됨
    return save_dir

def save_label_raw_split(save_dir: Path, stem: str, pred_mask: np.ndarray):
    """
    pred_mask: HxW, values {0,1,2,3}
    - PNG: 0~3 그대로 저장
    - NPY: 원본 저장
    """
    save_dir = Path(save_dir)
    m = pred_mask.astype(np.uint8)

    # 0~3 그대로 grayscale PNG
    imwrite_unicode(save_dir / "labels_png" / f"{stem}.png", m)

    # 원본 NPY
    np.save(str(save_dir / "labels_npy" / f"{stem}.npy"), m)

def save_masked_one(save_dir: Path, stem: str, img_rgb: np.ndarray, pred_mask: np.ndarray, alpha: float = 0.45,
                    masked_mode: str = "overlay"):
    """
    masked_mode:
      - "overlay": 원본+마스크 overlay (추천)
      - "color"  : 마스크 색상만 (원본 없이)
    """
    save_dir = Path(save_dir)

    pred_mask_u8 = pred_mask.astype(np.uint8)
    pred_color = colorize_mask(pred_mask_u8)  # RGB

    if masked_mode == "overlay":
        out_rgb = overlay(img_rgb, pred_color, alpha=alpha)  # RGB
    elif masked_mode == "color":
        out_rgb = pred_color
    else:
        raise ValueError("masked_mode must be 'overlay' or 'color'")

    # RGB -> BGR 저장
    imwrite_unicode(save_dir / "masked" / f"{stem}.png", cv2.cvtColor(out_rgb, cv2.COLOR_RGB2BGR))

def save_image_optional(save_dir: Path, stem: str, img_rgb: np.ndarray, save_image: bool = True):
    if not save_image:
        return
    save_dir = Path(save_dir)
    imwrite_unicode(save_dir / "images" / f"{stem}.png", cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR))


In [63]:
@torch.no_grad()
def segment_sentence_from_chars(
    sentence: str,
    char_dir: str | Path,
    weight_map: Dict[str, str | Path],
    size: Tuple[int, int] = (256, 256),
    alpha: float = 0.45,
    backbone: str = "resnet34",
    pretrained_backbone: bool = False,
    file_exts: Tuple[str, ...] = (".png", ".jpg", ".jpeg", ".bmp"),

    # --- saving ---
    save_dir: Optional[str | Path] = None,
    save_each_char: bool = True,
    masked_mode: str = "overlay",   # "overlay" or "color"
    save_image: bool = True,       # 원본도 저장할지
):
    """
    - 글자별 pred_mask(0~3) 저장: labels_png/, labels_npy/
    - 상태 확인용 1종 저장: masked/ (overlay 또는 color 중 하나)
    - 원본 저장은 옵션(images/)
    - concat 저장/리턴 없음
    """
    char_dir = Path(char_dir)
    if not char_dir.exists():
        raise FileNotFoundError(f"char_dir not found: {char_dir}")

    files = sorted([p for p in char_dir.iterdir() if p.is_file() and p.suffix.lower() in file_exts])
    if len(files) == 0:
        raise RuntimeError(f"No image files found in: {char_dir}")

    chars = [c for c in sentence if c != " "]
    if len(chars) != len(files):
        n = min(len(chars), len(files))
        print(f"[WARN] mismatch: sentence chars={len(chars)} vs files={len(files)}. Using first {n}.")
        chars = chars[:n]
        files = files[:n]

    device = get_device()
    model_cache: Dict[str, torch.nn.Module] = {}

    if save_dir is not None:
        save_dir = ensure_out_dirs(Path(save_dir))

    results: List[Dict] = []

    for i, (ch, img_path) in enumerate(zip(chars, files)):
        type_name = hangul_char_to_type(ch)
        if type_name not in weight_map:
            raise KeyError(f"weight_map missing type '{type_name}'. Available: {list(weight_map.keys())}")
        wpath = Path(weight_map[type_name])

        # 모델 로드/캐시
        if type_name not in model_cache:
            model = ResNetFPN(num_classes=4, backbone=backbone, pretrained=pretrained_backbone).to(device)
            state = torch.load(str(wpath), map_location=device)
            model.load_state_dict(state)
            model.eval()
            model_cache[type_name] = model
        model = model_cache[type_name]

        # 이미지 로드
        img_bgr = imread_unicode(img_path, cv2.IMREAD_COLOR)
        if img_bgr is None:
            raise RuntimeError(f"Failed to read image: {img_path}")

        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        img_rgb = cv2.resize(img_rgb, (size[1], size[0]), interpolation=cv2.INTER_LINEAR)

        # 텐서화 + normalize
        x = img_rgb.astype(np.float32) / 255.0
        x = torch.from_numpy(x).permute(2, 0, 1).to(device)
        x = imagenet_normalize_chw(x)
        x = x.unsqueeze(0)

        # 추론
        logits = model(x)
        pred = torch.argmax(logits, dim=1)[0].cpu().numpy().astype(np.uint8)  # 0~3

        # 결과 기록
        stem = f"{i:02d}_{ch}_{type_name}"

        results.append({
            "index": i,
            "char": ch,
            "type": type_name,
            "file": img_path.name,
            "weight": str(wpath),
            "stem": stem,
            "pred_mask": pred,
        })

        # 저장
        if save_dir is not None and save_each_char:
            save_label_raw_split(save_dir, stem, pred)
            save_masked_one(save_dir, stem, img_rgb, pred, alpha=alpha, masked_mode=masked_mode)
            save_image_optional(save_dir, stem, img_rgb, save_image=save_image)

    return results


In [75]:
weight_dir = project_root / "fpn" / "weights_final" 

weight_map = {
    "complex_jong":        weight_dir / "fpn_complex_jong_resnet34.pth",
    "complex_no_jong":     weight_dir / "fpn_complex_no_jong_resnet34.pth",
    "horizontal_jong":     weight_dir / "fpn_horizontal_jong_resnet34.pth",
    "horizontal_no_jong":  weight_dir / "fpn_horizontal_no_jong_resnet34.pth",
    "vertical_jong":       weight_dir / "fpn_vertical_jong_resnet34.pth",
    "vertical_no_jong":    weight_dir / "fpn_vertical_no_jong_resnet34.pth",
}


In [85]:
sentence = "소프트웨어분석"
char_dir = project_root / "results" / "segment_results" / "printed_chars" / "images"
#char_dir = project_root / "characters" / "cropped" / "test5"
out_dir = project_root / "results" / "segment_results" / "printed_chars"
#out_dir = project_root / "results" / "segment_results"

results = segment_sentence_from_chars(
    sentence=sentence,
    char_dir = char_dir,
    weight_map=weight_map,
    size=(256, 256),
    alpha=0.45,
    backbone="resnet34",
    pretrained_backbone=False,
    save_dir=out_dir,
    save_each_char=True,
)

for r in results:
    print(r["index"], r["char"], r["type"], r["file"])


  state = torch.load(str(wpath), map_location=device)


0 소 horizontal_no_jong 00_U+C18C.png
1 프 horizontal_no_jong 01_U+D504.png
2 트 horizontal_no_jong 02_U+D2B8.png
3 웨 complex_no_jong 03_U+C6E8.png
4 어 vertical_no_jong 04_U+C5B4.png
5 분 horizontal_jong 05_U+BD84.png
6 석 vertical_jong 06_U+C11D.png
