In [1]:
import os
from pathlib import Path
from typing import Dict, Optional, Tuple, List

import cv2
import numpy as np
import torch  

In [2]:
def imwrite_unicode(path, img):
    path = str(path)
    ext = os.path.splitext(path)[1]
    ok, buf = cv2.imencode(ext, img)
    if not ok:
        return False
    with open(path, "wb") as f:
        f.write(buf.tobytes())
    return True

def imread_unicode(path, flags=cv2.IMREAD_COLOR):
    try:
        path = str(path)
        with open(path, "rb") as f:
            data = f.read()
        img_array = np.frombuffer(data, np.uint8)
        img = cv2.imdecode(img_array, flags)
        return img
    except Exception as e:
        print("[imread_unicode ERROR]", e, "->", path)
        return None


In [3]:
def colorize_label(mask_0_3: np.ndarray) -> np.ndarray:
    palette = np.array([
        [0,   0,   0  ],   # 0 bg
        [255, 0,   0  ],   # 1 cho
        [0,   255, 0  ],   # 2 jung
        [0,   0,   255],   # 3 jong
    ], dtype=np.uint8)
    m = np.clip(mask_0_3.astype(np.int32), 0, 3)
    return palette[m]

def overlay_rgb(img_rgb: np.ndarray, mask_rgb: np.ndarray, alpha: float = 0.45) -> np.ndarray:
    out = (1 - alpha) * img_rgb.astype(np.float32) + alpha * mask_rgb.astype(np.float32)
    return np.clip(out, 0, 255).astype(np.uint8)


In [4]:
def make_ink_mask_white_on_black(
    img_rgb: np.ndarray,
    thr: int = 160,
    blur_ksize: int = 0,
    morph_open: int = 0,
    morph_close: int = 0,
) -> np.ndarray:
    """
    항상 '검은 배경 + 흰 잉크'를 가정.
    - gray > thr => ink(255)
    return: HxW uint8 {0,255}
    """
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)

    if blur_ksize and blur_ksize > 1:
        gray = cv2.GaussianBlur(gray, (blur_ksize, blur_ksize), 0)

    # white ink on black bg
    ink = (gray > thr).astype(np.uint8) * 255

    if morph_open and morph_open > 1:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_open, morph_open))
        ink = cv2.morphologyEx(ink, cv2.MORPH_OPEN, k)

    if morph_close and morph_close > 1:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_close, morph_close))
        ink = cv2.morphologyEx(ink, cv2.MORPH_CLOSE, k)

    return ink


In [5]:
def make_clustered_mask_vis_white_on_black(
    refined_label: np.ndarray,    # HxW uint8 (0~3)
    ink_mask_u8: np.ndarray,      # HxW uint8 {0,255}
    bg_black: bool = True,        # 배경을 검정으로 유지
) -> np.ndarray:
    """
    원본 이미지 없이, 클러스터링(최종 라벨) 결과만 시각화.
    - 잉크 영역만 라벨 색으로 칠함
    - 배경은 검정(기본) 또는 흰색
    return: HxWx3 RGB uint8
    """
    H, W = refined_label.shape
    bg_val = 0 if bg_black else 255
    out = np.full((H, W, 3), bg_val, dtype=np.uint8)

    ink = (ink_mask_u8 > 0)
    out[ink] = colorize_label(refined_label)[ink]

    return out


In [6]:
def refine_with_seed_distance(
    ink_mask_u8: np.ndarray,     # HxW uint8 {0,255}
    pred_label: np.ndarray,      # HxW uint8 (0~3)
    seed_label: np.ndarray,      # HxW uint8 (0~3)
    max_dist: float = 40.0,      # seed 신뢰 반경
    allow_classes: Tuple[int, ...] = (1, 2, 3),
    fallback_to_pred: bool = True,   # NEW: 멀면 pred로 fallback
) -> np.ndarray:
    """
    ink 픽셀만 대상으로 seed 거리 기반으로 1/2/3 재할당.
    - seed에 존재하지 않는 클래스는 후보에서 제외(예: 종성 없는 글자).
    - min_dist > max_dist이면:
        - fallback_to_pred=True: pred_label 유지 (권장)
        - False: 0(배경) 처리 (기존 방식)
    """
    H, W = pred_label.shape
    ink = (ink_mask_u8 > 0)

    candidates = []
    dist_maps = []

    for k in allow_classes:
        if np.any(seed_label == k):
            candidates.append(k)
            seed_k = (seed_label == k).astype(np.uint8)
            inv = (1 - seed_k) * 255
            dist = cv2.distanceTransform(inv, distanceType=cv2.DIST_L2, maskSize=3)
            dist_maps.append(dist)

    if len(candidates) == 0:
        return pred_label.copy()

    stack = np.stack(dist_maps, axis=0)   # CxHxW
    min_dist = np.min(stack, axis=0)
    arg = np.argmin(stack, axis=0)
    assigned = np.take(np.array(candidates, dtype=np.uint8), arg)

    refined = pred_label.copy()
    refined[ink] = assigned[ink]

    far = ink & (min_dist > max_dist)
    if not fallback_to_pred:
        refined[far] = 0
    else:
        # fallback: pred 유지 (이미 refined는 pred copy에서 시작했으므로, seed로 덮어쓴 부분만 되돌리면 됨)
        refined[far] = pred_label[far]

    return refined


