In [1]:
import os
from pathlib import Path
from typing import Optional, Tuple

import cv2
import numpy as np
from tqdm import tqdm

In [2]:
def imwrite_unicode(path, img):
    path = str(path)
    ext = os.path.splitext(path)[1]
    ok, buf = cv2.imencode(ext, img)
    if not ok:
        return False
    with open(path, "wb") as f:
        f.write(buf.tobytes())
    return True

def imread_unicode(path, flags=cv2.IMREAD_COLOR):
    try:
        path = str(path)
        with open(path, "rb") as f:
            data = f.read()
        img_array = np.frombuffer(data, np.uint8)
        img = cv2.imdecode(img_array, flags)
        return img
    except Exception as e:
        print("[imread_unicode ERROR]", e)
        return None


In [3]:
def load_label_any(path: str | Path) -> np.ndarray:
    path = Path(path)
    if path.suffix.lower() == ".npy":
        arr = np.load(str(path))
        if arr.dtype != np.uint8:
            arr = arr.astype(np.uint8)
        return arr
    else:
        m = imread_unicode(path, cv2.IMREAD_GRAYSCALE)
        if m is None:
            raise FileNotFoundError(path)
        if m.dtype != np.uint8:
            m = m.astype(np.uint8)
        return m


In [4]:
def colorize_label(label_0_3: np.ndarray) -> np.ndarray:
    # RGB palette: 0=black, 1=red, 2=green, 3=blue
    palette = np.array(
        [
            [0,   0,   0  ],
            [255, 0,   0  ],
            [0,   255, 0  ],
            [0,   0,   255],
        ],
        dtype=np.uint8
    )
    lab = np.clip(label_0_3, 0, 3).astype(np.uint8)
    return palette[lab]


In [5]:
def make_masked_vis_white_on_black(
    refined_label: np.ndarray,   # HxW uint8 0~3
    ink_mask_u8: np.ndarray,     # HxW uint8 {0,255}
) -> np.ndarray:
    """
    출력용 시각화:
    - 배경은 검정(0)
    - 잉크 픽셀만 라벨 팔레트로 컬러링
    """
    H, W = refined_label.shape
    out = np.zeros((H, W, 3), dtype=np.uint8)
    ink = (ink_mask_u8 > 0)
    colored = colorize_label(refined_label)
    out[ink] = colored[ink]
    return out


In [6]:
def save_label_outputs(
    out_root: Path,
    stem: str,
    refined_label: np.ndarray,
    ink_mask_u8: np.ndarray,
    save_npy: bool = True,
    save_masked: bool = True,
):
    out_root = Path(out_root)
    (out_root / "labels_png").mkdir(parents=True, exist_ok=True)
    (out_root / "labels_npy").mkdir(parents=True, exist_ok=True)
    (out_root / "masked").mkdir(parents=True, exist_ok=True)

    refined_u8 = refined_label.astype(np.uint8)

    # label png (0~3)
    imwrite_unicode(out_root / "labels_png" / f"{stem}.png", refined_u8)

    # label npy
    if save_npy:
        np.save(str(out_root / "labels_npy" / f"{stem}.npy"), refined_u8)

    # masked visualization (clustered/cleaned label only)
    if save_masked:
        vis_rgb = make_masked_vis_white_on_black(refined_u8, ink_mask_u8)
        imwrite_unicode(out_root / "masked" / f"{stem}.png", cv2.cvtColor(vis_rgb, cv2.COLOR_RGB2BGR))


In [7]:
def make_ink_mask_white_on_black(
    img_rgb: np.ndarray,
    thr: int = 160,
    blur_ksize: int = 0,
    morph_open: int = 0,
    morph_close: int = 0,
) -> np.ndarray:
    """
    img_rgb: RGB 이미지
    반환: ink_mask_u8 (0 or 255)
    전제: 검은 배경 + 흰 잉크 (잉크가 밝음)
    """
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)

    if blur_ksize and blur_ksize > 1:
        gray = cv2.GaussianBlur(gray, (blur_ksize, blur_ksize), 0)

    # white ink on black bg
    ink = (gray > thr).astype(np.uint8) * 255

    if morph_open and morph_open > 1:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_open, morph_open))
        ink = cv2.morphologyEx(ink, cv2.MORPH_OPEN, k)

    if morph_close and morph_close > 1:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_close, morph_close))
        ink = cv2.morphologyEx(ink, cv2.MORPH_CLOSE, k)

    return ink



