# HybridTwoWay Model


In [None]:
!pip install roboflow torch torchvision torchaudio opencv-python numpy tqdm pillow matplotlib albumentations timm

In [None]:
# ============================================
# Cell 0: ÏÑ§Ïπò Î∞è Í∏∞Î≥∏ Import, Roboflow Îã§Ïö¥Î°úÎìú
# ============================================

import math
import os
from typing import List, Tuple
import cv2
import numpy as np
import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from roboflow import Roboflow
from tqdm import tqdm
import timm # [NEW] PretrainingÏö© ÎùºÏù¥Î∏åÎü¨Î¶¨
import albumentations as A # [NEW] Ï¶ùÍ∞ï ÎùºÏù¥Î∏åÎü¨Î¶¨
import random

# [ÏòµÏÖò] matmul precision (ÏßÄÏõêÎêòÎäî ÌôòÍ≤ΩÏóêÏÑúÎßå)
try:
    torch.set_float32_matmul_precision('high')
except Exception:
    pass

rf = Roboflow(api_key="HG9M6YJZpcCUgAQaKO9v")
project = rf.workspace("arakon").project("detection-base-hqaeg")
version = project.version(6)
dataset = version.download("yolov8")

print(f'Roboflow dataset downloaded to: {dataset.location}')


In [None]:
# ============================================
# Cell 1: Í≥µÏö© Ïú†Ìã∏ IoU, NMS Îì±
# ============================================

def box_iou_matrix(box1, box2):
    """
    box1: (N, 4), box2: (M, 4)  -> (N, M)
    """
    N = box1.size(0)
    M = box2.size(0)
    if N == 0 or M == 0:
        return torch.zeros(N, M, device=box1.device)
    tl = torch.max(box1[:, None, :2], box2[:, :2])  # (N, M, 2)
    br = torch.min(box1[:, None, 2:], box2[:, 2:])  # (N, M, 2)
    wh = (br - tl).clamp(min=0)
    inter = wh[..., 0] * wh[..., 1]
    area1 = (box1[:, 2]-box1[:, 0]) * (box1[:, 3]-box1[:, 1])
    area2 = (box2[:, 2]-box2[:, 0]) * (box2[:, 3]-box2[:, 1])
    return inter / (area1[:, None] + area2 - inter + 1e-6)

def box_iou_pair(box1, box2):
    """
    Í∞Å anchorÏóê ÎåÄÌï¥ Îß§Ïπ≠Îêú GT ÌïòÎÇòÏî© ÏûàÏùÑ Îïå:
    box1, box2: (N, 4) -> (N,) IoU
    """
    if box1.numel() == 0:
        return torch.zeros(0, device=box1.device)
    tl = torch.max(box1[:, :2], box2[:, :2])
    br = torch.min(box1[:, 2:], box2[:, 2:])
    wh = (br - tl).clamp(min=0)
    inter = wh[:, 0] * wh[:, 1]
    area1 = (box1[:, 2]-box1[:, 0]) * (box1[:, 3]-box1[:, 1])
    area2 = (box2[:, 2]-box2[:, 0]) * (box2[:, 3]-box2[:, 1])
    return inter / (area1 + area2 - inter + 1e-6)

def nms(boxes, scores, iou_thres=0.5):
    if boxes.numel() == 0:
        return torch.zeros(0, dtype=torch.long, device=boxes.device)
    idxs = scores.argsort(descending=True)
    keep = []
    while idxs.numel() > 0:
        i = idxs[0]
        keep.append(i.item())
        if idxs.numel() == 1:
            break
        ious = box_iou_matrix(boxes[i].unsqueeze(0), boxes[idxs[1:]]).squeeze(0)
        idxs = idxs[1:][ious < iou_thres]
    return torch.tensor(keep, dtype=torch.long, device=boxes.device)

In [None]:
# ============================================
# Cell 2: Conv Block & Stem
# ============================================

def conv_bn_act(in_ch, out_ch, k=3, s=1, p=1, act=True):
    m = [nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
         nn.BatchNorm2d(out_ch)]
    if act:
        m.append(nn.SiLU(inplace=True))
    return nn.Sequential(*m)

