In [55]:
import os
import time
from collections import OrderedDict

import cv2
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import sys
from pathlib import Path

def setup_project_path():
    current = Path.cwd()
    while not (current / 'craft').exists():
        current = current.parent
    return current
project_root = setup_project_path()
sys.path.insert(0, str(project_root))

from craft.common import craft
from craft.common import craft_utils
from craft.common import file_utils
from craft.common import imgproc
from craft.common.craft import CRAFT
from craft.common.refinenet import RefineNet

In [56]:
def imwrite_unicode(path, img):
    ext = os.path.splitext(path)[1]
    ok, buf = cv2.imencode(ext, img)
    if not ok:
        return False
    with open(path, "wb") as f:
        f.write(buf.tobytes())
    return True

In [57]:
def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict


In [58]:
def load_craft_model(
    trained_model: str = "weights/craft_mlt_25k.pth",
    use_cuda: bool = True,
    use_refiner: bool = False,
    refiner_model: str = "weights/craft_refiner_CTW1500.pth",
):
    """
    CRAFT / RefineNet 모델을 로드해서 반환.
    
    Returns
    -------
    net : torch.nn.Module
        CRAFT 모델 (eval 모드)
    refine_net : torch.nn.Module or None
        RefineNet 모델 (옵션, eval 모드)
    use_cuda : bool
        실제로 cuda를 사용하는지 여부
    """
    if use_cuda and not torch.cuda.is_available():
        print("[WARN] CUDA is not available, fallback to CPU.")
        use_cuda = False

    net = CRAFT()  # initialize

    print(f"Loading CRAFT weights from checkpoint: {trained_model}")
    if use_cuda:
        net.load_state_dict(copyStateDict(torch.load(trained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(trained_model, map_location="cpu")))

    if use_cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()

    refine_net = None
    if use_refiner:     
        refine_net = RefineNet()
        print(f"Loading Refiner weights from checkpoint: {refiner_model}")

        if use_cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(
                copyStateDict(torch.load(refiner_model, map_location="cpu"))
            )

        refine_net.eval()

    return net, refine_net, use_cuda


In [59]:
def detect_text(
    net,
    image,
    text_threshold: float = 0.7,
    link_threshold: float = 0.4,
    low_text: float = 0.4,
    use_cuda: bool = True,
    poly: bool = False,
    refine_net=None,
    canvas_size: int = 1280,
    mag_ratio: float = 1.5,
    show_time: bool = False,
):
    """
    단일 이미지에 대해 CRAFT 추론을 수행하고 결과를 반환.

    Parameters
    ----------
    net : CRAFT model
    image : np.ndarray (H, W, 3) 또는 PIL.Image
    text_threshold : float
    link_threshold : float
    low_text : float
    use_cuda : bool
    poly : bool
    refine_net : RefineNet or None
    canvas_size : int
    mag_ratio : float
    show_time : bool

    Returns
    -------
    boxes : list
        텍스트 영역 box 리스트 (사각형 기준)
    polys : list
        텍스트 영역 polygon 리스트
    score_heatmap : np.ndarray
        score_text, score_link를 시각화한 heatmap 이미지
    """
    t0 = time.time()

    # 이미지 타입 통일
    if isinstance(image, Image.Image):
        image = np.array(image)

    # resize
    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
        image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
    )
    ratio_h = ratio_w = 1 / target_ratio

    # preprocessing
    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] -> [c, h, w]
    x = Variable(x.unsqueeze(0))  # [c, h, w] -> [b, c, h, w]

    if use_cuda:
        x = x.cuda()

    # forward pass
    with torch.no_grad():
        y, feature = net(x)

    # score & link map
    score_text = y[0, :, :, 0].cpu().data.numpy()
    score_link = y[0, :, :, 1].cpu().data.numpy()

    # refine link
    if refine_net is not None:
        with torch.no_grad():
            y_refiner = refine_net(y, feature)
        score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

    t0 = time.time() - t0
    t1 = time.time()

    # Post-processing
    boxes, polys = craft_utils.getDetBoxes(
        score_text, score_link, text_threshold, link_threshold, low_text, poly
    )

    # 좌표 원본 크기로 보정
    boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
    for k in range(len(polys)):
        if polys[k] is None:
            polys[k] = boxes[k]

    t1 = time.time() - t1

    # heatmap 생성
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    score_heatmap = imgproc.cvt2HeatmapImg(render_img)

    if show_time:
        print("infer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, score_heatmap

In [60]:
def run_craft_on_folder(
    net,
    refine_net,
    use_cuda: bool,
    test_folder: str,
    result_folder: str = "./result",
    text_threshold: float = 0.7,
    link_threshold: float = 0.4,
    low_text: float = 0.4,
    canvas_size: int = 1280,
    mag_ratio: float = 1.5,
    poly: bool = False,
    show_time: bool = False,
):
    """
    폴더 내 이미지들에 대해 CRAFT 추론을 수행하고 결과 이미지/마스크 저장.

    Parameters
    ----------
    net : CRAFT model
    refine_net : RefineNet or None
    use_cuda : bool
    test_folder : str
        입력 이미지 폴더
    result_folder : str
        결과 저장 폴더
    이하 하이퍼파라미터는 detect_text()와 동일
    """
    os.makedirs(result_folder, exist_ok=True)

    image_list, _, _ = file_utils.get_files(test_folder)
    if len(image_list) == 0:
        print(f"[WARN] No images found in folder: {test_folder}")
        return

    t_start = time.time()

    for k, image_path in enumerate(image_list):
        print(
            f"Test image {k+1}/{len(image_list)}: {image_path}",
            end="\r",
        )
        image = imgproc.loadImage(image_path)

        bboxes, polys, score_text = detect_text(
            net=net,
            image=image,
            text_threshold=text_threshold,
            link_threshold=link_threshold,
            low_text=low_text,
            use_cuda=use_cuda,
            poly=poly,
            refine_net=refine_net,
            canvas_size=canvas_size,
            mag_ratio=mag_ratio,
            show_time=show_time,
        )

        # score text 저장
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = os.path.join(result_folder, f"res_{filename}_mask.jpg")
        ok = imwrite_unicode(mask_file, score_text)

        # polygon 결과 저장
        file_utils.saveResult(
            image_path, image[:, :, ::-1], polys, dirname=result_folder
        )

    print("\nElapsed time : {:.3f}s".format(time.time() - t_start))

In [61]:
# 모델 로드
net, refine_net, use_cuda = load_craft_model(
    trained_model = project_root/"craft"/"common"/"weights"/"craft_mlt_25k.pth", 
    use_cuda=True,           
    use_refiner=False,        
    refiner_model=project_root/"craft"/"common"/"weights"/"craft_refiner_CTW1500.pth",
)

[WARN] CUDA is not available, fallback to CPU.




Loading CRAFT weights from checkpoint: D:\Study\학교강의\4학년2학기\캡스톤\Baram_Handwritting_Analysis\craft\common\weights\craft_mlt_25k.pth


In [62]:
# 하이퍼파라미터 설정 (원하는 대로 수정 가능)
hp = {
    "text_threshold": 0.7,
    "link_threshold": 0.4,
    "low_text": 0.4,
    "canvas_size": 1280,
    "mag_ratio": 1.5,
    "poly": False,
    "show_time": True,
}

In [103]:
run_craft_on_folder(
    net=net,
    refine_net=refine_net,
    use_cuda=use_cuda,
    test_folder= project_root/"craft"/"images",       
    result_folder=project_root/"craft"/"images", 
    **hp,                       
)

infer/postproc time : 1.779/0.0032학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test2.png
infer/postproc time : 1.789/0.0062학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test3.png
infer/postproc time : 1.803/0.0052학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test4.png
infer/postproc time : 1.885/0.0052학기\캡스톤\Baram_Handwritting_Analysis\craft\images\test5.png

Elapsed time : 7.476s