In [8]:
def suppress_minor_labels_in_components(
    ink_mask_u8: np.ndarray,
    label_0_3: np.ndarray,
    allow_classes: Tuple[int, ...] = (1, 2, 3),
    min_ratio: float = 0.06,
    min_pixels: int = 20,
    connectivity: int = 8,
) -> np.ndarray:
    """
    잉크 component(connected component) 단위로,
    component 내부에서 극히 적은 라벨(1~3)을 다수 라벨로 흡수.
    """
    assert ink_mask_u8.shape == label_0_3.shape
    ink = (ink_mask_u8 > 0)

    out = label_0_3.copy().astype(np.uint8)
    out[~ink] = 0

    conn = 8 if connectivity == 8 else 4
    num, cc = cv2.connectedComponents(ink.astype(np.uint8), connectivity=conn)

    for cid in range(1, num):
        m = (cc == cid)

        total_allowed = 0
        counts = {}
        for k in allow_classes:
            c = int((out[m] == k).sum())
            counts[k] = c
            total_allowed += c

        if total_allowed == 0:
            continue

        majority = max(counts.keys(), key=lambda k: counts[k])

        for k in allow_classes:
            if k == majority:
                continue
            c = counts[k]
            if c == 0:
                continue
            ratio = c / float(total_allowed)
            if (c < min_pixels) or (ratio < min_ratio):
                out[m & (out == k)] = majority

    out[~ink] = 0
    return out