class FixedGaussianBlur(nn.Module):
    def __init__(self, channels, k=5, sigma=1.0):
        super().__init__()
        grid = torch.arange(k).float() - (k - 1) / 2
        gauss = torch.exp(-(grid ** 2) / (2 * sigma ** 2))
        kernel1d = gauss / gauss.sum()
        kernel2d = torch.outer(kernel1d, kernel1d)
        weight = kernel2d[None, None, :, :].repeat(channels, 1, 1, 1)
        self.register_buffer('weight', weight)
        self.groups = channels
        self.k = k

    def forward(self, x):
        pad = (self.k // 2,) * 4
        return F.conv2d(F.pad(x, pad, mode='reflect'),
                        self.weight, groups=self.groups)

class AnomalyAwareStem(nn.Module):
    def __init__(self, in_ch=3, base_ch=48):
        super().__init__()
        C1, C2, C3 = base_ch, base_ch * 2, base_ch * 4
        self.stem = nn.Sequential(
            conv_bn_act(in_ch, C1, 3, 2, 1),
            conv_bn_act(C1, C2, 3, 2, 1),
            conv_bn_act(C2, C3, 3, 2, 1),
        )
        self.blur = FixedGaussianBlur(in_ch, k=5, sigma=1.0)
        self.anom = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, 1, 1, groups=in_ch, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch, C3 // 4, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C3 // 4),
            nn.SiLU(inplace=True),
        )
        self.fuse = nn.Conv2d(C3 + C3 // 4, C3, 1, 1, 0, bias=False)
        self.fuse_bn = nn.BatchNorm2d(C3)
        self.vis_head = nn.Conv2d(C3, 1, 1, 1, 0)
        self.base_ch = base_ch

    @property
    def out_channels(self):
        return 4 * self.base_ch

    def forward(self, x):
        f_main = self.stem(x)
        blurred = self.blur(x)
        high = x - blurred
        high_ds = F.interpolate(high, size=f_main.shape[-2:], mode='bilinear', align_corners=False)
        f_anom = self.anom(high_ds)
        f = torch.cat([f_main, f_anom], dim=1)
        f = self.fuse_bn(self.fuse(f))
        f = F.silu(f, inplace=True)
        v = torch.sigmoid(self.vis_head(f_main))
        return f, v

In [None]:
# ============================================
# Cell 3 : Feedback Adapter
# ============================================

class FeedbackAdapter(nn.Module):
    def __init__(self, d_token: int, c_stem: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(d_token, c_stem * 2, 1, 1, 0, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(c_stem * 2))
        layers.append(nn.SiLU(inplace=True))
        self.adapter = nn.Sequential(*layers)

    def forward(self, tokens, Ht, Wt, f_stem):
        B, N, D = tokens.shape
        t2d = tokens.transpose(1, 2).reshape(B, D, Ht, Wt)
        ab = self.adapter(t2d)
        Cs = f_stem.shape[1]
        gamma, beta = torch.split(ab, Cs, dim=1)
        return f_stem * (1 + torch.tanh(gamma)) + beta

In [None]:
# ============================================
# Cell 4 : PANLite Neck
# ============================================

class PANLite(nn.Module):
    def __init__(self, in_ch=512, mid=256):
        super().__init__()
        self.lateral = conv_bn_act(in_ch, mid, 1, 1, 0)
        self.down4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.down5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.up4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.up3 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse5 = conv_bn_act(mid + mid, mid, 3, 1, 1)

    def forward(self, p3):
        p3 = self.lateral(p3)
        p4 = self.down4(p3)
        p5 = self.down5(p4)

        p4u = F.interpolate(p5, size=p4.shape[-2:], mode='nearest')
        p4 = self.up4(torch.cat([p4, p4u], dim=1))

        p3u = F.interpolate(p4, size=p3.shape[-2:], mode='nearest')
        p3 = self.up3(torch.cat([p3, p3u], dim=1))

        p4b = self.down_f4(p3)
        p4 = self.fuse4(torch.cat([p4, p4b], dim=1))

        p5b = self.down_f5(p4)
        p5 = self.fuse5(torch.cat([p5, p5b], dim=1))

        return p3, p4, p5

In [None]:
# ============================================
# Cell 5 : YOLOLite Head
# ============================================

class YOLOHeadLite(nn.Module):
    def __init__(self, in_ch=256, num_classes=1):
        super().__init__()
        c = in_ch
        self.stem3 = conv_bn_act(c, c, 3, 1, 1)
        self.stem4 = conv_bn_act(c, c, 3, 1, 1)
        self.stem5 = conv_bn_act(c, c, 3, 1, 1)

        self.cls3 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj3 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box3 = nn.Conv2d(c, 4, 1, 1, 0)
        nn.init.constant_(self.obj3.bias, -4.59)

        self.cls4 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj4 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box4 = nn.Conv2d(c, 4, 1, 1, 0)
        nn.init.constant_(self.obj4.bias, -4.59)

        self.cls5 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj5 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box5 = nn.Conv2d(c, 4, 1, 1, 0)
        nn.init.constant_(self.obj5.bias, -4.59)

    def forward_single(self, x, stem, cls, obj, box):
        f = stem(x)
        return cls(f), obj(f), box(f)

    def forward(self, p3, p4, p5):
        c3, o3, b3 = self.forward_single(p3, self.stem3, self.cls3, self.obj3, self.box3)
        c4, o4, b4 = self.forward_single(p4, self.stem4, self.cls4, self.obj4, self.box4)
        c5, o5, b5 = self.forward_single(p5, self.stem5, self.cls5, self.obj5, self.box5)
        return [(c3, o3, b3), (c4, o4, b4), (c5, o5, b5)]


In [None]:
# ============================================
# Cell 6 : HybridTwoWay Model (timm ViT + Feedback)
# ============================================

class HybridTwoWay(nn.Module):
    def __init__(self, 
                 in_ch=3, 
                 stem_base=32, 
                 embed_dim=768,
                 vit_model_name='vit_base_patch16_224.augreg_in21k_ft_in1k',
                 num_classes=3, 
                 iters=1, 
                 detach_feedback=True, 
                 img_size=512):
        super().__init__()
        self.iters = iters
        self.detach_feedback = detach_feedback
        self.img_size = img_size

        # Stem
        self.stem = AnomalyAwareStem(in_ch=in_ch, base_ch=stem_base)
        c_stem = stem_base * 4  # ex) 64*4=256

        # Pretrained ViT
        print(f"üîÑ Loading Pretrained Weights: {vit_model_name}...")
        self.vit = timm.create_model(vit_model_name, pretrained=True, num_classes=0)
        self.vit_dim = self.vit.num_features

        # Stem -> ViT
        self.to_vit = nn.Sequential(
            nn.Conv2d(c_stem, self.vit_dim, 1),
            nn.BatchNorm2d(self.vit_dim),
            nn.SiLU()
        )

        # ViT -> Neck(256)
        self.from_vit = nn.Sequential(
            nn.Conv2d(self.vit_dim, 256, 1),
            nn.BatchNorm2d(256),
            nn.SiLU()
        )

        # Feedback Adapter
        self.feedback = FeedbackAdapter(d_token=self.vit_dim, c_stem=c_stem, use_bn=True)

        # Neck & Head
        self.neck = PANLite(in_ch=256, mid=256)
        self.head = YOLOHeadLite(in_ch=256, num_classes=num_classes)

        # pos_embed ÏïàÏ†Ñ Ï≤òÎ¶¨Ïö© flag
        self.has_pos_embed = hasattr(self.vit, "pos_embed") and self.vit.pos_embed is not None

    def get_interpolated_pos_embed(self, Ht, Wt):
        """
        timm ViTÏùò pos_embedÎ•º ÌòÑÏû¨ ÌÜ†ÌÅ∞ grid(Ht,Wt)Ïóê ÎßûÍ≤å interpolation
        """
        if not self.has_pos_embed:
            return 0

        pos_embed = self.vit.pos_embed  # (1, 1+N, D) or (1, N, D)
        num_prefix = getattr(self.vit, "num_prefix_tokens", 1)
        if pos_embed.shape[1] == Ht * Wt + num_prefix:
            # ÌÜ†ÌÅ∞ ÏàòÍ∞Ä Ïù¥ÎØ∏ Í∞ôÏúºÎ©¥ Í∑∏ÎåÄÎ°ú ÏÇ¨Ïö©
            src_pos = pos_embed[:, num_prefix:]
            return src_pos

        src_pos = pos_embed[:, num_prefix:]  # (1, N, D)
        N = src_pos.shape[1]
        src_h = int(N ** 0.5)
        src_w = src_h

        src_pos = src_pos.reshape(1, src_h, src_w, -1).permute(0, 3, 1, 2)
        dst = F.interpolate(src_pos, size=(Ht, Wt), mode='bicubic', align_corners=False)
        dst = dst.flatten(2).transpose(1, 2)  # (1, Ht*Wt, D)
        return dst

    def forward(self, x):
        # Stem
        f_stem, vis = self.stem(x)  # (B, Cs, Hs, Ws)
        f_vit_in = self.to_vit(f_stem)  # (B, D, Ht, Wt)
        B, D, Ht, Wt = f_vit_in.shape

        tokens = f_vit_in.flatten(2).transpose(1, 2)  # (B, N, D), N=Ht*Wt

        # Positional Embedding interpolate
        pos_embed = self.get_interpolated_pos_embed(Ht, Wt)  # (1, N, D) or 0
        if isinstance(pos_embed, torch.Tensor):
            pos_embed = pos_embed.to(tokens.device)
            tokens = tokens + pos_embed  # broadcast (1,N,D) + (B,N,D)

        f_fb = f_stem
        preds, aux = None, None

        for i in range(self.iters):
            # ViT blocks (ÏïàÏ†ÑÌïòÍ≤å loop)
            t = tokens
            for blk in self.vit.blocks:
                t = blk(t)
            if hasattr(self.vit, "norm") and self.vit.norm is not None:
                t = self.vit.norm(t)
            tokens_out = t  # (B,N,D)

            # Feedback
            toks_for_fb = tokens_out.detach() if self.detach_feedback else tokens_out
            f_fb = self.feedback(toks_for_fb, Ht, Wt, f_fb)

            # NeckÏóê ÎÑ£ÏùÑ feature (Stem domain -> ViT dim -> Neck dim)
            p3_in = self.from_vit(self.to_vit(f_fb))  # (B,256,Ht,Wt)
            p3, p4, p5 = self.neck(p3_in)
            preds = self.head(p3, p4, p5)
            aux = {"P3": p3, "P4": p4, "P5": p5, "V": vis}

            # Îã§Ïùå iterationÏù¥Î©¥ tokens Í∞±Ïã†
            if i != self.iters - 1:
                f_vit_in = self.to_vit(f_fb)
                B2, D2, Ht2, Wt2 = f_vit_in.shape
                tokens = f_vit_in.flatten(2).transpose(1, 2)
                pos_embed = self.get_interpolated_pos_embed(Ht2, Wt2)
                if isinstance(pos_embed, torch.Tensor):
                    pos_embed = pos_embed.to(tokens.device)
                    tokens = tokens + pos_embed

        return preds, aux

In [None]:
# ============================================
# Cell 7 : Dataset & Dataloader
# ============================================

IMG_SIZE = 512

def yolo_collate_fn(batch):
    imgs = []
    targets = []
    for img, tgt in batch:
        imgs.append(img)
        targets.append(tgt)
    imgs = torch.stack(imgs, 0)
    return imgs, targets

class YoloDataset(Dataset):
    def __init__(self, root, is_train=True, mosaic_prob=0.5):
        self.img_dir = os.path.join(root, "images")
        self.label_dir = os.path.join(root, "labels")
        self.images = sorted(os.listdir(self.img_dir))
        self.is_train = is_train
        self.mosaic_prob = mosaic_prob if is_train else 0.0
        self.img_size = IMG_SIZE

    def __len__(self):
        return len(self.images)

    def load_image_and_boxes(self, index):
        """ÏõêÎ≥∏ ÌÅ¨Í∏∞ Í∑∏ÎåÄÎ°ú Î°úÎìú, RGB + xyxy(abs) Î¶¨ÌÑ¥"""
        name = self.images[index]
        img_path = os.path.join(self.img_dir, name)
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]

        label_path = os.path.join(self.label_dir, name.rsplit(".", 1)[0] + ".txt")
        boxes = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f.readlines():
                    cls, cx, cy, bw, bh = map(float, line.split())
                    x1 = (cx - bw/2) * w
                    y1 = (cy - bh/2) * h
                    x2 = (cx + bw/2) * w
                    y2 = (cy + bh/2) * h
                    boxes.append([cls, x1, y1, x2, y2])
        boxes = np.array(boxes) if len(boxes) > 0 else np.zeros((0, 5))
        return img, boxes, (h, w)

    def load_mosaic(self, index):
        """
        [ÏïàÏ†Ñ Î≤ÑÏ†Ñ]
        ÏµúÏ¢Ö Ïù¥ÎØ∏ÏßÄ: (img_size, img_size)
        Í∞Å ÌÉÄÏùº: (tile, tile) = (img_size/2, img_size/2)
        4Ïû•ÏùÑ Î¨¥Ï°∞Í±¥ 2x2Î°ú Î∞∞Ïπò ‚Üí Ïä¨ÎùºÏù¥Ïã± Î∂àÏùºÏπò ÏóÜÏùå
        """
        tile = self.img_size // 2
        mosaic_img = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
        mosaic_boxes = []

        indices = [index] + [random.randint(0, len(self.images) - 1) for _ in range(3)]

        for i, idx in enumerate(indices):
            img, boxes, (h0, w0) = self.load_image_and_boxes(idx)

            # ÏõêÎ≥∏ ‚Üí ÌÉÄÏùº ÌÅ¨Í∏∞ (256x256)
            img_resized = cv2.resize(img, (tile, tile))
            scale_x = tile / w0
            scale_y = tile / h0

            # ÌÉÄÏùº Î∞∞Ïπò ÏúÑÏπò
            row = i // 2  # 0 or 1
            col = i % 2   # 0 or 1
            x_off = col * tile
            y_off = row * tile

            # Î™®ÏûêÏù¥ÌÅ¨ Ïù¥ÎØ∏ÏßÄÏóê Î∂ôÏù¥Í∏∞
            mosaic_img[y_off:y_off+tile, x_off:x_off+tile] = img_resized

            # Î∞ïÏä§ Ïä§ÏºÄÏùº + Ïò§ÌîÑÏÖã
            if len(boxes) > 0:
                b = boxes.copy()
                b[:, 1] = b[:, 1] * scale_x + x_off
                b[:, 3] = b[:, 3] * scale_x + x_off
                b[:, 2] = b[:, 2] * scale_y + y_off
                b[:, 4] = b[:, 4] * scale_y + y_off
                mosaic_boxes.append(b)

        if len(mosaic_boxes) > 0:
            mosaic_boxes = np.concatenate(mosaic_boxes, 0)
            np.clip(mosaic_boxes[:, 1:], 0, self.img_size, out=mosaic_boxes[:, 1:])
        else:
            mosaic_boxes = np.zeros((0, 5))

        return mosaic_img, mosaic_boxes

    def __getitem__(self, idx):
        # 1) Ïù¥ÎØ∏ÏßÄ/Î∞ïÏä§ Î°úÎìú (mosaic or not)
        if self.is_train and random.random() < self.mosaic_prob:
            img, boxes_xyxy = self.load_mosaic(idx)
        else:
            img, boxes_xyxy, (h0, w0) = self.load_image_and_boxes(idx)
            # ÏõêÎ≥∏ ‚Üí (img_size, img_size)
            img = cv2.resize(img, (self.img_size, self.img_size))
            if len(boxes_xyxy) > 0:
                boxes_xyxy = boxes_xyxy.copy()
                scale_x = self.img_size / w0
                scale_y = self.img_size / h0
                boxes_xyxy[:, 1] *= scale_x
                boxes_xyxy[:, 3] *= scale_x
                boxes_xyxy[:, 2] *= scale_y
                boxes_xyxy[:, 4] *= scale_y
                np.clip(boxes_xyxy[:, 1:], 0, self.img_size, out=boxes_xyxy[:, 1:])
            else:
                boxes_xyxy = np.zeros((0, 5))

        # 2) Ïù¥ÎØ∏ÏßÄ Ï†ïÍ∑úÌôî ‚Üí Tensor
        img = img.astype(np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1)  # (3,H,W)

        # 3) xyxy ‚Üí cxcywh (normalized)
        targets = []
        if len(boxes_xyxy) > 0:
            boxes_xyxy = boxes_xyxy[
                np.logical_and(
                    boxes_xyxy[:, 3] > boxes_xyxy[:, 1],
                    boxes_xyxy[:, 4] > boxes_xyxy[:, 2]
                )
            ]
            for box in boxes_xyxy:
                cls = box[0]
                x1, y1, x2, y2 = box[1:]
                cx = (x1 + x2) / 2 / self.img_size
                cy = (y1 + y2) / 2 / self.img_size
                w  = (x2 - x1) / self.img_size
                h  = (y2 - y1) / self.img_size
                targets.append([cls, cx, cy, w, h])

        targets = torch.tensor(targets, dtype=torch.float32)
        return img, targets


In [None]:
# ============================================
# Cell 8 : Task Aligned Assigner & Loss
# ============================================

def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
    n_anchors = xy_centers.shape[0]
    bs, n_boxes, _ = gt_bboxes.shape
    lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # (B*M,1,2) each
    bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2)
    bbox_deltas = bbox_deltas.view(bs, n_boxes, n_anchors, -1)
    return bbox_deltas.min(3)[0] > eps

