In [5]:
# === Fast Inference API (Multi-pill, Top-1 Only) ===
# YOLO → 최대 4개 알약 bbox 탐지 → 각 알약별 ResNet(1324) Top-1 (dl_idx 계산식)
# -*- coding: utf-8 -*-
import os
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms

from ultralytics import YOLO
from PIL import Image

# -------------------------------------------------
# 0) 경로 / 기본 설정
# -------------------------------------------------
DRIVE = "/content/drive/MyDrive"

IMG_PATH = "/content/drive/MyDrive/캡스톤_원천_데이터/TS_48_단일.zip/K-018357/K-018357_0_2_0_0_70_000_200.png"

BEST_YOLO      = os.path.join(DRIVE, "best.pt")
RESNET_1324_PT = os.path.join(DRIVE, "best_model_generalized.pth")
CLASS_JSON_1K  = os.path.join(DRIVE, "pill_label_path_sharp_score.json")
CLASS_JSON_324 = os.path.join(DRIVE, "class_mapping_from_cache_1324.json")

for p in [BEST_YOLO, RESNET_1324_PT]:
    assert os.path.exists(p), f"가중치 파일 없음: {p}"

YOLO_CONF  = 0.25
YOLO_IOU   = 0.45
YOLO_IMGSZ = 640

CROP_SIZE        = 224
MIN_BOX_SIDE_PX  = 40
SQUARE_SCALE     = 1.3

NUM_CLASSES  = 1324
LABEL_OFFSET = 1000

MAX_PILLS_MULTI = 4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -------------------------------------------------
# 1) 라벨 맵 + dl_idx 계산
# -------------------------------------------------
def load_label_map_generic(json_path):
    if not json_path or not os.path.exists(json_path):
        return {}
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if isinstance(data, dict) and "label_to_kcode" in data:
        data = data["label_to_kcode"]

    out = {}
    if isinstance(data, dict):
        for k, v in data.items():
            try:
                key_int = int(k)
            except:
                continue

            val_str = str(v)
            base = os.path.basename(val_str)
            first = base.split("_")[0]

            if first.startswith("K-") and len(first) == 8 and first[2:].isdigit():
                kcode = first
            elif first.startswith("K") and len(first) == 7 and first[1:].isdigit():
                kcode = "K-" + first[1:]
            else:
                kcode = first

            out[key_int] = kcode

    return out

LABEL_MAP_1K  = load_label_map_generic(CLASS_JSON_1K)
LABEL_MAP_324 = load_label_map_generic(CLASS_JSON_324)

def class_idx_to_kcode(global_idx: int) -> str:
    if global_idx < LABEL_OFFSET:
        return LABEL_MAP_1K.get(global_idx, f"imagenet_{global_idx}")
    local = global_idx - LABEL_OFFSET
    return LABEL_MAP_324.get(local, f"unknown_{local}")

def kcode_to_dl_idx(kcode: str) -> str:
    if len(kcode) >= 7:
        tail = kcode[-6:]
        if tail.isdigit():
            val = int(tail)
            dl_val = val - 1
            if dl_val >= 0:
                return str(dl_val)
    return kcode

def idx_to_dl_idx(global_idx: int) -> str:
    kcode = class_idx_to_kcode(global_idx)
    return kcode_to_dl_idx(kcode)

# -------------------------------------------------
# 2) ResNet 1324 + 전처리 (전역 1회)
# -------------------------------------------------
def build_resnet_1324(num_classes=NUM_CLASSES, model_path=RESNET_1324_PT):
    model = models.resnet152(weights=None)
    in_f = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(p=0.5),
        nn.Linear(in_f, num_classes)
    )

    state = torch.load(model_path, map_location="cpu")
    if isinstance(state, dict):
        if "model_state_dict" in state:
            state = state["model_state_dict"]
        elif "model" in state:
            state = state["model"]
    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing or unexpected:
        print(f"ℹ️ state_dict load: missing={len(missing)}, unexpected={len(unexpected)}")

    model.to(DEVICE)
    model.eval()

    if DEVICE.type == "cuda":
        model.half()

    return model

RESNET_MODEL = build_resnet_1324()

