# HybridTwoWay Model (Colab Ready)


## Imports
필요한 PyTorch 모듈과 타입 힌트를 불러옵니다.

In [None]:
import math
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
# dataset 가져오기
import os
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

# --- 옵션 1: Kaggle 데이터셋 사용 (kagglehub 사용) --- #
# !pip install kagglehub
import kagglehub

# 1. 사용할 Kaggle 데이터셋의 핸들(handle)을 입력하세요.
#    예: 'kaggle/your-user/your-dataset'
KAGGLE_DATASET_HANDLE = 'kaggle/your-user/your-dataset' # <--- 여기에 핸들 입력
DATASET_PATH = kagglehub.snapshot_download(KAGGLE_DATASET_HANDLE)

# 2. 다운로드된 경로에서 이미지 파일을 로드하는 예시
image_folder = os.path.join(DATASET_PATH, 'images') # <--- 이미지 폴더 경로 지정 (데이터셋 구조에 따라 변경)
if os.path.exists(image_folder):
    image_files = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
    if image_files:
        print(f'Found {len(image_files)} images in {image_folder}. Loading a sample image...')
        sample_image_path = image_files[0]
        sample_image = Image.open(sample_image_path).convert('RGB')
        print(f'Sample image loaded from {sample_image_path}. Size: {sample_image.size}')
        # sample_image.show() # Colab에서는 직접 이미지 표시가 어려울 수 있습니다.
    else:
        print(f'No image files found in {image_folder}.')
else:
    print(f'Image folder not found at {image_folder}. Please check the path inside the dataset.')
    print(f'Available files/folders in {DATASET_PATH}:', os.listdir(DATASET_PATH))


# --- 옵션 2: Google Drive 데이터셋 사용 --- #
# 1. 아래 주석을 해제하여 Google Drive를 마운트하세요.
# from google.colab import drive
# drive.mount('/content/drive')

# 2. Google Drive에 있는 이미지 폴더의 전체 경로를 'GDRIVE_IMAGE_FOLDER'에 입력하세요.
# GDRIVE_IMAGE_FOLDER = '/content/drive/My Drive/your_image_folder' # <--- 여기에 이미지 폴더 경로 입력
# if os.path.exists(GDRIVE_IMAGE_FOLDER):
#     gdrive_image_files = [os.path.join(GDRIVE_IMAGE_FOLDER, f) for f in os.listdir(GDRIVE_IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
#     if gdrive_image_files:
#         print(f'Found {len(gdrive_image_files)} images in {GDRIVE_IMAGE_FOLDER}. Loading a sample image...')
#         gdrive_sample_image_path = gdrive_image_files[0]
#         gdrive_sample_image = Image.open(gdrive_sample_image_path).convert('RGB')
#         print(f'Sample image loaded from {gdrive_sample_image_path}. Size: {gdrive_sample_image.size}')
#     else:
#         print(f'No image files found in {GDRIVE_IMAGE_FOLDER}.')
# else:
#     print(f'Image folder not found at {GDRIVE_IMAGE_FOLDER}. Please check the path.')


In [None]:
# --- 옵션 3: Roboflow 데이터셋 사용 --- #
# !pip install roboflow

from roboflow import Roboflow
rf = Roboflow(api_key="HG9M6YJZpcCUgAQaKO9v")
project = rf.workspace("arakon").project("detection-base-hqaeg")
version = project.version(6)
dataset = version.download("yolov8")

print(f'Roboflow dataset downloaded to: {dataset.location}')
                

## 0. Utility Functions

In [None]:
def conv_bn_act(in_ch, out_ch, k=3, s=1, p=1, act=True):
    m = [nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
         nn.BatchNorm2d(out_ch)]
    if act:
        m.append(nn.SiLU(inplace=True))
    return nn.Sequential(*m)


## 1. Anomaly-Aware CNN Stem