class TaskAlignedAssigner(nn.Module):
    def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9):
        super().__init__()
        self.topk = topk
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    @torch.no_grad()
    def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
        """
        pd_scores: (B, A, C) - sigmoidÎêú cls score
        pd_bboxes: (B, A, 4) - xyxy
        anc_points: (A, 2)
        gt_labels: (B, M, 1)
        gt_bboxes: (B, M, 4) - xyxy
        mask_gt: (B, M)
        """
        bs, n_max_boxes = gt_bboxes.shape[:2]

        mask_pos, align_metric, overlaps = self.get_pos_mask(
            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
        )

        # top-k
        target_gt_idx, fg_mask, mask_pos = self.select_topk_candidates(
            align_metric * mask_pos,
            topk_mask=mask_gt[..., None].expand(-1, -1, self.topk).bool()
        )

        # target labels & bboxes
        target_labels = gt_labels.long().squeeze(-1)  # (B,M)
        # (B,A) ‚Üí (B,A) label
        target_labels = torch.gather(target_labels, 1, target_gt_idx.clamp(min=0))
        target_bboxes = []
        for b in range(bs):
            tb = gt_bboxes[b, target_gt_idx[b].clamp(min=0)]  # (A,4)
            target_bboxes.append(tb)
        target_bboxes = torch.stack(target_bboxes, dim=0)  # (B,A,4)

        # alignment metric as score target (Ïó¨Í∏∞ÏÑ† Îî∞Î°ú Ïì∞ÏßÑ ÏïäÏßÄÎßå ÎÇ®Í≤®Îë†)
        return target_labels, target_bboxes, None, fg_mask, target_gt_idx

    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
        bs, n_max_boxes = gt_bboxes.shape[:2]
        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)  # (B,M,A)

        overlaps = []
        for b in range(bs):
            iou = box_iou_matrix(gt_bboxes[b], pd_bboxes[b])  # (M,A)
            overlaps.append(iou)
        overlaps = torch.stack(overlaps, dim=0)  # (B,M,A)

        bbox_scores = []
        B, A, C = pd_scores.shape
        for b in range(bs):
            labels = gt_labels[b].long().squeeze(-1)  # (M,)
            # [FIX] out-of-range Î∞©ÏßÄ
            labels = labels.clamp(min=0, max=C-1)
            scores = pd_scores[b][:, labels]  # (A,M)
            bbox_scores.append(scores.T)      # (M,A)
        bbox_scores = torch.stack(bbox_scores, dim=0)  # (B,M,A)

        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)

        mask_pos = mask_in_gts & mask_gt[..., None].bool()
        return mask_pos, align_metric, overlaps

    def select_topk_candidates(self, metrics, topk_mask=None):
        B, M, A = metrics.shape
        topk = min(self.topk, A)
        topk_metrics, topk_idxs = torch.topk(metrics, k=topk, dim=-1, largest=True)

        topk_mask_full = torch.zeros_like(metrics, dtype=torch.bool)
        topk_mask_full.scatter_(-1, topk_idxs, True)

        mask_pos = topk_mask_full
        fg_mask = mask_pos.sum(-2) > 0  # (B,A)

        # Í∞Å anchorÎßàÎã§ Ïñ¥Îñ§ GT ÏÑ†ÌÉùÎêòÏóàÎäîÏßÄ
        metrics_pos = metrics.clone()
        metrics_pos[~mask_pos] = -1
        target_gt_idx = metrics_pos.argmax(dim=1)  # (B,A)

        return target_gt_idx, fg_mask, mask_pos