In [7]:
def cleanup_small_label_fragments_in_components(
    ink_mask_u8: np.ndarray,
    label_0_3: np.ndarray,
    min_ratio: float = 0.03,     # component 내에서 3% 미만 라벨은 제거
    min_pixels: int = 15,        # 너무 작은 조각 제거
    connectivity: int = 8,
) -> np.ndarray:
    """
    각 ink component(연결요소) 내부에서
    - 라벨별 픽셀 비율이 매우 작거나(min_ratio)
    - 절대 픽셀 수도 작은(min_pixels)
    라벨을 component의 다수 라벨로 흡수.
    """
    ink = (ink_mask_u8 > 0).astype(np.uint8)
    num, cc = cv2.connectedComponents(ink, connectivity=connectivity)
    out = label_0_3.copy()

    for cid in range(1, num):
        region = (cc == cid)
        if not np.any(region):
            continue

        # 배경 제외(0)
        vals, counts = np.unique(out[region], return_counts=True)
        total = int(region.sum())

        # component에서 다수 라벨(0 제외)
        best_label = 0
        best_count = -1
        for v, c in zip(vals, counts):
            if v == 0:
                continue
            if int(c) > best_count:
                best_count = int(c)
                best_label = int(v)

        if best_label == 0:
            # component 전체가 0이면 스킵
            continue

        # 소형 라벨 제거
        for v, c in zip(vals, counts):
            v = int(v); c = int(c)
            if v == 0 or v == best_label:
                continue
            if c < min_pixels or (c / max(total, 1)) < min_ratio:
                out[region & (out == v)] = best_label

    return out


In [8]:
def load_label_any(path: Path) -> np.ndarray:
    """
    label이 png(0~3)든 npy든 읽어서 HxW uint8 반환.
    """
    if path.suffix.lower() == ".npy":
        arr = np.load(str(path))
        if arr.dtype != np.uint8:
            arr = arr.astype(np.uint8)
        return arr
    else:
        img = imread_unicode(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise FileNotFoundError(path)
        return img.astype(np.uint8)

def save_label_outputs(
    out_root: Path,
    stem: str,
    refined_label: np.ndarray,
    ink_mask_u8: np.ndarray,
    save_npy: bool = True,
    save_masked: bool = True,
):
    (out_root / "labels_png").mkdir(parents=True, exist_ok=True)
    (out_root / "labels_npy").mkdir(parents=True, exist_ok=True)
    (out_root / "masked").mkdir(parents=True, exist_ok=True)

    # label png (0~3)
    imwrite_unicode(out_root / "labels_png" / f"{stem}.png", refined_label.astype(np.uint8))

    # label npy
    if save_npy:
        np.save(str(out_root / "labels_npy" / f"{stem}.npy"), refined_label.astype(np.uint8))

    # masked: cluster result only (NO original image)
    if save_masked:
        vis = make_clustered_mask_vis_white_on_black(
            refined_label=refined_label,
            ink_mask_u8=ink_mask_u8,
            bg_black=True,   # 당신 입력이 항상 검은 배경이므로 일관되게 검정 유지
        )
        # 저장은 BGR
        imwrite_unicode(out_root / "masked" / f"{stem}.png", cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))


In [9]:
def make_clustered_mask_vis(
    refined_label: np.ndarray,   # HxW, 0~3
    ink_mask_u8: np.ndarray,     # HxW, 0/255
    bg_color: Tuple[int,int,int] = (255, 255, 255),
) -> np.ndarray:
    """
    클러스터링(후처리) 결과 자체를 보여주는 시각화.
    - 원본 이미지 사용 안 함
    - ink 영역만 라벨 컬러 표시
    """
    H, W = refined_label.shape
    out = np.full((H, W, 3), bg_color, dtype=np.uint8)

    label_rgb = colorize_label(refined_label)

    ink = (ink_mask_u8 > 0)
    out[ink] = label_rgb[ink]

    return out