base_transform = transforms.Compose([
    transforms.Resize((CROP_SIZE, CROP_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

def preprocess_pil(pil_img):
    t = base_transform(pil_img)
    if DEVICE.type == "cuda":
        return t.half()
    return t

@torch.no_grad()
def predict_resnet_batch_top1(pil_imgs):
    if not pil_imgs:
        return []
    xs = [preprocess_pil(im) for im in pil_imgs]
    x = torch.stack(xs).to(DEVICE)

    if DEVICE.type == 'cuda':
        logits = RESNET_MODEL(x)
    else:
        logits = RESNET_MODEL(x)

    probs = F.softmax(logits, dim=1)
    top1_prob, top1_idx = torch.topk(probs, 1, dim=1)

    all_results = []
    for i in range(probs.shape[0]):
        all_results.append({
            "idx": int(top1_idx[i, 0].item()),
            "prob": float(top1_prob[i, 0].item()),
        })
    return all_results

# -------------------------------------------------
# 3) YOLO 감지 + 크롭
# -------------------------------------------------
YOLO_DEVICE = 0 if DEVICE.type == "cuda" else "cpu"
YOLO_MODEL = YOLO(BEST_YOLO)

def detect_yolo_boxes(img_path):
    det = YOLO_MODEL(
        img_path,
        imgsz=YOLO_IMGSZ,
        conf=YOLO_CONF,
        iou=YOLO_IOU,
        device=YOLO_DEVICE,
        verbose=False
    )[0]

    img = Image.open(img_path).convert("RGB")
    W, H = img.size

    boxes = []
    confs = []
    if det.boxes is not None and len(det.boxes) > 0:
        xyxy = det.boxes.xyxy.cpu().numpy().tolist()
        conf = det.boxes.conf.cpu().numpy().tolist()
        for (b, c) in zip(xyxy, conf):
            x1, y1, x2, y2 = map(int, b)
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(W-1, x2)
            y2 = min(H-1, y2)
            if (x2-x1) >= MIN_BOX_SIDE_PX and (y2-y1) >= MIN_BOX_SIDE_PX:
                boxes.append([x1, y1, x2, y2])
                confs.append(float(c))

    return img, boxes, confs

def square_crop_from_bbox(pil_img, xyxy, scale=1.3):
    W, H = pil_img.size
    x1, y1, x2, y2 = xyxy
    cx = (x1 + x2) / 2.0
    cy = (y1 + y2) / 2.0
    bw = x2 - x1
    bh = y2 - y1
    side = max(bw, bh) * scale
    side = max(side, MIN_BOX_SIDE_PX * 1.5)

    half = side / 2.0
    nx1 = int(round(cx - half))
    ny1 = int(round(cy - half))
    nx2 = int(round(cx + half))
    ny2 = int(round(cy + half))

    nx1 = max(0, nx1)
    ny1 = max(0, ny1)
    nx2 = min(W, nx2)
    ny2 = min(H, ny2)
    if nx2 <= nx1 or ny2 <= ny1:
        return None

    return pil_img.crop((nx1, ny1, nx2, ny2))

# -------------------------------------------------
# 4) 다중 알약 Top-1 추론
# -------------------------------------------------
def infer_pill_image_multi_top1(img_path: str,
                                max_pills: int = MAX_PILLS_MULTI):
    """
    최대 max_pills개의 알약에 대해 각 알약별 Top-1만 반환.
    return:
      [
        {
          "pill_index": 0,
          "bbox": [x1,y1,x2,y2] or None,
          "idx": <class_idx>,
          "dl_idx": "<dl_idx>",
          "prob": <float>,
        },
        ...
      ]
    """
    img, boxes, confs = detect_yolo_boxes(img_path)
    results = []

    # YOLO가 아무것도 못 잡았으면 전체 이미지를 pill 0으로 취급
    if not boxes:
        crops = [img.copy()]
        preds = predict_resnet_batch_top1(crops)
        if preds:
            p = preds[0]
            dl_idx = idx_to_dl_idx(p["idx"])
            results.append({
                "pill_index": 0,
                "bbox": None,
                "idx": p["idx"],
                "dl_idx": dl_idx,
                "prob": p["prob"],
            })
        return results

    # conf 기준 상위 max_pills 선택
    idx_sorted = sorted(range(len(boxes)), key=lambda i: confs[i], reverse=True)
    idx_keep   = idx_sorted[:max_pills]

    crops = []
    kept_boxes = []
    for i in idx_keep:
        bbox = boxes[i]
        sq = square_crop_from_bbox(img, bbox, scale=SQUARE_SCALE)
        if sq is None:
            continue
        crops.append(sq)
        kept_boxes.append(bbox)

    if not crops:
        crops = [img.copy()]
        kept_boxes = [None]

    preds = predict_resnet_batch_top1(crops)

    for pill_idx, (bbox, p) in enumerate(zip(kept_boxes, preds)):
        dl_idx = idx_to_dl_idx(p["idx"])
        results.append({
            "pill_index": pill_idx,
            "bbox": bbox,
            "idx": p["idx"],
            "dl_idx": dl_idx,
            "prob": p["prob"],
        })

    return results

# -------------------------------------------------
# 5) 테스트 실행
# -------------------------------------------------
if __name__ == "__main__":
    assert os.path.exists(IMG_PATH), f"이미지 없음: {IMG_PATH}"

    print("=== MULTI MODE (up to 4 pills) TOP-1 RESULTS ===")
    out = infer_pill_image_multi_top1(IMG_PATH, max_pills=4)
    for r in out:
        print(
            f"[Pill #{r['pill_index']}] bbox={r['bbox']} "
            f"class_idx={r['idx']}, dl_idx={r['dl_idx']}, prob={r['prob']*100:.2f}%"
        )
    # 다른 API에 넘길 땐 out 그대로 json.dumps 해서 쓰면 됨


Device: cpu
=== MULTI MODE (up to 4 pills) TOP-1 RESULTS ===
[Pill #0] bbox=None class_idx=1054, dl_idx=18356, prob=99.71%


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.3.228-py3-none-any.whl.metadata (37 kB)
Collecting ultralytics-thop>=2.0.18 (from ultralytics)
  Downloading ultralytics_thop-2.0.18-py3-none-any.whl.metadata (14 kB)
Downloading ultralytics-8.3.228-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m51.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.18-py3-none-any.whl (28 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.228 ultralytics-thop-2.0.18