class ComputeLoss(nn.Module):
    def __init__(self, num_classes=3, img_size=512):
        super().__init__()
        self.assigner = TaskAlignedAssigner(topk=10)
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        self.num_classes = num_classes
        self.img_size = img_size

    def forward(self, preds, batch_targets, img_size):
        # preds: [(cls, obj, box), ...]
        # batch_targets: list of (N,5) [cls,cx,cy,w,h]
        cls_preds = []
        box_preds = []
        anchors = []

        B = preds[0][0].shape[0]

        for i, (cls, obj, box) in enumerate(preds):
            bsz, C, H, W = cls.shape
            stride = img_size // H

            grid_y, grid_x = torch.meshgrid(
                torch.arange(H, device=cls.device),
                torch.arange(W, device=cls.device),
                indexing='ij'
            )
            grid = torch.stack((grid_x, grid_y), 2).float()
            grid = (grid + 0.5) * stride
            anchors.append(grid.view(-1, 2))

            cls_preds.append(cls.permute(0, 2, 3, 1).reshape(bsz, -1, C))
            b_p = box.permute(0, 2, 3, 1).reshape(bsz, -1, 4).sigmoid()

            cx = b_p[..., 0] * img_size
            cy = b_p[..., 1] * img_size
            w  = b_p[..., 2] * img_size
            h  = b_p[..., 3] * img_size
            x1 = cx - w/2
            y1 = cy - h/2
            x2 = cx + w/2
            y2 = cy + h/2
            box_preds.append(torch.stack([x1, y1, x2, y2], -1))

        cls_preds = torch.cat(cls_preds, dim=1)  # (B,A_tot,C)
        box_preds = torch.cat(box_preds, dim=1)  # (B,A_tot,4)
        anchors = torch.cat(anchors, dim=0)      # (A_tot,2)

        max_boxes = max([len(t) for t in batch_targets])
        if max_boxes == 0:
            return cls_preds.sum() * 0.0

        batch_gt_labels = torch.zeros(B, max_boxes, 1, device=cls_preds.device)
        batch_gt_bboxes = torch.zeros(B, max_boxes, 4, device=cls_preds.device)
        batch_mask_gt   = torch.zeros(B, max_boxes, device=cls_preds.device)

        for b, t in enumerate(batch_targets):
            n = len(t)
            if n > 0:
                batch_gt_labels[b, :n, 0] = t[:, 0]
                cx, cy, w, h = t[:, 1], t[:, 2], t[:, 3], t[:, 4]
                x1 = (cx - w/2) * img_size
                y1 = (cy - h/2) * img_size
                x2 = (cx + w/2) * img_size
                y2 = (cy + h/2) * img_size
                batch_gt_bboxes[b, :n] = torch.stack([x1, y1, x2, y2], dim=-1)
                batch_mask_gt[b, :n] = 1.0

        target_labels, target_bboxes, _, fg_mask, target_gt_idx = self.assigner(
            cls_preds.sigmoid(), box_preds, anchors,
            batch_gt_labels, batch_gt_bboxes, batch_mask_gt
        )

        num_pos = fg_mask.sum().clamp(min=1.0)

        # Classification Target
        target_cls = torch.zeros_like(cls_preds)
        for b in range(B):
            pos_idx = fg_mask[b].nonzero(as_tuple=False).squeeze(-1)
            if pos_idx.numel() == 0:
                continue
            cls_idx = target_labels[b, pos_idx].long().clamp(min=0, max=self.num_classes-1)
            target_cls[b, pos_idx, cls_idx] = 1.0

        loss_cls = self.bce(cls_preds, target_cls).sum() / num_pos

        # Box Loss (element-wise IoU)
        loss_box = 0.0
        if fg_mask.sum() > 0:
            pos_pred_box = box_preds[fg_mask]
            pos_tgt_box = target_bboxes[fg_mask]
            ious = box_iou_pair(pos_pred_box, pos_tgt_box)
            loss_box = (1.0 - ious).sum() / num_pos

        return loss_cls + 4.0 * loss_box

