# HybridTwoWay Model (Colab Ready)


## Imports
필요한 PyTorch 모듈과 타입 힌트를 불러옵니다.

In [None]:
#!pip install roboflow torch torchvision torchaudio opencv-python numpy tqdm pillow matplotlib

In [None]:
import math
from typing import List, Tuple
import os

import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
# ============================================
# Cell 1: 설치 및 기본 Import, Roboflow 다운로드
# ============================================
!pip install roboflow torch torchvision torchaudio opencv-python numpy tqdm pillow matplotlib

import math
from typing import List, Tuple
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from roboflow import Roboflow

rf = Roboflow(api_key="HG9M6YJZpcCUgAQaKO9v")
project = rf.workspace("arakon").project("detection-base-hqaeg")
version = project.version(6)
dataset = version.download("yolov8")

print(f'Roboflow dataset downloaded to: {dataset.location}')


## 0. Utility Functions
## 1. Anomaly-Aware CNN Stem

In [None]:
# ============================================
# Cell 2: 기본 Conv 블록 + Stem
# ============================================

def conv_bn_act(in_ch, out_ch, k=3, s=1, p=1, act=True):
    m = [nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
         nn.BatchNorm2d(out_ch)]
    if act:
        m.append(nn.SiLU(inplace=True))
    return nn.Sequential(*m)

