In [10]:
!pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.3.228-py3-none-any.whl.metadata (37 kB)
Collecting ultralytics-thop>=2.0.18 (from ultralytics)
  Downloading ultralytics_thop-2.0.18-py3-none-any.whl.metadata (14 kB)
Downloading ultralytics-8.3.228-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.18-py3-none-any.whl (28 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.228 ultralytics-thop-2.0.18


In [31]:
# === Fast Inference API (Single-pill, Top-1 Only) ===
# YOLO → (조건부 전체이미지 / 소프트 크롭)
#     → ResNet(1324) 최종 Top-1 (dl_idx 계산식 기반)
# -*- coding: utf-8 -*-
import os
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms

from ultralytics import YOLO
from PIL import Image

# -------------------------------------------------
# 0) 경로 / 기본 설정
# -------------------------------------------------
DRIVE = "/content/drive/MyDrive"

# 테스트용 기본 이미지 경로 (필요하면 이 값만 바꿔서 사용)
IMG_PATH = "/content/drive/MyDrive/캡스톤_원천_데이터/TS_34_단일.zip/K-009272/K-009272_0_1_0_0_70_000_200.png"

BEST_YOLO      = os.path.join(DRIVE, "best.pt")
RESNET_1324_PT = os.path.join(DRIVE, "best_model_generalized.pth")
CLASS_JSON_1K  = os.path.join(DRIVE, "pill_label_path_sharp_score.json")
CLASS_JSON_324 = os.path.join(DRIVE, "class_mapping_from_cache_1324.json")

for p in [BEST_YOLO, RESNET_1324_PT]:
    assert os.path.exists(p), f"가중치 파일 없음: {p}"

YOLO_CONF  = 0.25
YOLO_IOU   = 0.45
YOLO_IMGSZ = 640

CROP_SIZE        = 224
MIN_BOX_SIDE_PX  = 40
FULL_IMAGE_AREA_RATIO_THRESHOLD = 0.65
SQUARE_SCALE     = 1.3

NUM_CLASSES  = 1324
LABEL_OFFSET = 1000

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -------------------------------------------------
# 1) 라벨 맵 로드 (class_idx → K-코드, dl_idx 계산)
# -------------------------------------------------
def load_label_map_generic(json_path):
    if not json_path or not os.path.exists(json_path):
        return {}
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if isinstance(data, dict) and "label_to_kcode" in data:
        data = data["label_to_kcode"]

    out = {}
    if isinstance(data, dict):
        for k, v in data.items():
            try:
                key_int = int(k)
            except:
                continue

            val_str = str(v)
            base = os.path.basename(val_str)
            first = base.split("_")[0]   # "K013101" or "K-009272" ...

            if first.startswith("K-") and len(first) == 8 and first[2:].isdigit():
                kcode = first
            elif first.startswith("K") and len(first) == 7 and first[1:].isdigit():
                kcode = "K-" + first[1:]
            else:
                kcode = first

            out[key_int] = kcode

    return out

LABEL_MAP_1K  = load_label_map_generic(CLASS_JSON_1K)   # 0..999
LABEL_MAP_324 = load_label_map_generic(CLASS_JSON_324)  # 0..323 (for 1000~1323)

def class_idx_to_kcode(global_idx: int) -> str:
    if global_idx < LABEL_OFFSET:
        return LABEL_MAP_1K.get(global_idx, f"imagenet_{global_idx}")
    local = global_idx - LABEL_OFFSET
    return LABEL_MAP_324.get(local, f"unknown_{local}")

def kcode_to_dl_idx(kcode: str) -> str:
    # "K-009272" → "9271"
    if len(kcode) >= 7:
        tail = kcode[-6:]
        if tail.isdigit():
            val = int(tail)
            dl_val = val - 1
            if dl_val >= 0:
                return str(dl_val)
    return kcode  # fallback

def idx_to_dl_idx(global_idx: int) -> str:
    kcode = class_idx_to_kcode(global_idx)
    return kcode_to_dl_idx(kcode)

# -------------------------------------------------
# 2) ResNet 1324 모델 + 전처리 (전역 1회 로드)
# -------------------------------------------------
def build_resnet_1324(num_classes=NUM_CLASSES, model_path=RESNET_1324_PT):
    model = models.resnet152(weights=None)
    in_f = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(p=0.5),
        nn.Linear(in_f, num_classes)
    )

    state = torch.load(model_path, map_location="cpu")
    if isinstance(state, dict):
        if "model_state_dict" in state:
            state = state["model_state_dict"]
        elif "model" in state:
            state = state["model"]
    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing or unexpected:
        print(f"ℹ️ state_dict load: missing={len(missing)}, unexpected={len(unexpected)}")

    model.to(DEVICE)
    model.eval()

    if DEVICE.type == "cuda":
        model.half()

    return model

RESNET_MODEL = build_resnet_1324()