criterion = ComputeLoss(num_classes=3, img_size=IMG_SIZE)

def yolo_loss_tal(preds, targets, img_size):
    return criterion(preds, targets, img_size)


In [None]:
# ============================================
# Cell 9 : Decode Predictions + mAP Evaluation
# ============================================

def decode_predictions(preds, img_size=512, conf_thres=0.25, nms_iou_thres=0.5):
    all_outputs = []
    B = preds[0][0].shape[0]

    for b in range(B):
        dets_all = []
        for (cls_pred, obj_pred, box_pred) in preds:
            B_s, C, H, W = cls_pred.shape

            cls_logits = cls_pred[b].permute(1,2,0).reshape(H*W, C)
            box_logits = box_pred[b].permute(1,2,0).reshape(H*W, 4)

            cls_scores = cls_logits.sigmoid()
            box_norm = box_logits.sigmoid()

            cls_max_scores, cls_ids = cls_scores.max(dim=-1)
            scores = cls_max_scores

            mask = scores > conf_thres
            if mask.sum() == 0:
                continue

            scores_ = scores[mask]
            cls_ids_ = cls_ids[mask]
            boxes = box_norm[mask]

            x_c = boxes[:, 0] * img_size
            y_c = boxes[:, 1] * img_size
            w   = boxes[:, 2] * img_size
            h   = boxes[:, 3] * img_size
            x1 = (x_c - w/2).clamp(0, img_size)
            y1 = (y_c - h/2).clamp(0, img_size)
            x2 = (x_c + w/2).clamp(0, img_size)
            y2 = (y_c + h/2).clamp(0, img_size)

            boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)

            keep = nms(boxes_xyxy, scores_, iou_thres=nms_iou_thres)
            if keep.numel() == 0:
                continue

            dets = torch.cat([
                boxes_xyxy[keep],
                scores_[keep].unsqueeze(1),
                cls_ids_[keep].float().unsqueeze(1)
            ], dim=1)
            dets_all.append(dets)

        all_outputs.append(torch.cat(dets_all, dim=0) if len(dets_all) > 0 else [])
    return all_outputs