In [10]:
def refine_folder_with_print_seeds(
    hw_img_dir: str | Path,          # 손글씨 원본 이미지 폴더
    hw_pred_dir: str | Path,         # 손글씨 FPN 예측 라벨 폴더 (png 또는 npy)
    print_seed_dir: str | Path,      # 인쇄체 seed 라벨 폴더 (png 또는 npy)
    out_dir: str | Path,             # 출력 폴더 (refined)
    file_exts_img=(".png", ".jpg", ".jpeg", ".bmp"),
    pred_ext_priority=(".npy", ".png"),   # pred 라벨 찾는 우선순위
    seed_ext_priority=(".npy", ".png"),   # seed 라벨 찾는 우선순위

    # ink mask params
    ink_thr: int = 200,
    ink_blur: int = 0,
    ink_open: int = 0,
    ink_close: int = 0,

    # refinement params
    max_dist: float = 40.0,
    alpha: float = 0.45,

    # optional cleanup
    do_cleanup: bool = True,
    cleanup_min_ratio: float = 0.03,
    cleanup_min_pixels: int = 15,
):
    hw_img_dir = Path(hw_img_dir)
    hw_pred_dir = Path(hw_pred_dir)
    print_seed_dir = Path(print_seed_dir)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # 손글씨 이미지 목록
    img_files = sorted([p for p in hw_img_dir.iterdir() if p.is_file() and p.suffix.lower() in file_exts_img])
    if len(img_files) == 0:
        raise RuntimeError(f"No images in: {hw_img_dir}")

    def find_label_file(label_dir: Path, stem: str, priority=(".npy", ".png")) -> Optional[Path]:
        for ext in priority:
            p = label_dir / f"{stem}{ext}"
            if p.exists():
                return p
        return None

    print(f"[REFINE] images={len(img_files)}")
    print(f" - hw_img_dir     : {hw_img_dir}")
    print(f" - hw_pred_dir    : {hw_pred_dir}")
    print(f" - print_seed_dir : {print_seed_dir}")
    print(f" - out_dir        : {out_dir}")

    for img_path in img_files:
        stem = img_path.stem

        pred_path = find_label_file(hw_pred_dir, stem, pred_ext_priority)
        seed_path = find_label_file(print_seed_dir, stem, seed_ext_priority)

        if pred_path is None or seed_path is None:
            print(f"[SKIP] missing pred/seed for {stem} | pred={pred_path} seed={seed_path}")
            continue

        # load hw image (RGB)
        bgr = imread_unicode(img_path, cv2.IMREAD_COLOR)
        if bgr is None:
            print(f"[SKIP] failed to read image: {img_path}")
            continue
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        # load labels
        pred = load_label_any(pred_path)   # uint8
        seed = load_label_any(seed_path)   # uint8

        # shape 맞추기 (안 맞으면 이미지 기준으로 강제)
        H, W = rgb.shape[:2]
        if pred.shape != (H, W):
            pred = cv2.resize(pred, (W, H), interpolation=cv2.INTER_NEAREST)
        if seed.shape != (H, W):
            seed = cv2.resize(seed, (W, H), interpolation=cv2.INTER_NEAREST)

        # ink mask
        ink = make_ink_mask_white_on_black(
            rgb,
            thr=160,        # 필요하면 140~200 사이에서 튜닝
            blur_ksize=0,
            morph_open=0,
            morph_close=0,
        )
        
        # 1) seeded refinement
        refined = refine_with_seed_distance(
            ink_mask_u8=ink,
            pred_label=pred,
            seed_label=seed,
            max_dist=max_dist,
            allow_classes=(1, 2, 3),
        )
        
        # 2) 1차: 잉크 아닌 곳은 무조건 0
        refined[ink == 0] = 0
        
        # 3) optional cleanup
        if do_cleanup:
            refined = cleanup_small_label_fragments_in_components(
                ink_mask_u8=ink,
                label_0_3=refined,
                min_ratio=cleanup_min_ratio,
                min_pixels=cleanup_min_pixels,
                connectivity=8,
            )
            # 2차: 잉크 아닌 곳은 무조건 0
            refined[ink == 0] = 0
        
        # 4) save outputs (labels_npy는 클러스터된 최종 refined 저장)
        save_label_outputs(
            out_root=out_dir,
            stem=stem,
            refined_label=refined,
            ink_mask_u8=ink,
            save_npy=True,
            save_masked=True,
        )
        


In [25]:
# 예시: 글자 단위 폴더
HW_IMG_DIR     = Path(r"D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\images")# 손글씨 원본(또는 크롭된 char 이미지)
HW_PRED_DIR    = Path(r"D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\labels_npy")# 손글씨 FPN 예측 라벨 (png/npy)
PRINT_SEED_DIR = Path(r"D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\printed_chars\labels_npy") # 인쇄체 seed 라벨 (png/npy)
OUT_DIR        = Path(r"D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\clustered")        # 출력

refine_folder_with_print_seeds(
    hw_img_dir=HW_IMG_DIR,
    hw_pred_dir=HW_PRED_DIR,
    print_seed_dir=PRINT_SEED_DIR,
    out_dir=OUT_DIR,

    # ink mask 튜닝 (흰 배경/검은 잉크 가정)
    ink_thr=200,
    ink_blur=0,
    ink_open=0,
    ink_close=0,

    # seed clustering 강도(작을수록 seed에 더 강하게 붙음)
    max_dist=8.0,

    # component cleanup
    do_cleanup=True,
    cleanup_min_ratio=0.03,
    cleanup_min_pixels=15,
)


[REFINE] images=7
 - hw_img_dir     : D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\images
 - hw_pred_dir    : D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\labels_npy
 - print_seed_dir : D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\printed_chars\labels_npy
 - out_dir        : D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\clustered
