In [37]:
import os
import time
from collections import OrderedDict

import cv2
import numpy as np
from PIL import Image
import math
import re

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import sys
from pathlib import Path

def setup_project_path():
    current = Path.cwd()
    while not (current / 'craft').exists():
        current = current.parent
    return current
project_root = setup_project_path()
sys.path.insert(0, str(project_root))

from craft.common import craft
from craft.common import craft_utils
from craft.common import file_utils
from craft.common import imgproc
from craft.common.craft import CRAFT
from craft.common.refinenet import RefineNet

In [38]:
def imwrite_unicode(path, img):
    ext = os.path.splitext(path)[1]
    ok, buf = cv2.imencode(ext, img)
    if not ok:
        return False
    with open(path, "wb") as f:
        f.write(buf.tobytes())
    return True

In [39]:
def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict


In [40]:
def load_craft_model(
    trained_model: str = "weights/craft_mlt_25k.pth",
    use_cuda: bool = True,
    use_refiner: bool = False,
    refiner_model: str = "weights/craft_refiner_CTW1500.pth",
):
    if use_cuda and not torch.cuda.is_available():
        print("[WARN] CUDA is not available, fallback to CPU.")
        use_cuda = False

    net = CRAFT()

    print(f"Loading CRAFT weights from checkpoint: {trained_model}")
    if use_cuda:
        net.load_state_dict(copyStateDict(torch.load(trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(trained_model, map_location="cpu")))

    if use_cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    refine_net = None
    if use_refiner:     
        refine_net = RefineNet()
        print(f"Loading Refiner weights from checkpoint: {refiner_model}")

        if use_cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(torch.load(refiner_model, map_location="cpu"))
            )

        refine_net.eval()

    return net, refine_net, use_cuda


In [41]:
def extract_centers_from_score_text(
    score_text: np.ndarray,
    thr: float = 0.7,
    min_area: int = 3,
):
    _, binary = cv2.threshold(
        score_text.astype(np.float32),
        thr,
        1.0,
        cv2.THRESH_BINARY,
    )
    binary = (binary * 255).astype(np.uint8)

    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        binary, connectivity=8
    )

    centers = []
    for i in range(1, num_labels):
        area = stats[i, cv2.CC_STAT_AREA]
        if area < min_area:
            continue
        cx, cy = centroids[i]
        centers.append((cx, cy))

    return centers

In [42]:
def compute_center_angles(centers):
    if len(centers) < 2:
        return []

    centers_sorted = sorted(centers, key=lambda p: p[0])

    angles = []
    for i in range(len(centers_sorted) - 1):
        x1, y1 = centers_sorted[i]
        x2, y2 = centers_sorted[i + 1]

        dx = x2 - x1
        dy = -(y2 - y1)

        angle_rad = math.atan2(dy, dx)
        angle_deg = angle_rad * 180.0 / math.pi

        angles.append(angle_deg)

    return angles