def compute_ap(recall, precision):
    mrec = torch.cat([torch.tensor([0.0]), recall, torch.tensor([1.0])])
    mpre = torch.cat([torch.tensor([0.0]), precision, torch.tensor([0.0])])
    for i in range(mpre.size(0)-1, 0, -1):
        mpre[i-1] = torch.max(mpre[i-1], mpre[i])
    idx = (mrec[1:] != mrec[:-1]).nonzero().squeeze()
    return ((mrec[idx+1] - mrec[idx]) * mpre[idx+1]).sum().item()

def evaluate_map(model, dataloader, num_classes=3, img_size=512, iou_thr=0.5, conf_thres=0.25):
    model.eval()
    device = next(model.parameters()).device
    all_dets = {c: [] for c in range(num_classes)}
    all_gts  = {c: [] for c in range(num_classes)}
    global_img_id = 0

    with torch.no_grad():
        for batch_i, (imgs, targets) in enumerate(dataloader):
            imgs = imgs.to(device)
            targets = [t.to(device) for t in targets]
            preds, _ = model(imgs)
            dets_list = decode_predictions(preds, img_size=img_size, conf_thres=conf_thres)

            for b in range(len(imgs)):
                dets = dets_list[b]
                gt = targets[b]
                current_img_id = global_img_id
                global_img_id += 1

                if len(gt) > 0:
                    gcls = gt[:, 0].long()
                    gxy  = gt[:, 1:3] * img_size
                    gwh  = gt[:, 3:5] * img_size
                    gx1 = gxy[:, 0] - gwh[:, 0]/2
                    gy1 = gxy[:, 1] - gwh[:, 1]/2
                    gx2 = gxy[:, 0] + gwh[:, 0]/2
                    gy2 = gxy[:, 1] + gwh[:, 1]/2
                    gboxes = torch.stack([gx1, gy1, gx2, gy2], dim=1)
                    for c in range(num_classes):
                        mask = (gcls == c)
                        if mask.sum() > 0:
                            all_gts[c].append((current_img_id, gboxes[mask].cpu()))

                if dets is not None and len(dets) > 0:
                    boxes, scores, cls_ids = dets[:, :4], dets[:, 4], dets[:, 5].long()
                    for c in range(num_classes):
                        mask = (cls_ids == c)
                        if mask.sum() > 0:
                            all_dets[c].append((current_img_id, scores[mask].cpu(), boxes[mask].cpu()))

    aps = []
    for c in range(num_classes):
        gts_c = all_gts[c]
        if len(gts_c) == 0:
            continue
        # [FIX] ÎèôÏùº Ïù¥ÎØ∏ÏßÄ Ïó¨Îü¨ ÏóîÌä∏Î¶¨ concat
        gt_dict = {}
        for img_id, boxes in gts_c:
            if img_id not in gt_dict:
                gt_dict[img_id] = {
                    "boxes": boxes.clone(),
                    "matched": torch.zeros(boxes.size(0), dtype=torch.bool)
                }
            else:
                old = gt_dict[img_id]
                new_boxes = torch.cat([old["boxes"], boxes], dim=0)
                new_matched = torch.cat(
                    [old["matched"], torch.zeros(boxes.size(0), dtype=torch.bool)],
                    dim=0
                )
                gt_dict[img_id] = {"boxes": new_boxes, "matched": new_matched}

        n_gt = sum(v["boxes"].size(0) for v in gt_dict.values())

        dets_c = all_dets[c]
        if len(dets_c) == 0:
            aps.append(0.0)
            continue

        scores_all, boxes_all, img_ids_all = [], [], []
        for img_id, scores, boxes in dets_c:
            for i in range(boxes.size(0)):
                scores_all.append(scores[i].item())
                boxes_all.append(boxes[i])
                img_ids_all.append(img_id)

        scores_all = torch.tensor(scores_all)
        boxes_all = torch.stack(boxes_all, dim=0)
        order = scores_all.argsort(descending=True)
        scores_all = scores_all[order]
        boxes_all = boxes_all[order]
        img_ids_all = [img_ids_all[i] for i in order]

        tps = torch.zeros(len(scores_all))
        fps = torch.zeros(len(scores_all))

        for i in range(len(scores_all)):
            img_id = img_ids_all[i]
            pred_box = boxes_all[i].unsqueeze(0)
            if img_id not in gt_dict:
                fps[i] = 1
                continue
            gt_entry = gt_dict[img_id]
            ious = box_iou_matrix(pred_box, gt_entry["boxes"]).squeeze(0)
            if ious.numel() == 0:
                fps[i] = 1
                continue
            max_iou, max_idx = ious.max(0)
            if max_iou >= iou_thr and not gt_entry["matched"][max_idx]:
                tps[i] = 1
                gt_entry["matched"][max_idx] = True
            else:
                fps[i] = 1

        tp_cum = torch.cumsum(tps, dim=0)
        fp_cum = torch.cumsum(fps, dim=0)
        recall = tp_cum / (n_gt + 1e-6)
        precision = tp_cum / (tp_cum + fp_cum + 1e-6)
        aps.append(compute_ap(recall, precision))

    mAP = sum(aps) / len(aps) if len(aps) > 0 else 0.0
    return mAP, aps