base_transform = transforms.Compose([
    transforms.Resize((CROP_SIZE, CROP_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

def preprocess_pil(pil_img):
    t = base_transform(pil_img)
    if DEVICE.type == "cuda":
        return t.half()
    return t

@torch.no_grad()
def predict_resnet_batch_top1(pil_imgs):
    """
    pil_imgs: [PIL.Image, ...]
    return: [[{"idx": int, "prob": float}], ...]  (각 이미지당 Top-1만)
    """
    if not pil_imgs:
        return []
    xs = [preprocess_pil(im) for im in pil_imgs]
    x = torch.stack(xs).to(DEVICE)

    if DEVICE.type == 'cuda':
        logits = RESNET_MODEL(x)
    else:
        logits = RESNET_MODEL(x)

    probs = F.softmax(logits, dim=1)
    top1_prob, top1_idx = torch.topk(probs, 1, dim=1)

    all_results = []
    for i in range(probs.shape[0]):
        res_i = [{
            "idx": int(top1_idx[i, 0].item()),
            "prob": float(top1_prob[i, 0].item()),
        }]
        all_results.append(res_i)
    return all_results

# -------------------------------------------------
# 3) YOLO 감지 + 크롭 생성 (단일 모드용)
# -------------------------------------------------
YOLO_DEVICE = 0 if DEVICE.type == "cuda" else "cpu"
YOLO_MODEL = YOLO(BEST_YOLO)

def detect_yolo_boxes(img_path):
    det = YOLO_MODEL(
        img_path,
        imgsz=YOLO_IMGSZ,
        conf=YOLO_CONF,
        iou=YOLO_IOU,
        device=YOLO_DEVICE,
        verbose=False
    )[0]

    img = Image.open(img_path).convert("RGB")
    W, H = img.size

    boxes = []
    if det.boxes is not None and len(det.boxes) > 0:
        for b in det.boxes.xyxy.cpu().numpy().tolist():
            x1, y1, x2, y2 = map(int, b)
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(W-1, x2)
            y2 = min(H-1, y2)
            if (x2-x1) >= MIN_BOX_SIDE_PX and (y2-y1) >= MIN_BOX_SIDE_PX:
                boxes.append([x1, y1, x2, y2])

    return img, boxes

def square_crop_from_bbox(pil_img, xyxy, scale=1.3):
    W, H = pil_img.size
    x1, y1, x2, y2 = xyxy
    cx = (x1 + x2) / 2.0
    cy = (y1 + y2) / 2.0
    bw = x2 - x1
    bh = y2 - y1
    side = max(bw, bh) * scale
    side = max(side, MIN_BOX_SIDE_PX * 1.5)

    half = side / 2.0
    nx1 = int(round(cx - half))
    ny1 = int(round(cy - half))
    nx2 = int(round(cx + half))
    ny2 = int(round(cy + half))

    nx1 = max(0, nx1)
    ny1 = max(0, ny1)
    nx2 = min(W, nx2)
    ny2 = min(H, ny2)
    if nx2 <= nx1 or ny2 <= ny1:
        return None

    return pil_img.crop((nx1, ny1, nx2, ny2))

def make_crops_for_single_mode(img_path):
    img, boxes = detect_yolo_boxes(img_path)
    W, H = img.size
    img_area = W * H

    use_full_image_only = False
    if len(boxes) == 1:
        x1, y1, x2, y2 = boxes[0]
        box_area = (x2-x1) * (y2-y1)
        area_ratio = box_area / float(img_area + 1e-9)
        if area_ratio >= FULL_IMAGE_AREA_RATIO_THRESHOLD:
            use_full_image_only = True

    crop_images = []
    if use_full_image_only:
        crop_images.append(img.copy())
    else:
        for bbox in boxes:
            sq = square_crop_from_bbox(img, bbox, scale=SQUARE_SCALE)
            if sq is None:
                continue
            crop_images.append(sq)

    if not crop_images:
        crop_images.append(img.copy())

    return crop_images

# -------------------------------------------------
# 4) 엔드투엔드 단일 알약 Top-1 추론
# -------------------------------------------------
def infer_pill_image_single_top1(img_path: str):
    """
    이미지 전체 기준 최종 Top-1만 리턴.
    return:
      {"idx": <class_idx>, "dl_idx": "<dl_idx>", "prob": <float>}
    """
    crops = make_crops_for_single_mode(img_path)
    batch_results = predict_resnet_batch_top1(crops)

    # 클래스별 max-pooling (Top-1 결과만 모아도 동일 로직)
    agg_scores = {}
    for crop_res in batch_results:
        t = crop_res[0]
        idx = t["idx"]
        p   = t["prob"]
        if idx not in agg_scores or p > agg_scores[idx]:
            agg_scores[idx] = p

    if not agg_scores:
        return None

    best_idx, best_prob = max(agg_scores.items(), key=lambda x: x[1])
    dl_idx = idx_to_dl_idx(best_idx)

    return {
        "idx":   best_idx,
        "dl_idx": dl_idx,
        "prob":  best_prob,
    }

# -------------------------------------------------
# 5) 테스트 실행
# -------------------------------------------------
if __name__ == "__main__":
    assert os.path.exists(IMG_PATH), f"이미지 없음: {IMG_PATH}"
    res = infer_pill_image_single_top1(IMG_PATH)
    print("=== SINGLE MODE TOP-1 RESULT ===")
    if res is None:
        print("No prediction.")
    else:
        print(f"class_idx={res['idx']}, dl_idx={res['dl_idx']}, prob={res['prob']*100:.2f}%")
    # 다른 API에 넘길 때는 res 그대로 사용하면 됨


Device: cpu
=== SINGLE MODE TOP-1 RESULT ===
class_idx=1262, dl_idx=9271, prob=99.13%