class FixedGaussianBlur(nn.Module):
    def __init__(self, channels, k=5, sigma=1.0):
        super().__init__()
        grid = torch.arange(k).float() - (k - 1) / 2
        gauss = torch.exp(-(grid ** 2) / (2 * sigma ** 2))
        kernel1d = gauss / gauss.sum()
        kernel2d = torch.outer(kernel1d, kernel1d)
        weight = kernel2d[None, None, :, :].repeat(channels, 1, 1, 1)
        self.register_buffer('weight', weight)
        self.groups = channels
        self.k = k

    def forward(self, x):
        pad = (self.k // 2,) * 4
        return F.conv2d(F.pad(x, pad, mode='reflect'),
                        self.weight, groups=self.groups)

class AnomalyAwareStem(nn.Module):
    def __init__(self, in_ch=3, base_ch=48):
        super().__init__()
        C1, C2, C3 = base_ch, base_ch * 2, base_ch * 4
        self.stem = nn.Sequential(
            conv_bn_act(in_ch, C1, 3, 2, 1),
            conv_bn_act(C1, C2, 3, 2, 1),
            conv_bn_act(C2, C3, 3, 2, 1),
        )
        self.blur = FixedGaussianBlur(in_ch, k=5, sigma=1.0)
        self.anom = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, 1, 1, groups=in_ch, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch, C3 // 4, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C3 // 4),
            nn.SiLU(inplace=True),
        )
        self.fuse = nn.Conv2d(C3 + C3 // 4, C3, 1, 1, 0, bias=False)
        self.fuse_bn = nn.BatchNorm2d(C3)
        self.vis_head = nn.Conv2d(C3, 1, 1, 1, 0)

    @property
    def out_channels(self):
        return 4 * self.base_ch

    def forward(self, x):
        f_main = self.stem(x)  # (B,C3,H/8,W/8)
        blurred = self.blur(x)
        high = x - blurred
        high_ds = F.interpolate(high, size=f_main.shape[-2:],
                                mode='bilinear', align_corners=False)
        f_anom = self.anom(high_ds)
        f = torch.cat([f_main, f_anom], dim=1)
        f = self.fuse_bn(self.fuse(f))
        f = F.silu(f, inplace=True)
        v = torch.sigmoid(self.vis_head(f_main))
        return f, v


## 2. Vision Transformer Encoder

In [None]:
# ============================================
# Cell 3: ViT Encoder + Feedback Adapter
# ============================================

class PatchEmbed1x1(nn.Module):
    """CNN 출력 → ViT 임베딩 (해상도 유지)"""
    def __init__(self, in_ch, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, 1, 1, 0, bias=False)
        self.bn = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        x = self.bn(x)
        x = F.silu(x, inplace=True)
        return x

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class MultiheadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (self.qkv(x)
               .reshape(B, N, 3, self.num_heads, self.head_dim)
               .permute(2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiheadSelfAttention(dim, num_heads, drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ViTEncoder(nn.Module):
    def __init__(self, embed_dim=512, depth=8, num_heads=8):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio=4.0, drop=0.0)
            for _ in range(depth)
        ])

    def forward(self, tokens):
        for blk in self.blocks:
            tokens = blk(tokens)
        return tokens

class FeedbackAdapter(nn.Module):
    """
    ViT tokens → (gamma, beta) 생성해서 stem feature 재조정
    """
    def __init__(self, d_token: int, c_stem: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(d_token, c_stem * 2, 1, 1, 0, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(c_stem * 2))
        layers.append(nn.SiLU(inplace=True))
        self.adapter = nn.Sequential(*layers)

    def forward(self, tokens: torch.Tensor, Ht: int, Wt: int,
                f_stem: torch.Tensor):
        B, N, D = tokens.shape
        t2d = tokens.transpose(1, 2).reshape(B, D, Ht, Wt)
        ab = self.adapter(t2d)
        Cs = f_stem.shape[1]
        gamma, beta = torch.split(ab, Cs, dim=1)
        return f_stem * (1 + torch.tanh(gamma)) + beta


## 4. PAN-Lite Neck

In [None]:
# ============================================
# Cell 4: PAN-Lite Neck (P3, P4, P5) + YOLO Head (3-Scale)
# ============================================

class PANLite(nn.Module):
    """
    입력: ViT 출력 feature (P3 수준, stride=8)
    출력: P3, P4, P5 (stride 8,16,32)
    """
    def __init__(self, in_ch=512, mid=256):
        super().__init__()
        self.lateral = conv_bn_act(in_ch, mid, 1, 1, 0)

        # P4, P5 생성
        self.down4 = conv_bn_act(mid, mid, 3, 2, 1)  # P3 -> P4
        self.down5 = conv_bn_act(mid, mid, 3, 2, 1)  # P4 -> P5

        # top-down
        self.up4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.up3 = conv_bn_act(mid + mid, mid, 3, 1, 1)

        # bottom-up refine
        self.down_f4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse5 = conv_bn_act(mid + mid, mid, 3, 1, 1)

    def forward(self, p3):
        # channel align
        p3 = self.lateral(p3)       # (B,256,H/8,W/8)
        p4 = self.down4(p3)         # (B,256,H/16,W/16)
        p5 = self.down5(p4)         # (B,256,H/32,W/32)

        # top-down
        p4u = F.interpolate(p5, size=p4.shape[-2:], mode='nearest')
        p4 = self.up4(torch.cat([p4, p4u], dim=1))
        p3u = F.interpolate(p4, size=p3.shape[-2:], mode='nearest')
        p3 = self.up3(torch.cat([p3, p3u], dim=1))

        # bottom-up
        p4b = self.down_f4(p3)
        p4 = self.fuse4(torch.cat([p4, p4b], dim=1))
        p5b = self.down_f5(p4)
        p5 = self.fuse5(torch.cat([p5, p5b], dim=1))

        return p3, p4, p5



## 5. YOLO-style Detection Head

In [None]:
class YOLOHeadLite(nn.Module):
    """
    P3, P4, P5 각각에서 cls/obj/box 예측 (anchor-free, 단순 버전)
    """
    def __init__(self, in_ch=256, num_classes=1):
        super().__init__()
        c = in_ch
        # stems
        self.stem3 = conv_bn_act(c, c, 3, 1, 1)
        self.stem4 = conv_bn_act(c, c, 3, 1, 1)
        self.stem5 = conv_bn_act(c, c, 3, 1, 1)

        # heads
        self.cls3 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj3 = nn.Conv2d(c, 1,           1, 1, 0)
        self.box3 = nn.Conv2d(c, 4,           1, 1, 0)
        nn.init.constant_(self.obj3.bias, -4.59)

        self.cls4 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj4 = nn.Conv2d(c, 1,           1, 1, 0)
        self.box4 = nn.Conv2d(c, 4,           1, 1, 0)
        nn.init.constant_(self.obj4.bias, -4.59)

        self.cls5 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj5 = nn.Conv2d(c, 1,           1, 1, 0)
        self.box5 = nn.Conv2d(c, 4,           1, 1, 0)
        nn.init.constant_(self.obj5.bias, -4.59)

    def forward_single(self, x, stem, cls, obj, box):
        f = stem(x)
        return cls(f), obj(f), box(f)

    def forward(self, p3, p4, p5):
        c3, o3, b3 = self.forward_single(p3, self.stem3,
                                         self.cls3, self.obj3, self.box3)
        c4, o4, b4 = self.forward_single(p4, self.stem4,
                                         self.cls4, self.obj4, self.box4)
        c5, o5, b5 = self.forward_single(p5, self.stem5,
                                         self.cls5, self.obj5, self.box5)
        return [(c3, o3, b3), (c4, o4, b4), (c5, o5, b5)]


## 6. HybridTwoWay Model

In [None]:
class HybridTwoWay(nn.Module):
    """
    Stem → ViT (with Pos Embed) → (반복) Feedback → PANLite → YOLOHeadLite
    """
    def __init__(
        self,
        in_ch=3,
        stem_base=32,
        embed_dim=256,
        vit_depth=4,
        vit_heads=4,
        num_classes=3,
        iters=1,
        detach_feedback=True,
        img_size=640  # [중요] 이미지 사이즈 명시 필요
    ):
        super().__init__()
        assert iters >= 1
        self.iters = iters
        self.detach_feedback = detach_feedback

        # 1) CNN Stem
        self.stem = AnomalyAwareStem(in_ch=in_ch, base_ch=stem_base)
        c_stem = stem_base * 4 

        # 2) Stem feature → ViT token 변환
        self.patch = PatchEmbed1x1(c_stem, embed_dim)

        # [New] Positional Embedding
        # Stem output stride is 8. So patches = (img_size // 8)^2
        self.num_patches = (img_size // 8) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # 3) ViT Encoder
        self.vit = ViTEncoder(
            embed_dim=embed_dim,
            depth=vit_depth,
            num_heads=vit_heads
        )

        # 4) Feedback Adapter
        self.feedback = FeedbackAdapter(
            d_token=embed_dim,
            c_stem=c_stem,
            use_bn=True
        )

        # 5) Neck + Head
        self.neck = PANLite(in_ch=embed_dim, mid=256)
        self.head = YOLOHeadLite(in_ch=256, num_classes=num_classes)

    def forward(self, x):
        B, C, H, W = x.shape
        
        # 1) Stem
        f_stem, vis = self.stem(x)  # (B, C_s, H/8, W/8)

        # 2) 초기 tokens 생성
        p0 = self.patch(f_stem)     # (B, D, H/8, W/8)
        Ht, Wt = p0.shape[-2:]
        tokens = p0.flatten(2).transpose(1, 2)  # (B, N, D)

        # [New] Add Positional Embedding
        # 학습/추론 이미지 크기가 init과 다를 경우 interpolate 처리 (안전장치)
        if tokens.shape[1] != self.pos_embed.shape[1]:
            pos_embed = F.interpolate(
                self.pos_embed.reshape(1, int(self.num_patches**0.5), int(self.num_patches**0.5), -1).permute(0, 3, 1, 2),
                size=(Ht, Wt), mode='bicubic', align_corners=False
            ).flatten(2).transpose(1, 2)
            tokens = tokens + pos_embed
        else:
            tokens = tokens + self.pos_embed

        f_fb = f_stem
        preds = None
        aux = None

        # 3) iters 만큼 반복
        for i in range(self.iters):
            # 3-1) ViT
            tokens = self.vit(tokens)

            # 3-2) Feedback
            toks_for_fb = tokens.detach() if self.detach_feedback else tokens
            f_fb = self.feedback(toks_for_fb, Ht, Wt, f_fb)

            # 3-3) Neck/Head
            p3_in = self.patch(f_fb)
            p3, p4, p5 = self.neck(p3_in)

            preds = self.head(p3, p4, p5)
            aux = {"P3": p3, "P4": p4, "P5": p5, "V": vis}

            # 3-4) 다음 루프 토큰 준비 (여기서도 위치정보 다시 더해줌)
            if i != self.iters - 1:
                tokens = p3_in.flatten(2).transpose(1, 2)
                # 동일하게 위치 정보 더하기
                if tokens.shape[1] != self.pos_embed.shape[1]:
                     tokens = tokens + pos_embed # 위에서 계산한 resized 사용
                else:
                     tokens = tokens + self.pos_embed

        return preds, aux

In [None]:
# ============================================
# Cell 6: Dataset / Dataloader
# ============================================

from torch.utils.data import Dataset, DataLoader
import cv2

IMG_SIZE = 640

def yolo_collate_fn(batch):
    imgs = []
    targets = []
    for img, tgt in batch:
        imgs.append(img)
        targets.append(tgt)
    imgs = torch.stack(imgs, 0)
    return imgs, targets

class YoloDataset(Dataset):
    def __init__(self, root):
        self.img_dir = os.path.join(root, "images")
        self.label_dir = os.path.join(root, "labels")
        self.images = sorted(os.listdir(self.img_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        name = self.images[idx]

        img_path = os.path.join(self.img_dir, name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = torch.tensor(img).permute(2,0,1).float() / 255.0

        label_path = os.path.join(self.label_dir,
                                  name.replace(".jpg",".txt").replace(".png",".txt"))
        boxes = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f.readlines():
                    cls, x, y, w, h = map(float, line.split())
                    boxes.append([cls, x, y, w, h])

        boxes = torch.tensor(boxes, dtype=torch.float32)
        return img, boxes

DATA_PATH = dataset.location

train_dataset = YoloDataset(os.path.join(DATA_PATH, "train"))
val_dataset   = YoloDataset(os.path.join(DATA_PATH, "valid"))
test_dataset  = YoloDataset(os.path.join(DATA_PATH, "test"))

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=yolo_collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,          # 검증은 보통 shuffle=False
    collate_fn=yolo_collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=yolo_collate_fn
)



In [None]:
# ============================================
# Cell 7: YOLO-style Loss (Focal + GIoU)
# ============================================

import torch.nn.functional as F

class FocalLoss(nn.Module):
    """
    Binary Focal Loss for logits (BCE with logits + focal modulation)
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        logits: any shape
        targets: same shape, in {0,1}
        """
        # p = sigmoid(logit)
        prob = torch.sigmoid(logits)
        ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")

        # p_t
        p_t = prob * targets + (1 - prob) * (1 - targets)
        # focal term
        focal_term = (1 - p_t) ** self.gamma

        loss = ce * focal_term

        if self.alpha >= 0:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            loss = alpha_t * loss

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss


def xywh_to_xyxy(box_xywh):
    """
    box_xywh: [..., 4] where box = (x_c, y_c, w, h), normalized [0,1]
    return: [..., 4] (x1,y1,x2,y2), normalized [0,1]
    """
    x_c, y_c, w, h = box_xywh.unbind(-1)
    x1 = x_c - w / 2
    y1 = y_c - h / 2
    x2 = x_c + w / 2
    y2 = y_c + h / 2
    return torch.stack([x1, y1, x2, y2], dim=-1)


def giou_loss(pred_box_xyxy, tgt_box_xyxy):
    """
    pred_box_xyxy: [N, 4], normalized [0,1]
    tgt_box_xyxy:  [N, 4], normalized [0,1]
    return: scalar loss (1 - GIoU 평균)
    """
    # IoU
    x1 = torch.max(pred_box_xyxy[:, 0], tgt_box_xyxy[:, 0])
    y1 = torch.max(pred_box_xyxy[:, 1], tgt_box_xyxy[:, 1])
    x2 = torch.min(pred_box_xyxy[:, 2], tgt_box_xyxy[:, 2])
    y2 = torch.min(pred_box_xyxy[:, 3], tgt_box_xyxy[:, 3])

    inter_w = (x2 - x1).clamp(min=0)
    inter_h = (y2 - y1).clamp(min=0)
    inter = inter_w * inter_h

    area_p = (pred_box_xyxy[:, 2] - pred_box_xyxy[:, 0]).clamp(min=0) * \
             (pred_box_xyxy[:, 3] - pred_box_xyxy[:, 1]).clamp(min=0)
    area_t = (tgt_box_xyxy[:, 2] - tgt_box_xyxy[:, 0]).clamp(min=0) * \
             (tgt_box_xyxy[:, 3] - tgt_box_xyxy[:, 1]).clamp(min=0)

    union = area_p + area_t - inter + 1e-6
    iou = inter / union  # [N]

    # enclosing box
    c_x1 = torch.min(pred_box_xyxy[:, 0], tgt_box_xyxy[:, 0])
    c_y1 = torch.min(pred_box_xyxy[:, 1], tgt_box_xyxy[:, 1])
    c_x2 = torch.max(pred_box_xyxy[:, 2], tgt_box_xyxy[:, 2])
    c_y2 = torch.max(pred_box_xyxy[:, 3], tgt_box_xyxy[:, 3])

    c_w = (c_x2 - c_x1).clamp(min=0)
    c_h = (c_y2 - c_y1).clamp(min=0)
    c_area = c_w * c_h + 1e-6

    giou = iou - (c_area - union) / c_area  # [N]
    loss = 1.0 - giou
    return loss.mean()


# focal loss 인스턴스 (object + class 둘 다 사용)
_focal_loss = FocalLoss(alpha=0.25, gamma=2.0, reduction="mean")


def yolo_loss(preds, targets, img_size=512,
              lambda_obj=1.0, lambda_cls=1.0, lambda_box=5.0):
    """
    preds: list of 3 scales: [(cls3,obj3,box3), (cls4,obj4,box4), (cls5,obj5,box5)]
    targets: list of len B, each [num_gt, 5] (cls, x, y, w, h), all normalized [0,1]
    - GT를 scale별로 나누진 않고, 이전처럼 각 scale에 동일하게 할당
    - 하지만:
        * object/cls -> Focal Loss
        * box -> GIoU Loss
    """
    total_obj_loss = 0.0
    total_cls_loss = 0.0
    total_box_loss = 0.0

    for scale_id, (cls_pred, obj_pred, box_pred) in enumerate(preds):
        B, C, H, W = cls_pred.shape
        device = cls_pred.device

        # [B, H*W, C], [B, H*W, 1], [B, H*W, 4]
        cls_p = cls_pred.permute(0, 2, 3, 1).reshape(B, H * W, C)
        obj_p = obj_pred.permute(0, 2, 3, 1).reshape(B, H * W, 1)
        box_p = box_pred.permute(0, 2, 3, 1).reshape(B, H * W, 4)

        stride = img_size // H  # 이 scale의 stride

        for b in range(B):
            gt = targets[b]  # [num_gt, 5]
            if gt.numel() == 0:
                # GT 없으면 objectness만 전부 0으로 학습 (negative sample)
                obj_tgt = torch.zeros((H * W, 1), device=device)
                total_obj_loss += _focal_loss(obj_p[b], obj_tgt)
                continue

            # -------------------------
            # GT 분리
            # -------------------------
            gcls = gt[:, 0].long()    # [num_gt]
            gxy_norm = gt[:, 1:3]     # [num_gt, 2], normalized [0,1]
            gwh_norm = gt[:, 3:5]     # [num_gt, 2], normalized [0,1]

            # grid index 계산을 위한 pixel 단위 좌표
            gxy_pix = gxy_norm * img_size

            gx = (gxy_pix[:, 0] / stride).long().clamp(0, W - 1)
            gy = (gxy_pix[:, 1] / stride).long().clamp(0, H - 1)
            gi = gy * W + gx          # [num_gt]

            # -------------------------
            # objectness target
            # -------------------------
            obj_tgt = torch.zeros((H * W, 1), device=device)
            obj_tgt[gi] = 1.0
            total_obj_loss += _focal_loss(obj_p[b], obj_tgt)

            # -------------------------
            # class target (one-hot)
            # -------------------------
            cls_tgt = torch.zeros((H * W, C), device=device)
            cls_tgt[gi, gcls] = 1.0
            total_cls_loss += _focal_loss(cls_p[b], cls_tgt)

            # -------------------------
            # box GIoU Loss
            # -------------------------
            # 예측 box: [H*W, 4] -> 선택된 cell만 [num_gt, 4]
            pred_raw = box_p[b][gi]          # [num_gt,4], logits
            pred_box_norm_xywh = pred_raw.sigmoid()  # [0,1]

            # 타깃도 normalized xywh
            tgt_box_norm_xywh = torch.cat([gxy_norm, gwh_norm], dim=1)

            # xyxy로 변환 후 GIoU
            pred_xyxy = xywh_to_xyxy(pred_box_norm_xywh)  # [num_gt, 4]
            tgt_xyxy = xywh_to_xyxy(tgt_box_norm_xywh)    # [num_gt, 4]

            total_box_loss += giou_loss(pred_xyxy, tgt_xyxy)

    total = (lambda_obj * total_obj_loss +
             lambda_cls * total_cls_loss +
             lambda_box * total_box_loss)
    return total


In [None]:
# ============================================
# Cell 8: Decode Predictions + mAP Evaluation
# ============================================

import math

def decode_predictions(preds, img_size=512, conf_thres=0.25,
                       nms_iou_thres=0.5):
    """
    preds: list of 3 scales
    return: per-image list of [x1,y1,x2,y2,score,cls]
    """
    all_outputs = []
    B = preds[0][0].shape[0]

    for b in range(B):
        dets_all = []

        for (cls_pred, obj_pred, box_pred) in preds:
            B_s, C, H, W = cls_pred.shape

            cls_logits = cls_pred[b].permute(1,2,0).reshape(H*W, C)   # [HW,C]
            obj_logits = obj_pred[b].permute(1,2,0).reshape(H*W, 1)   # [HW,1]
            box_logits = box_pred[b].permute(1,2,0).reshape(H*W, 4)   # [HW,4]

            obj_scores = obj_logits.sigmoid().squeeze(-1)             # [HW]
            cls_scores = cls_logits.sigmoid()                         # [HW,C]
            box_norm   = box_logits.sigmoid()                         # [HW,4]

            cls_max_scores, cls_ids = cls_scores.max(dim=-1)          # [HW]
            scores = obj_scores * cls_max_scores                      # [HW]

            # confidence threshold
            mask = scores > conf_thres
            if mask.sum() == 0:
                continue

            scores_ = scores[mask]            # [N]
            cls_ids_ = cls_ids[mask]          # [N]
            boxes = box_norm[mask]            # [N,4]

            # cx,cy,w,h -> x1,y1,x2,y2 (image scale)
            x_c = boxes[:, 0] * img_size
            y_c = boxes[:, 1] * img_size
            w   = boxes[:, 2] * img_size
            h   = boxes[:, 3] * img_size

            x1 = (x_c - w/2).clamp(0, img_size)
            y1 = (y_c - h/2).clamp(0, img_size)
            x2 = (x_c + w/2).clamp(0, img_size)
            y2 = (y_c + h/2).clamp(0, img_size)

            # NMS 적용
            boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)  # [N,4]
            keep = nms(boxes_xyxy, scores_, iou_thres=nms_iou_thres)
            if keep.numel() == 0:
                continue

            boxes_xyxy = boxes_xyxy[keep]
            scores_ = scores_[keep]
            cls_ids_ = cls_ids_[keep]

            dets = torch.cat(
                [boxes_xyxy,
                 scores_.unsqueeze(1),
                 cls_ids_.float().unsqueeze(1)],
                dim=1
            )  # [K, 6]
            dets_all.append(dets)

        if len(dets_all) == 0:
            all_outputs.append([])
        else:
            all_outputs.append(torch.cat(dets_all, dim=0))
    return all_outputs



def box_iou(box1, box2):
    N = box1.size(0)
    M = box2.size(0)
    if N == 0 or M == 0:
        return torch.zeros(N, M)

    tl = torch.max(box1[:, None, :2], box2[:, :2])
    br = torch.min(box1[:, None, 2:], box2[:, 2:])

    wh = (br - tl).clamp(min=0)
    inter = wh[..., 0] * wh[..., 1]

    area1 = (box1[:, 2]-box1[:, 0]) * (box1[:, 3]-box1[:, 1])
    area2 = (box2[:, 2]-box2[:, 0]) * (box2[:, 3]-box2[:, 1])

    iou = inter / (area1[:, None] + area2 - inter + 1e-6)
    return iou

def nms(boxes: torch.Tensor,
        scores: torch.Tensor,
        iou_thres: float = 0.5) -> torch.Tensor:
    """
    boxes: [N, 4]  (x1,y1,x2,y2)
    scores: [N]
    return: keep indices (LongTensor)
    """
    if boxes.numel() == 0:
        return torch.zeros(0, dtype=torch.long, device=boxes.device)

    # score 내림차순 정렬
    idxs = scores.argsort(descending=True)
    keep = []

    while idxs.numel() > 0:
        i = idxs[0]
        keep.append(i.item())

        if idxs.numel() == 1:
            break

        # 현재 선택된 박스 vs 나머지 박스 IoU 계산
        ious = box_iou(
            boxes[i].unsqueeze(0),
            boxes[idxs[1:]]
        ).squeeze(0)  # [N-1]

        # IoU가 threshold 미만인 애들만 남기기
        idxs = idxs[1:][ious < iou_thres]

    return torch.tensor(keep, dtype=torch.long, device=boxes.device)

def compute_ap(recall, precision):
    mrec = torch.cat([torch.tensor([0.0]), recall, torch.tensor([1.0])])
    mpre = torch.cat([torch.tensor([0.0]), precision, torch.tensor([0.0])])

    for i in range(mpre.size(0)-1, 0, -1):
        mpre[i-1] = torch.max(mpre[i-1], mpre[i])

    idx = (mrec[1:] != mrec[:-1]).nonzero().squeeze()
    ap = ((mrec[idx+1] - mrec[idx]) * mpre[idx+1]).sum()
    return ap.item()

def evaluate_map(model, dataloader, num_classes=3, img_size=512,
                 iou_thr=0.5, conf_thres=0.25):
    model.eval()
    device = next(model.parameters()).device

    all_dets = {c: [] for c in range(num_classes)}
    all_gts  = {c: [] for c in range(num_classes)}

    with torch.no_grad():
        for img_id, (imgs, targets) in enumerate(dataloader):
            imgs = imgs.to(device)
            targets = [t.to(device) for t in targets]

            preds, aux = model(imgs)
            dets = decode_predictions(preds, img_size=img_size,
                                      conf_thres=conf_thres)

            dets = dets[0] if len(dets) > 0 else []

            gt = targets[0]
            if len(gt) > 0:
                gcls = gt[:, 0].long()
                gxy  = gt[:, 1:3] * img_size
                gwh  = gt[:, 3:5] * img_size

                gx1 = gxy[:, 0] - gwh[:, 0] / 2
                gy1 = gxy[:, 1] - gwh[:, 1] / 2
                gx2 = gxy[:, 0] + gwh[:, 0] / 2
                gy2 = gxy[:, 1] + gwh[:, 1] / 2
                gboxes = torch.stack([gx1, gy1, gx2, gy2], dim=1)

                for c in range(num_classes):
                    mask = (gcls == c)
                    if mask.sum() > 0:
                        all_gts[c].append((img_id, gboxes[mask].cpu()))

            if dets is not None and len(dets) > 0:
                boxes = dets[:, :4]
                scores = dets[:, 4]
                cls_ids = dets[:, 5].long()
                for c in range(num_classes):
                    mask = (cls_ids == c)
                    if mask.sum() > 0:
                        all_dets[c].append(
                            (img_id, scores[mask].cpu(),
                             boxes[mask].cpu())
                        )

    aps = []
    for c in range(num_classes):
        gts_c = all_gts[c]
        if len(gts_c) == 0:
            continue

        n_gt = sum(boxes.size(0) for _, boxes in gts_c)
        gt_dict = {}
        for img_id, boxes in gts_c:
            gt_dict.setdefault(img_id, [])
            gt_dict[img_id].append({
                "boxes": boxes,
                "matched": torch.zeros(boxes.size(0),
                                       dtype=torch.bool)
            })

        dets_c = all_dets[c]
        if len(dets_c) == 0:
            aps.append(0.0)
            continue

        scores_all = []
        boxes_all = []
        img_ids_all = []
        for img_id, scores, boxes in dets_c:
            for i in range(boxes.size(0)):
                scores_all.append(scores[i].item())
                boxes_all.append(boxes[i])
                img_ids_all.append(img_id)

        scores_all = torch.tensor(scores_all)
        boxes_all = torch.stack(boxes_all, dim=0)

        order = scores_all.argsort(descending=True)
        scores_all = scores_all[order]
        boxes_all = boxes_all[order]
        img_ids_all = [img_ids_all[i] for i in order]

        tps = torch.zeros(len(scores_all))
        fps = torch.zeros(len(scores_all))

        for i in range(len(scores_all)):
            img_id = img_ids_all[i]
            pred_box = boxes_all[i].unsqueeze(0)

            if img_id not in gt_dict:
                fps[i] = 1
                continue

            gt_entry = gt_dict[img_id][0]
            gboxes = gt_entry["boxes"]
            matched = gt_entry["matched"]

            ious = box_iou(pred_box, gboxes).squeeze(0)
            max_iou, max_idx = ious.max(0)

            if max_iou >= iou_thr and not matched[max_idx]:
                tps[i] = 1
                matched[max_idx] = True
            else:
                fps[i] = 1

        tp_cum = torch.cumsum(tps, dim=0)
        fp_cum = torch.cumsum(fps, dim=0)
        recall = tp_cum / (n_gt + 1e-6)
        precision = tp_cum / (tp_cum + fp_cum + 1e-6)

        ap_c = compute_ap(recall, precision)
        aps.append(ap_c)

    mAP = sum(aps) / len(aps) if len(aps) > 0 else 0.0
    return mAP, aps


In [None]:
# ============================================
# Cell 9: 학습 루프 + Val mAP 기준 best.pt 저장
# ============================================

if 'IMG_SIZE' not in globals():
    IMG_SIZE = 640

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델 설정(학습/검증/테스트 공통)
cfg = dict(
    in_ch=3,
    stem_base=32,
    embed_dim=256,
    vit_depth=4,
    vit_heads=4,
    num_classes=3,
    iters=1,
    detach_feedback=True,
)

model = HybridTwoWay(**cfg).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
EPOCHS = 5

from tqdm import tqdm

best_map = 0.0
best_epoch = -1

for epoch in range(EPOCHS):
    # ------------------------------
    # 1) Train
    # ------------------------------
    model.train()
    total_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for imgs, targets in loop:
        imgs = imgs.to(device)
        targets = [t.to(device) for t in targets]

        preds, aux = model(imgs)
        loss = yolo_loss(preds, targets, img_size=IMG_SIZE)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Train Average Loss: {avg_loss:.4f}")

    # ------------------------------
    # 2) Validation mAP
    # ------------------------------
    val_map, val_aps = evaluate_map(
        model,
        val_loader,
        num_classes=cfg["num_classes"],
        img_size=IMG_SIZE
    )
    print(f"Epoch {epoch+1} | Val mAP@0.5: {val_map:.4f}")

    # ------------------------------
    # 3) best 모델 저장
    # ------------------------------
    if val_map > best_map:
        best_map = val_map
        best_epoch = epoch + 1
        ckpt = {
            "state_dict": model.state_dict(),
            "cfg": cfg,
            "epoch": best_epoch,
            "val_map": best_map,
        }
        torch.save(ckpt, "hybrid_two_way_best.pt")
        print(f"✅ New best model saved at epoch {best_epoch} | Val mAP: {best_map:.4f}")

print("학습 완료!")
print(f"Best epoch: {best_epoch}, Best Val mAP: {best_map:.4f}")


In [None]:
# ============================================
# 저장된 Best 모델 불러오기 및 테스트 평가
# ============================================

# 1. 체크포인트 파일 로드
checkpoint_path = "hybrid_two_way_best.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

# 2. 저장된 Config를 이용해 모델 초기화 (학습 때와 동일한 구조 보장)
loaded_cfg = checkpoint["cfg"]
model = HybridTwoWay(**loaded_cfg).to(device)

# 3. 가중치(state_dict) 로드
# 저장할 때 "state_dict" 키에 가중치를 담았으므로, 이를 꺼내서 로드해야 함
model.load_state_dict(checkpoint["state_dict"])

model.eval()
print(f"✅ 모델 로드 완료 (Epoch: {checkpoint['epoch']}, Val mAP: {checkpoint['val_map']:.4f})")

# 4. 테스트셋 평가
test_map, class_aps = evaluate_map(
    model, test_loader, num_classes=loaded_cfg["num_classes"], img_size=IMG_SIZE
)

print(f"\nTest mAP@0.5: {test_map:.4f}")
for i, ap in enumerate(class_aps):
    print(f"  Class {i} AP@0.5: {ap:.4f}")

## 7. Quick Sanity Check
Colab에서 바로 실행해 모델 입출력 형태를 확인할 수 있습니다.

In [None]:
# ============================================
# Cell 10: Quick Sanity Check (입출력 shape 확인)
# ============================================

x = torch.randn(2, 3, 640, 640).to(device)
preds, aux = model(x)

for level, (c, o, b) in zip(["P3","P4","P5"], preds):
    print(f"[{level}] cls: {list(c.shape)}, obj: {list(o.shape)}, box: {list(b.shape)}")