In [None]:
class FixedGaussianBlur(nn.Module):
    def __init__(self, channels, k=5, sigma=1.0):
        super().__init__()
        grid = torch.arange(k).float() - (k - 1) / 2
        gauss = torch.exp(-(grid ** 2) / (2 * sigma ** 2))
        kernel1d = gauss / gauss.sum()
        kernel2d = torch.outer(kernel1d, kernel1d)
        weight = kernel2d[None, None, :, :].repeat(channels, 1, 1, 1)
        self.register_buffer('weight', weight)
        self.groups = channels
        self.k = k

    def forward(self, x):
        pad = (self.k // 2,) * 4
        return F.conv2d(F.pad(x, pad, mode='reflect'), self.weight, groups=self.groups)


class AnomalyAwareStem(nn.Module):
    def __init__(self, in_ch=3, base_ch=48):
        super().__init__()
        C1, C2, C3 = base_ch, base_ch * 2, base_ch * 4
        self.stem = nn.Sequential(
            conv_bn_act(in_ch, C1, 3, 2, 1),
            conv_bn_act(C1, C2, 3, 2, 1),
            conv_bn_act(C2, C3, 3, 2, 1),
        )
        self.blur = FixedGaussianBlur(in_ch, k=5, sigma=1.0)
        self.anom = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, 1, 1, groups=in_ch, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch, C3 // 4, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C3 // 4),
            nn.SiLU(inplace=True),
        )
        self.fuse = nn.Conv2d(C3 + C3 // 4, C3, 1, 1, 0, bias=False)
        self.fuse_bn = nn.BatchNorm2d(C3)
        self.vis_head = nn.Conv2d(C3, 1, 1, 1, 0)

    @property
    def out_channels(self):
        return 4 * 48

    def forward(self, x):
        f_main = self.stem(x)
        blurred = self.blur(x)
        high = x - blurred
        high_ds = F.interpolate(high, size=f_main.shape[-2:], mode='bilinear', align_corners=False)
        f_anom = self.anom(high_ds)
        f = torch.cat([f_main, f_anom], dim=1)
        f = self.fuse_bn(self.fuse(f))
        f = F.silu(f, inplace=True)
        v = torch.sigmoid(self.vis_head(f_main))
        return f, v


## 2. Vision Transformer Encoder

In [None]:
class PatchEmbed1x1(nn.Module):
    """Map CNN features to ViT embeddings while keeping spatial resolution."""
    def __init__(self, in_ch, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, 1, 1, 0, bias=False)
        self.bn = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        x = self.bn(x)
        x = F.silu(x, inplace=True)
        return x


class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class MultiheadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiheadSelfAttention(dim, num_heads, drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class ViTEncoder(nn.Module):
    def __init__(self, embed_dim=512, depth=8, num_heads=8):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio=4.0, drop=0.0)
            for _ in range(depth)
        ])

    def forward(self, tokens):
        for blk in self.blocks:
            tokens = blk(tokens)
        return tokens


## 3. Feedback Adapter

