In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class YOLOv1(nn.Module):
    def __init__(self, num_classes=20, num_boxes=2):
        super(YOLOv1, self).__init__()
        self.num_classes = num_classes
        self.num_boxes = num_boxes
        
        self.backbone = nn.Sequential(
            # 첫 번째 컨볼루션 블록
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # 두 번째 컨볼루션 블록
            nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # 세 번째 컨볼루션 블록
            nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # 네 번째 컨볼루션 블록 
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # 다섯 번째 컨볼루션 블록
            nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            
            # 여섯 번째 컨볼루션 블록
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
        )
        
        # 완전연결층 (7x7 그리드 출력)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(1024 * 7 * 7, 4096),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(4096, 7 * 7 * (num_classes + num_boxes * 5))
        )
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        x = x.view(-1, 7, 7, self.num_classes + self.num_boxes * 5) # 7x7x(클래스수 + 박스수*5) 형태로 reshape
        return x

In [5]:
class YoloV1LossCE(nn.Module):
    def __init__(self, S=7, B=2, C=20, lambda_coord=5.0, lambda_noobj=0.5, eps=1e-9):
        super().__init__()
        self.S, self.B, self.C = S, B, C
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.eps = eps

    def _xywh_to_xyxy(self, cx, cy, w, h):
        x1 = cx - w/2.0; y1 = cy - h/2.0   # bounding box의 좌상단(top-left)
        x2 = cx + w/2.0; y2 = cy + h/2.0   # bounding box의 우하단(bottom-right)
        return x1, y1, x2, y2

    def _bbox_iou(self, x1, y1, x2, y2, X1, Y1, X2, Y2):
        xx1 = torch.maximum(x1, X1); yy1 = torch.maximum(y1, Y1)
        xx2 = torch.minimum(x2, X2); yy2 = torch.minimum(y2, Y2)
        inter = (xx2-xx1).clamp(min=0) * (yy2-yy1).clamp(min=0)       # clamp: 두 박스가 겹치지 않을때, 음수 되는 것 방지
        area1 = (x2-x1).clamp(min=0) * (y2-y1).clamp(min=0)
        area2 = (X2-X1).clamp(min=0) * (Y2-Y1).clamp(min=0)
        return inter / (area1 + area2 - inter + self.eps)

    def forward(self, pred, target):

        N, S, _, _ = pred.shape
        B, C = self.B, self.C
        # device, dtype = pred.device, pred.dtype

        # 1) 예측 bounding box
        box = pred[..., :B*5].view(N, S, S, B, 5)
        bx = box[..., 0]                                              # (N,S,S,B)
        by = box[..., 1]                                              # (N,S,S,B)                  
        bw = box[..., 2]                                              # (N,S,S,B)
        bh = box[..., 3]                                              # (N,S,S,B)
        bc = box[..., 4]                                              # (N,S,S,B)                   

        class_logits = pred[..., B*5:]                                # (N,S,S,C) 

        # 2) 정답 target
        tx, ty, tw, th, tobj = target[...,0], target[...,1], target[...,2], target[...,3], target[...,4] # (N,S,S)
        tcls_onehot = target[..., 5:]
        tcls_idx = tcls_onehot.argmax(dim=-1)                         # (N,S,S)

        # 3) 그리드 
        yy, xx = torch.meshgrid(torch.arange(S),torch.arange(S))
        grid_x = xx[None, ...]                                        # (1,S,S)
        grid_y = yy[None, ...]

        # 4) 좌표(정규화) 변환
        bcx = (grid_x[..., None] + bx) / self.S                       # (N,S,S,B)
        bcy = (grid_y[..., None] + by) / self.S                       # (N,S,S,B)
        bx1, by1, bx2, by2 = self._xywh_to_xyxy(bcx, bcy, bw, bh)     # (N,S,S,B) 좌상단(top-left), 우하단(bottom-right)

        tcx = (grid_x[..., None] + tx) / self.S                       # (N,S,S,1)
        tcy = (grid_y[..., None] + ty) / self.S                       # (N,S,S,1)
        tx1, ty1, tx2, ty2 = self._xywh_to_xyxy(tcx, tcy, tw, th)     # (N,S,S,1) 좌상단(top-left), 우하단(bottom-right)

        # 5) IoU & responsible box
        iou  = self._bbox_iou(bx1,by1,bx2,by2, tx1,ty1,tx2,ty2)       # (N,S,S,B)
        best = torch.argmax(iou, dim=-1, keepdim=True)                # (N,S,S,1)    
        resp_mask = torch.zeros_like(iou).scatter_(-1, best, 1.0) * tobj.unsqueeze(-1) # (N,S,S,B) # 객체가 없는 셀: 0, IoU가 큰 박스만 1로 표기
        noobj_mask = 1.0 - resp_mask                                  # (N,S,S,B)

        # 6) loss - regression(좌표)
        sqrt_bw, sqrt_bh = torch.sqrt(bw), torch.sqrt(bh)             # (N,S,S,B)
        sqrt_tw = torch.sqrt(tw)[..., None]                           # (N,S,S,1)
        sqrt_th = torch.sqrt(th)[..., None]
        tx_b, ty_b = tx.unsqueeze(-1), ty.unsqueeze(-1)               # (N,S,S,1)

        coord_xy = ((bx - tx_b)**2 + (by - ty_b)**2) * resp_mask      # (N,S,S,B)
        coord_wh = ((sqrt_bw - sqrt_tw)**2 + (sqrt_bh - sqrt_th)**2) * resp_mask
        loss_coord = self.lambda_coord * (coord_xy.sum() + coord_wh.sum()) # (N,S,S,B) -> scalar # lambda_coord=5.0

        # 7) loss - confidence
        with torch.no_grad():
            iou_tgt = iou.clone()                                     # (N,S,S,B)
        loss_obj   = ((bc - iou_tgt)**2 * resp_mask).sum()            # (N,S,S,B) -> scalar
        loss_noobj = self.lambda_noobj * ((bc**2) * noobj_mask).sum() # (N,S,S,B) -> scalar      # lambda_noobj=0.5

        # 8) loss - classfication: CrossEntropy(logits, class_idx)
        obj_mask_bool = tobj.bool()                                   # (N,S,S)    
        loss_cls = F.cross_entropy(class_logits[obj_mask_bool], tcls_idx[obj_mask_bool], reduction="sum")

        total = (loss_coord + loss_obj + loss_noobj + loss_cls) / N
        return total