In [18]:
def refine_folder_suppress_minor_only(
    hw_img_dir: str | Path,
    hw_pred_dir: str | Path,
    out_dir: str | Path,
    file_exts_img=(".png", ".jpg", ".jpeg", ".bmp"),
    pred_ext_priority=(".npy", ".png"),

    # ink mask params
    ink_thr: int = 160,
    ink_blur: int = 0,
    ink_open: int = 0,
    ink_close: int = 0,

    # suppress-minor params
    do_suppress_minor: bool = True,
    suppress_min_ratio: float = 0.06,
    suppress_min_pixels: int = 20,
    suppress_connectivity: int = 8,

    # NEW: suppress 후 특정 라벨이 완전히 사라지면 롤백
    rollback_if_label_disappears: bool = True,
    protect_classes: Tuple[int, ...] = (1, 2, 3),   # 보통 1~3 보호

    # save options
    save_npy: bool = True,
    save_masked: bool = True,
    
    preserve_disappeared_labels=True,

):
    hw_img_dir = Path(hw_img_dir)
    hw_pred_dir = Path(hw_pred_dir)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    img_files = sorted([p for p in hw_img_dir.iterdir() if p.is_file() and p.suffix.lower() in file_exts_img])
    if len(img_files) == 0:
        raise RuntimeError(f"No images in: {hw_img_dir}")

    def find_label_file(label_dir: Path, stem: str, priority=(".npy", ".png")) -> Optional[Path]:
        for ext in priority:
            p = label_dir / f"{stem}{ext}"
            if p.exists():
                return p
        return None

    print(f"[SUPPRESS_ONLY] images={len(img_files)}")
    print(f" - hw_img_dir  : {hw_img_dir}")
    print(f" - hw_pred_dir : {hw_pred_dir}")
    print(f" - out_dir     : {out_dir}")
    print(f" - ink_thr     : {ink_thr} (black bg + white ink)")
    print(f" - suppress    : {do_suppress_minor} (ratio={suppress_min_ratio}, px={suppress_min_pixels}, conn={suppress_connectivity})")

    for img_path in tqdm(img_files, desc="suppress_minor_only", leave=True):
        stem = img_path.stem

        pred_path = find_label_file(hw_pred_dir, stem, pred_ext_priority)
        if pred_path is None:
            continue

        # load hw image
        bgr = imread_unicode(img_path, cv2.IMREAD_COLOR)
        if bgr is None:
            continue
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        # load pred label
        pred = load_label_any(pred_path)

        # match size
        H, W = rgb.shape[:2]
        if pred.shape != (H, W):
            pred = cv2.resize(pred, (W, H), interpolation=cv2.INTER_NEAREST)

        # ink mask (white ink on black bg)
        ink = make_ink_mask_white_on_black(
            rgb,
            thr=ink_thr,
            blur_ksize=ink_blur,
            morph_open=ink_open,
            morph_close=ink_close,
        )

        # enforce background outside ink
        refined = pred.astype(np.uint8).copy()
        refined[ink == 0] = 0

        # suppress-minor inside components
        if do_suppress_minor:
            ink_bool = (ink > 0)
        
            # (A) suppress 이전 라벨 존재 집합 (ink 내부에서만)
            before_labels = set(int(v) for v in np.unique(refined[ink_bool]) if int(v) in protect_classes)
        
            # (B) suppress 실행
            suppressed = suppress_minor_labels_in_components(
                ink_mask_u8=ink,
                label_0_3=refined,
                allow_classes=protect_classes,
                min_ratio=suppress_min_ratio,
                min_pixels=suppress_min_pixels,
                connectivity=suppress_connectivity,
            )
            suppressed[~ink_bool] = 0
        
            # (C) suppress 이후 라벨 존재 집합
            after_labels = set(int(v) for v in np.unique(suppressed[ink_bool]) if int(v) in protect_classes)
        
            disappeared = sorted(list(before_labels - after_labels))
        
            # (D) 라벨이 "완전히 사라졌으면" 그 라벨 픽셀만 원본에서 복구
            if preserve_disappeared_labels and len(disappeared) > 0:
                # NOTE: 이 복구는 "해당 라벨의 픽셀이 0개가 되는 것을 방지"하는 목적.
                #       복구 픽셀이 너무 적으면(예: 1~2픽셀) 시각적으로 의미 없을 수 있으니,
                #       원하면 최소 복구 픽셀 수 threshold를 추가로 둘 수도 있습니다.
                for k in disappeared:
                    restore_mask = ink_bool & (refined == k)
                    if np.any(restore_mask):
                        suppressed[restore_mask] = k
        
                # 복구 후에도 잉크 밖은 0
                suppressed[~ink_bool] = 0
        
            refined = suppressed



        # save
        save_label_outputs(
            out_root=out_dir,
            stem=stem,
            refined_label=refined,
            ink_mask_u8=ink,
            save_npy=save_npy,
            save_masked=save_masked,
        )

    print("Done.")


In [28]:
HW_IMG_DIR     = Path(r"D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\images")# 손글씨 원본(또는 크롭된 char 이미지)
HW_PRED_DIR    = Path(r"D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\clustered\labels_npy")# 손글씨 FPN 예측 라벨 (png/npy)
OUT_DIR        = Path(r"D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\suppressed")        # 출력


refine_folder_suppress_minor_only(
    hw_img_dir=HW_IMG_DIR,
    hw_pred_dir=HW_PRED_DIR,
    out_dir=OUT_DIR,

    # ink mask (검은 배경 + 흰 잉크)
    ink_thr=160,
    ink_blur=0,
    ink_open=0,
    ink_close=0,

    # component 소수 라벨 흡수
    do_suppress_minor=True,
    suppress_min_ratio=0.3,
    suppress_min_pixels=20,
    suppress_connectivity=8,

    save_npy=True,
    save_masked=True,
)


[SUPPRESS_ONLY] images=7
 - hw_img_dir  : D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\images
 - hw_pred_dir : D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\clustered\labels_npy
 - out_dir     : D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\results\segment_results\suppressed
 - ink_thr     : 160 (black bg + white ink)
 - suppress    : True (ratio=0.3, px=20, conn=8)


suppress_minor_only: 100%|███████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 56.53it/s]

Done.