In [None]:
class FeedbackAdapter(nn.Module):
    def __init__(self, d_token: int, c_stem: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(d_token, c_stem * 2, 1, 1, 0, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(c_stem * 2))
        layers.append(nn.SiLU(inplace=True))
        self.adapter = nn.Sequential(*layers)

    def forward(self, tokens: torch.Tensor, Ht: int, Wt: int, f_stem: torch.Tensor):
        B, N, D = tokens.shape
        t2d = tokens.transpose(1, 2).reshape(B, D, Ht, Wt)
        ab = self.adapter(t2d)
        Cs = f_stem.shape[1]
        gamma, beta = torch.split(ab, Cs, dim=1)
        return f_stem * (1 + torch.tanh(gamma)) + beta


## 4. PAN-Lite Neck

In [None]:
class PANLite(nn.Module):
    def __init__(self, in_ch=512, mid=256):
        super().__init__()
        self.lateral = conv_bn_act(in_ch, mid, 1, 1, 0)
        self.down4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.down5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.up4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.up3 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse5 = conv_bn_act(mid + mid, mid, 3, 1, 1)

    def forward(self, p3):
        p3 = self.lateral(p3)
        p4 = self.down4(p3)
        p5 = self.down5(p4)
        p4u = F.interpolate(p5, size=p4.shape[-2:], mode='nearest')
        p4 = self.up4(torch.cat([p4, p4u], dim=1))
        p3u = F.interpolate(p4, size=p3.shape[-2:], mode='nearest')
        p3 = self.up3(torch.cat([p3, p3u], dim=1))
        p4b = self.down_f4(p3)
        p4 = self.fuse4(torch.cat([p4, p4b], dim=1))
        p5b = self.down_f5(p4)
        p5 = self.fuse5(torch.cat([p5, p5b], dim=1))
        return p3, p4, p5


## 5. YOLO-style Detection Head

In [None]:
class YOLOHeadLite(nn.Module):
    def __init__(self, in_ch=256, num_classes=1, reg_max=0):
        super().__init__()
        c = in_ch
        self.stem3 = conv_bn_act(c, c, 3, 1, 1)
        self.stem4 = conv_bn_act(c, c, 3, 1, 1)
        self.stem5 = conv_bn_act(c, c, 3, 1, 1)
        self.cls3 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj3 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box3 = nn.Conv2d(c, 4, 1, 1, 0)
        self.cls4 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj4 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box4 = nn.Conv2d(c, 4, 1, 1, 0)
        self.cls5 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj5 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box5 = nn.Conv2d(c, 4, 1, 1, 0)

    def forward_single(self, x, stem, cls, obj, box):
        f = stem(x)
        return cls(f), obj(f), box(f)

    def forward(self, p3, p4, p5):
        c3, o3, b3 = self.forward_single(p3, self.stem3, self.cls3, self.obj3, self.box3)
        c4, o4, b4 = self.forward_single(p4, self.stem4, self.cls4, self.obj4, self.box4)
        c5, o5, b5 = self.forward_single(p5, self.stem5, self.cls5, self.obj5, self.box5)
        return [(c3, o3, b3), (c4, o4, b4), (c5, o5, b5)]


## 6. HybridTwoWay Model

In [None]:
class HybridTwoWay(nn.Module):
    def __init__(
        self,
        in_ch=3,
        stem_base=48,
        embed_dim=512,
        vit_depth=8,
        vit_heads=8,
        num_classes=1,
        iters=1,
        detach_feedback=False,
    ):
        super().__init__()
        assert iters >= 1
        self.iters = iters
        self.detach_feedback = detach_feedback
        self.stem = AnomalyAwareStem(in_ch=in_ch, base_ch=stem_base)
        c_stem = stem_base * 4
        self.patch = PatchEmbed1x1(c_stem, embed_dim)
        self.vit = ViTEncoder(embed_dim=embed_dim, depth=vit_depth, num_heads=vit_heads)
        self.feedback = FeedbackAdapter(embed_dim, c_stem, use_bn=True)
        self.neck = PANLite(in_ch=embed_dim, mid=256)
        self.head = YOLOHeadLite(in_ch=256, num_classes=num_classes)

    def forward_once(self, x):
        f_stem, vis = self.stem(x)
        p = self.patch(f_stem)
        Ht, Wt = p.shape[-2:]
        tokens = p.flatten(2).transpose(1, 2)
        tokens = self.vit(tokens)
        toks_for_fb = tokens.detach() if self.detach_feedback else tokens
        f_fb = self.feedback(toks_for_fb, Ht, Wt, f_stem)
        p3_in = self.patch(f_fb)
        p3 = p3_in
        p3, p4, p5 = self.neck(p3)
        preds = self.head(p3, p4, p5)
        aux = {"P3": p3, "P4": p4, "P5": p5, "V": vis}
        return preds, aux, f_fb

    def forward(self, x):
        preds, aux, f_fb = self.forward_once(x)
        for _ in range(self.iters - 1):
            preds, aux, f_fb = self.forward_once(x)
        return preds, aux


## 7. Quick Sanity Check
Colab에서 바로 실행해 모델 입출력 형태를 확인할 수 있습니다.

In [None]:
model = HybridTwoWay(
    in_ch=3,
    stem_base=48,
    embed_dim=512,
    vit_depth=8,
    vit_heads=8,
    num_classes=1,
    iters=1,
    detach_feedback=True,
)

x = torch.randn(2, 3, 640, 640)
preds, aux = model(x)
for i, (c, o, b) in enumerate(preds, start=3):
    print(f"[TwoWay] P{i} cls:{list(c.shape)} obj:{list(o.shape)} box:{list(b.shape)}")