In [43]:
def detect_text(
    net,
    image,
    text_threshold=0.7,
    link_threshold=0.4,
    low_text=0.4,
    use_cuda=True,
    poly=False,
    refine_net=None,
    canvas_size=1280,
    mag_ratio=1.5,
    show_time=False,
    center_thr=0.7,
    center_min_area=3,
    expected_char_count=None,
):

    t0 = time.time()

    if isinstance(image, Image.Image):
        image = np.array(image)

    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
        image,
        canvas_size,
        interpolation=cv2.INTER_LINEAR,
        mag_ratio=mag_ratio,
    )
    ratio_h = ratio_w = 1 / target_ratio

    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)
    x = Variable(x.unsqueeze(0))
    if use_cuda:
        x = x.cuda()

    if use_cuda:
        x = x.cuda()

    with torch.no_grad():
        y, feature = net(x)

    score_text = y[0, :, :, 0].cpu().data.numpy()
    score_link = y[0, :, :, 1].cpu().data.numpy()

    if refine_net is not None:
        with torch.no_grad():
            y_refiner = refine_net(y, feature)
        score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

    t0 = time.time() - t0
    t1 = time.time()

    boxes, polys = craft_utils.getDetBoxes(
        score_text, score_link, text_threshold, link_threshold, low_text, poly
    )

    boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
    for k in range(len(polys)):
        if polys[k] is None:
            polys[k] = boxes[k]

    t1 = time.time() - t1

    h_resized, w_resized = img_resized.shape[:2]
    score_text_up = cv2.resize(
        score_text,
        (w_resized, h_resized),
        interpolation=cv2.INTER_LINEAR,
    )
    
    centers_resized = extract_centers_from_score_text(
        score_text_up,
        thr=center_thr,
        min_area=center_min_area,
    )

    centers = [(cx * ratio_w, cy * ratio_h) for (cx, cy) in centers_resized]

    if expected_char_count is not None:
        if len(centers) != expected_char_count:
            raise ValueError(
                "글자 수 불일치: expected=%d, detected=%d"
                % (expected_char_count, len(centers))
            )
    
    
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    score_heatmap = imgproc.cvt2HeatmapImg(render_img)

    if show_time:
        print("infer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, score_heatmap, centers

In [44]:
def run_craft_on_folder(
    net,
    refine_net,
    use_cuda: bool,
    test_folder: str,
    result_folder: str,
    text_threshold: float = 0.7,
    link_threshold: float = 0.4,
    low_text: float = 0.4,
    canvas_size: int = 1280,
    mag_ratio: float = 1.5,
    poly: bool = False,
    show_time: bool = False,
    center_thr: float = 0.7,
    center_min_area: int = 3,
    expected_map: dict | None = None,
    save_mask: bool = False,
    save_polys: bool = False,
):
    os.makedirs(result_folder, exist_ok=True)

    image_list, _, _ = file_utils.get_files(test_folder)
    if len(image_list) == 0:
        print(f"[WARN] No images found in folder: {test_folder}")
        return

    t_start = time.time()

    for k, image_path in enumerate(image_list):
        print(f"Test image {k+1}/{len(image_list)}: {image_path}", end="\r")

        image = imgproc.loadImage(image_path)
        filename, _ = os.path.splitext(os.path.basename(image_path))

        local_expected = None
        if expected_map is not None and filename in expected_map:
            v = expected_map[filename]
            local_expected = len(v.replace(" ", ""))

        bboxes, polys, score_heatmap, centers = detect_text(
            net=net,
            image=image,
            text_threshold=text_threshold,
            link_threshold=link_threshold,
            low_text=low_text,
            use_cuda=use_cuda,
            poly=poly,
            refine_net=refine_net,
            canvas_size=canvas_size,
            mag_ratio=mag_ratio,
            show_time=show_time,
            center_thr=center_thr,
            center_min_area=center_min_area,
            expected_char_count=local_expected,
        )

        if save_mask:
            mask_file = os.path.join(result_folder, f"res_{filename}_mask.jpg")
            imwrite_unicode(mask_file, score_heatmap)

        centers_file = os.path.join(result_folder, f"res_{filename}_center.txt")
        with open(centers_file, "w", encoding="utf-8") as f:
            for (cx, cy) in centers:
                f.write(f"{cx:.2f},{cy:.2f}\n")

        if save_polys:
            file_utils.saveResult(
                image_path, image[:, :, ::-1], polys, dirname=result_folder
            )

    print("\nElapsed time : {:.3f}s".format(time.time() - t_start))


In [45]:
net, refine_net, use_cuda = load_craft_model(
    trained_model = project_root/"craft"/"common"/"weights"/"craft_mlt_25k.pth", 
    use_cuda=True,           
    use_refiner=False,        
    refiner_model=project_root/"craft"/"common"/"weights"/"craft_refiner_CTW1500.pth",
)

[WARN] CUDA is not available, fallback to CPU.




Loading CRAFT weights from checkpoint: D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\craft\common\weights\craft_mlt_25k.pth


In [46]:
hp = {
    "text_threshold": 0.7,
    "link_threshold": 0.4,
    "low_text": 0.4,
    "canvas_size": 1280,
    "mag_ratio": 1.5,
    "poly": False,
    "show_time": True,
}

In [47]:
expected_map = {
    "test1": "안녕하세요",  
    "test2": "안녕하세요",  
    "test3": "안녕하세요",
    "test4": "안녕하세요",
    "test5": "다람쥐 헌 챗바퀴에 타고파"
}

run_craft_on_folder(
    net=net,
    refine_net=refine_net,
    use_cuda=use_cuda,
    test_folder= project_root/"craft"/"images",       
    result_folder=project_root/"craft"/"results"/"center", 
    expected_map=expected_map,
    **hp,
)

infer/postproc time : 0.260/0.0012학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test1.png
infer/postproc time : 0.118/0.0012학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test2.png
infer/postproc time : 0.263/0.0022학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test3.png
infer/postproc time : 0.126/0.0012학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test4.png
infer/postproc time : 0.154/0.0022학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test5.png

Elapsed time : 0.953s