In [6]:
import torch
import torch.nn.functional as F

# NMS (class-wise)
def _nms(boxes, scores, iou_thres=0.5):
    # boxes: (M,4) xyxy, scores: (M,)
    keep = []
    order = torch.argsort(scores, descending=True)  # (M,)
    while order.numel():
        i = order[0].item()
        keep.append(i)
        if order.numel() == 1:
            break
        bb = boxes[i].unsqueeze(0)                 # (1,4)
        rest = boxes[order[1:]]                    # (K,4)

        xx1 = torch.maximum(bb[:, 0], rest[:, 0])  # (K,)
        yy1 = torch.maximum(bb[:, 1], rest[:, 1])  # (K,)
        xx2 = torch.minimum(bb[:, 2], rest[:, 2])  # (K,)
        yy2 = torch.minimum(bb[:, 3], rest[:, 3])  # (K,)
        inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0)   # (K,)
        area1 = (bb[:, 2] - bb[:, 0]) * (bb[:, 3] - bb[:, 1])         # (1,)
        area2 = (rest[:, 2] - rest[:, 0]) * (rest[:, 3] - rest[:, 1]) # (K,)
        iou = inter / (area1 + area2 - inter + 1e-9)                  # (K,)

        order = order[1:][iou < iou_thres]        # 남길 인덱스만 유지
    return torch.tensor(keep)



# YOLOv1 추론 
@torch.no_grad()
def yolov1_infer(pred, S=7, B=2, C=20, score_thresh=0.2, iou_thresh=0.5):
    device = pred.device
    N = pred.shape[0]                                              # 배치 크기

    # 1) bounding box예측
    box = pred[..., :B*5].view(N, S, S, B, 5)                      # (N,S,S,B,5)
    bx = box[..., 0]                                               # (N,S,S,B)
    by = box[..., 1]                                               # (N,S,S,B)    
    bw = box[..., 2]                                               # (N,S,S,B)
    bh = box[..., 3]                                               # (N,S,S,B)    
    bc = box[..., 4]                                               # (N,S,S,B)
    class_logits = pred[..., B*5:]                                 # (N,S,S,C)

    # 2) p(class|object)
    p_class = F.softmax(class_logits, dim=-1)                      # (N,S,S,C)

    # 3) 그리드 좌표
    yy, xx = torch.meshgrid(torch.arange(S), torch.arange(S))      # (S,S), (S,S)
    grid_x, grid_y = xx[None, ...], yy[None, ...]                  # (1,S,S), (1,S,S)

    cx = (grid_x[..., None] + bx) / S                              # (N,S,S,B)
    cy = (grid_y[..., None] + by) / S                              # (N,S,S,B)
    x1 = cx - bw/2;  y1 = cy - bh/2                                # (N,S,S,B)
    x2 = cx + bw/2;  y2 = cy + bh/2                                # (N,S,S,B)

    # 4) 클래스별 점수 s = c * p(class|obj)
    scores = bc.unsqueeze(-1) * p_class.unsqueeze(3)               # (N,S,S,B,C)

    # 5) 평탄화
    boxes  = torch.stack([x1, y1, x2, y2], dim=-1)                 # (N,S,S,B,4)
    boxes  = boxes.reshape(N, -1, 4)                               # (N, S*S*B, 4)
    scores = scores.reshape(N, -1, C)                              # (N, S*S*B, C)

    # 6) 이미지별, 클래스별 임계 → NMS → 결과 합치기
    results = []
    for n in range(N):
        all_b, all_s, all_l = [], [], []
        for k in range(20):
            s_k = scores[n, :, k]                               # k번째 클래스에 대한 (S*S*B=90,) 
            m   = s_k >= score_thresh                           # (S*S*B,)
            b_k = boxes[n, m]                                   # (M,4) 
            s_k = s_k[m]                                        # (M,)
            keep = _nms(b_k, s_k, iou_thres=iou_thresh)         # (L,)  
    
            all_b.append(b_k[keep])                             # (L,4)
            all_s.append(s_k[keep])                             # (L,)
            all_l.append(torch.full((keep.numel(),), k))        # (L,)
                                     
    
        # 전 클래스 결과를 한 번에 결합
        b = torch.cat(all_b, dim=0)                             # (K,4) 
        s = torch.cat(all_s, dim=0)                             # (K,)
        l = torch.cat(all_l, dim=0)                             # (K,)
    
        order = torch.argsort(s, descending=True)               # (K,)
        results.append({'boxes':  b[order],                     # (K,4)
                        'scores': s[order],                     # (K,)
                        'labels': l[order]})                    # (K,)
        return results