In [27]:
import os
import time
from collections import OrderedDict

import cv2
import numpy as np
from PIL import Image
import math
import re

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import sys
from pathlib import Path

def setup_project_path():
    current = Path.cwd()
    while not (current / 'craft').exists():
        current = current.parent
    return current
project_root = setup_project_path()
sys.path.insert(0, str(project_root))

from craft.common import craft
from craft.common import craft_utils
from craft.common import file_utils
from craft.common import imgproc
from craft.common.craft import CRAFT
from craft.common.refinenet import RefineNet

In [2]:
def imwrite_unicode(path, img):
    ext = os.path.splitext(path)[1]
    ok, buf = cv2.imencode(ext, img)
    if not ok:
        return False
    with open(path, "wb") as f:
        f.write(buf.tobytes())
    return True

In [3]:
def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict


In [4]:
def load_craft_model(
    trained_model: str = "weights/craft_mlt_25k.pth",
    use_cuda: bool = True,
    use_refiner: bool = False,
    refiner_model: str = "weights/craft_refiner_CTW1500.pth",
):
    if use_cuda and not torch.cuda.is_available():
        print("[WARN] CUDA is not available, fallback to CPU.")
        use_cuda = False

    net = CRAFT()

    print(f"Loading CRAFT weights from checkpoint: {trained_model}")
    if use_cuda:
        net.load_state_dict(copyStateDict(torch.load(trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(trained_model, map_location="cpu")))

    if use_cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    refine_net = None
    if use_refiner:     
        refine_net = RefineNet()
        print(f"Loading Refiner weights from checkpoint: {refiner_model}")

        if use_cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(torch.load(refiner_model, map_location="cpu"))
            )

        refine_net.eval()

    return net, refine_net, use_cuda


In [5]:
def extract_centers_from_score_text(
    score_text: np.ndarray,
    thr: float = 0.7,
    min_area: int = 3,
):
    _, binary = cv2.threshold(
        score_text.astype(np.float32),
        thr,
        1.0,
        cv2.THRESH_BINARY,
    )
    binary = (binary * 255).astype(np.uint8)
    kernel = np.ones((5, 5), np.uint8)
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2)

    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        binary, connectivity=8
    )

    centers = []
    for i in range(1, num_labels):
        area = stats[i, cv2.CC_STAT_AREA]
        if area < min_area:
            continue
        cx, cy = centroids[i]
        centers.append((cx, cy))

    return centers

In [6]:
def compute_center_angles(centers):
    if len(centers) < 2:
        return []

    centers_sorted = sorted(centers, key=lambda p: p[0])

    angles = []
    for i in range(len(centers_sorted) - 1):
        x1, y1 = centers_sorted[i]
        x2, y2 = centers_sorted[i + 1]

        dx = x2 - x1
        dy = -(y2 - y1)

        angle_rad = math.atan2(dy, dx)
        angle_deg = angle_rad * 180.0 / math.pi

        angles.append(angle_deg)

    return angles