In [None]:
# ============================================
# Cell 10: Îç∞Ïù¥ÌÑ∞ÏÖã Ï†ïÏùò
# ============================================
DATA_PATH = dataset.location
train_dataset = YoloDataset(os.path.join(DATA_PATH, "train"), is_train=True)
val_dataset   = YoloDataset(os.path.join(DATA_PATH, "valid"), is_train=False)
test_dataset  = YoloDataset(os.path.join(DATA_PATH, "test"), is_train=False)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True,
                          collate_fn=yolo_collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,
                        collate_fn=yolo_collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False,
                         collate_fn=yolo_collate_fn, num_workers=2, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")

In [None]:
# ============================================
# Cell 11 : Train Config & Loop
# ============================================

cfg = dict(
    in_ch=3,
    stem_base=64,
    embed_dim=768,
    vit_model_name='vit_base_patch16_224.augreg_in21k_ft_in1k',
    num_classes=3,
    iters=1,
    detach_feedback=False,
    img_size=IMG_SIZE
)

print(f"‚öôÔ∏è Configuration: {cfg}")

model = HybridTwoWay(**cfg).to(device)

# AMP ÏÇ¨Ïö© Ïó¨Î∂Ä
use_amp = (device.type == 'cuda')
if use_amp:
    scaler = torch.amp.GradScaler('cuda')