In [7]:
def _auto_merge_dist(centers, ratio=0.45, min_px=15, max_px=120):
    if len(centers) < 2:
        return min_px

    xs = sorted([p[0] for p in centers])
    dxs = [xs[i + 1] - xs[i] for i in range(len(xs) - 1)]
    dxs = [d for d in dxs if d > 1e-6]
    if not dxs:
        return min_px

    dxs.sort()
    med = dxs[len(dxs) // 2]
    d = med * ratio
    return float(max(min_px, min(max_px, d)))

In [8]:
def merge_centers_by_gap_to_expected(centers, expected_count: int):
    if expected_count <= 0:
        return centers

    if len(centers) <= expected_count:
        return sorted(centers, key=lambda p: p[0])

    centers_sorted = sorted(centers, key=lambda p: p[0])
    n = len(centers_sorted)

    dx = [centers_sorted[i+1][0] - centers_sorted[i][0] for i in range(n - 1)]

    k = expected_count - 1
    if k <= 0:
        mx = sum(p[0] for p in centers_sorted) / n
        my = sum(p[1] for p in centers_sorted) / n
        return [(mx, my)]

    # dx 큰 순서로 경계 선택
    idx_sorted = sorted(range(len(dx)), key=lambda i: dx[i], reverse=True)
    boundaries = sorted(idx_sorted[:k])  # split between b and b+1

    groups = []
    start = 0
    for b in boundaries:
        groups.append(centers_sorted[start:b+1])
        start = b + 1
    groups.append(centers_sorted[start:])

    merged = []
    for g in groups:
        mx = sum(p[0] for p in g) / len(g)
        my = sum(p[1] for p in g) / len(g)
        merged.append((mx, my))

    return merged


In [9]:
def detect_text(
    net,
    image,
    text_threshold=0.7,
    link_threshold=0.4,
    low_text=0.4,
    use_cuda=True,
    poly=False,
    refine_net=None,
    canvas_size=1280,
    mag_ratio=1.5,
    show_time=False,
    center_thr=0.7,
    center_min_area=3,
    expected_char_count=None,
    merge_centers: bool = True,
    merge_ratio: float = 0.45,
    merge_min_px: int = 18,
    merge_max_px: int = 140,
):
    t0 = time.time()

    if isinstance(image, Image.Image):
        image = np.array(image)

    H0, W0 = image.shape[:2]

    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
        image,
        canvas_size,
        interpolation=cv2.INTER_LINEAR,
        mag_ratio=mag_ratio,
    )
    ratio_h = ratio_w = 1 / target_ratio

    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)
    x = Variable(x.unsqueeze(0))
    if use_cuda:
        x = x.cuda()

    with torch.no_grad():
        y, feature = net(x)

    score_text = y[0, :, :, 0].cpu().data.numpy()
    score_link = y[0, :, :, 1].cpu().data.numpy()

    if refine_net is not None:
        with torch.no_grad():
            y_refiner = refine_net(y, feature)
        score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

    t0 = time.time() - t0
    t1 = time.time()

    boxes, polys = craft_utils.getDetBoxes(
        score_text, score_link, text_threshold, link_threshold, low_text, poly
    )
    boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
    for k in range(len(polys)):
        if polys[k] is None:
            polys[k] = boxes[k]

    t1 = time.time() - t1

    pad_h, pad_w = img_resized.shape[:2]

    valid_w = int(round(W0 * target_ratio))
    valid_h = int(round(H0 * target_ratio))

    heat_h, heat_w = score_text.shape[:2]

    valid_heat_w = int(round(heat_w * (valid_w / pad_w)))
    valid_heat_h = int(round(heat_h * (valid_h / pad_h)))

    valid_heat_w = max(1, min(valid_heat_w, heat_w))
    valid_heat_h = max(1, min(valid_heat_h, heat_h))

    score_text_valid = score_text[:valid_heat_h, :valid_heat_w]

    score_text_orig = cv2.resize(
        score_text_valid,
        (W0, H0),
        interpolation=cv2.INTER_LINEAR,
    )

    if expected_char_count is None:
        raise ValueError("expected_char_count must be provided (not None)")
    
    centers_raw = extract_centers_from_score_text(
        score_text_orig,
        thr=center_thr,
        min_area=center_min_area,
    )
    
    if len(centers_raw) < expected_char_count:
        raise ValueError(
            f"center 미검출: expected={expected_char_count}, detected_raw={len(centers_raw)}"
        )
    
    if len(centers_raw) > expected_char_count:
        centers = merge_centers_by_gap_to_expected(centers_raw, expected_char_count)
    else:
        centers = sorted(centers_raw, key=lambda p: p[0])
    
    if len(centers) != expected_char_count:
        raise ValueError(
            f"center 정규화 실패: expected={expected_char_count}, detected={len(centers)}"
        )

    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    score_heatmap = imgproc.cvt2HeatmapImg(render_img)

    if show_time:
        print("infer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, score_heatmap, centers, score_text_orig


In [10]:
def run_craft_on_folder(
    net,
    refine_net,
    use_cuda: bool,
    test_folder: str,
    result_folder: str,
    score_text_folder: str,
    text_threshold: float = 0.7,
    link_threshold: float = 0.4,
    low_text: float = 0.4,
    canvas_size: int = 1280,
    mag_ratio: float = 1.5,
    poly: bool = False,
    show_time: bool = False,
    center_thr: float = 0.7,
    center_min_area: int = 3,
    expected_map: dict | None = None,
    save_mask: bool = False,
    save_polys: bool = False,
    save_score_text: bool = True,
):
    os.makedirs(result_folder, exist_ok=True)
    os.makedirs(score_text_folder, exist_ok=True)
    image_list, _, _ = file_utils.get_files(test_folder)
    if len(image_list) == 0:
        print(f"[WARN] No images found in folder: {test_folder}")
        return

    t_start = time.time()

    for k, image_path in enumerate(image_list):
        print(f"Test image {k+1}/{len(image_list)}: {image_path}", end="\r")

        image = imgproc.loadImage(image_path)
        filename, _ = os.path.splitext(os.path.basename(image_path))

        local_expected = None
        if expected_map is not None and filename in expected_map:
            v = expected_map[filename]
            local_expected = len(v.replace(" ", ""))

        bboxes, polys, score_heatmap, centers, score_text_up = detect_text(
            net=net,
            image=image,
            text_threshold=text_threshold,
            link_threshold=link_threshold,
            low_text=low_text,
            use_cuda=use_cuda,
            poly=poly,
            refine_net=refine_net,
            canvas_size=canvas_size,
            mag_ratio=mag_ratio,
            show_time=show_time,
            center_thr=center_thr,
            center_min_area=center_min_area,
            expected_char_count=local_expected,
            merge_centers=True,
        )

        if save_mask:
            mask_file = os.path.join(result_folder, f"res_{filename}_mask.jpg")
            imwrite_unicode(mask_file, score_heatmap)

        centers_file = os.path.join(result_folder, f"res_{filename}_center.txt")
        with open(centers_file, "w", encoding="utf-8") as f:
            for (cx, cy) in centers:
                f.write(f"{cx:.2f},{cy:.2f}\n")

        if save_polys:
            file_utils.saveResult(
                image_path, image[:, :, ::-1], polys, dirname=result_folder
            )
            
        score_file = os.path.join(score_text_folder, f"res_{filename}_score_text.npy")
        np.save(score_file, score_text_up.astype(np.float32))



    print("\nElapsed time : {:.3f}s".format(time.time() - t_start))


In [11]:
net, refine_net, use_cuda = load_craft_model(
    trained_model = project_root/"craft"/"common"/"weights"/"craft_mlt_25k.pth", 
    use_cuda=True,           
    use_refiner=False,        
    refiner_model=project_root/"craft"/"common"/"weights"/"craft_refiner_CTW1500.pth",
)

[WARN] CUDA is not available, fallback to CPU.




Loading CRAFT weights from checkpoint: D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\craft\common\weights\craft_mlt_25k.pth


In [12]:
hp = {
    "text_threshold": 0.5,
    "link_threshold": 0.1,
    "low_text": 0.2,
    "canvas_size": 1280,
    "mag_ratio": 1.5,
    "poly": False,
    "show_time": True,
    "center_thr" : 0.4,
    "center_min_area" : 30
}

In [13]:
expected_map = {
    "test1": "바른글씨",
    "test2": "캡스톤디자인",
    "test3": "숭실대학교",  
    "test4": "기말 시험",
    "test5": "소프트웨어 분석",
    "test6": "안녕하세요",
    "test7": "숭실대학교",
    "test8": "안녕하세요"
}

run_craft_on_folder(
    net=net,
    refine_net=refine_net,
    use_cuda=use_cuda,
    test_folder= project_root/"craft"/"images_normalized",       
    result_folder=project_root/"craft"/"results"/"center", 
    score_text_folder = project_root/"craft"/"results"/"score_text",
    expected_map=expected_map,
    **hp,
)

infer/postproc time : 2.466/0.0082학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test1.png
infer/postproc time : 2.183/0.0102학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test2.png
infer/postproc time : 1.459/0.0142학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test3.png
infer/postproc time : 1.237/0.0092학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test4.png
infer/postproc time : 2.236/0.0072학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test5.png
infer/postproc time : 1.097/0.0092학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test6.png
infer/postproc time : 2.091/0.0072학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test7.png
infer/postproc time : 1.394/0.0082학기\캡스톤\Baram_Handwritting_Analysis\craft\images_normalized\test8.png

Elapsed time : 14.542s