else:
    scaler = None

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.05)
EPOCHS = 15
ACCUM_STEPS = 4

steps_per_epoch = max(1, len(train_loader) // ACCUM_STEPS)  # [FIX] 0 Î∞©ÏßÄ
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=2e-5,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    pct_start=0.3,
    div_factor=25.0,
    final_div_factor=1000.0
)

best_map = -1.0
best_epoch = -1

print("‚úÖ Safe Mode Ready: Iters=1, Max LR=2e-5")

print("üöÄ Start Training (Safe Mode)...")

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    optimizer.zero_grad()

    for i, (imgs, targets) in enumerate(loop):
        imgs = imgs.to(device)
        targets = [t.to(device) for t in targets]

        if use_amp:
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                preds, aux = model(imgs)
                loss = yolo_loss_tal(preds, targets, img_size=IMG_SIZE)
                loss = loss / ACCUM_STEPS
            scaler.scale(loss).backward()
        else:
            preds, aux = model(imgs)
            loss = yolo_loss_tal(preds, targets, img_size=IMG_SIZE)
            loss = loss / ACCUM_STEPS
            loss.backward()

        if (i + 1) % ACCUM_STEPS == 0:
            if use_amp:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        current_loss = loss.item() * ACCUM_STEPS
        total_loss += current_loss
        current_lr = optimizer.param_groups[0]['lr']
        loop.set_postfix(loss=f"{current_loss:.4f}", lr=f"{current_lr:.8f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Train Average Loss: {avg_loss:.4f}")

    val_map, val_aps = evaluate_map(
        model, val_loader,
        num_classes=cfg["num_classes"],
        img_size=IMG_SIZE,
        iou_thr=0.5,
        conf_thres=0.001
    )
    print(f"Epoch {epoch+1} | Val mAP@0.5: {val_map:.4f}")

    if val_map > best_map:
        best_map = val_map
        best_epoch = epoch + 1
        state_dict = model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict()
        ckpt = {"state_dict": state_dict, "cfg": cfg, "epoch": best_epoch, "val_map": best_map}
        torch.save(ckpt, "hybrid_two_way_best.pt")
        print(f"‚úÖ Best model saved! (Val mAP: {best_map:.4f})")

print("üèÅ Training Finished!")

In [None]:
# ============================================
# Cell 12 : Load Best & Test Eval
# ============================================

checkpoint = torch.load("hybrid_two_way_best.pt", map_location=device)
loaded_cfg = checkpoint["cfg"]

print(f"üìÑ Loaded Config: {loaded_cfg}")

model = HybridTwoWay(**loaded_cfg).to(device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

print(f"‚úÖ Model Loaded from Epoch {checkpoint['epoch']} (Val mAP: {checkpoint['val_map']:.4f})")

test_map, class_aps = evaluate_map(
    model, test_loader,
    num_classes=loaded_cfg["num_classes"],
    img_size=IMG_SIZE,
    conf_thres=0.001
)

print(f"\nüèÜ Final Test mAP@0.5: {test_map:.4f}")
for i, ap in enumerate(class_aps):
    print(f"   Class {i} AP@0.5: {ap:.4f}")


In [None]:
# ============================================
# Cell 13 : Sanity Check
# ============================================

x = torch.randn(2, 3, IMG_SIZE, IMG_SIZE).to(device)
preds, aux = model(x)
for level, (c, o, b) in zip(["P3","P4","P5"], preds):
    print(f"[{level}] cls: {list(c.shape)}, obj: {list(o.shape)}, box: {list(b.shape)}")

In [None]:
# ÌÅ¥ÎûòÏä§ Î∂ÑÌè¨ ÌôïÏù∏
def count_class_dist(dataset, name):
    cnt = Counter()
    for i in range(len(dataset)):
        _, tgt = dataset[i]
        if tgt.numel() == 0:
            continue
        cls_ids = tgt[:, 0].long().tolist()
        cnt.update(cls_ids)
    print(f"\n[{name}] class distribution:")
    for k in sorted(cnt.keys()):
        print(f"  class {k}: {cnt[k]} boxes")

count_class_dist(train_dataset, "train")
count_class_dist(val_dataset,   "val")
count_class_dist(test_dataset,  "test")