In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# Basic Conv + BN + SiLU block

class ConvBNAct(nn.Module):
    def __init__(self, in_channels, out_channels, k=1, s=1, p=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, k, s, p, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


# CSP block 
class CSPBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        mid = out_channels // 2
        self.conv1 = ConvBNAct(in_channels, mid, 1, 1, 0)
        self.conv2 = ConvBNAct(mid, out_channels, 3, 1, 1)
        self.conv_merge = ConvBNAct(in_channels + out_channels, out_channels, 1, 1, 0)

    def forward(self, x):
        x1 = self.conv1(x)
        y1 = self.conv2(x1)
        out = torch.cat([x, y1], dim=1)
        return self.conv_merge(out)


# Backbone (CSPDarknet-tiny)

class CSPDarknetTiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = ConvBNAct(3, 32, 3, 1, 1)
        self.stage2 = nn.Sequential(
            ConvBNAct(32, 64, 3, 2, 1),
            CSPBlock(64, 128),   # c2 → 128
        )
        self.stage3 = nn.Sequential(
            ConvBNAct(128, 128, 3, 2, 1),
            CSPBlock(128, 256),  # c3 → 256
        )
        self.stage5 = nn.Sequential(
            ConvBNAct(256, 256, 3, 2, 1),
            CSPBlock(256, 512),  # c5 → 512
        )

    def forward(self, x):
        x = self.stem(x)
        c2 = self.stage2(x)
        c3 = self.stage3(c2)
        c5 = self.stage5(c3)
        return c2, c3, c5  

    
class PAFPN(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Top-down reductions
        self.reduce_c5 = ConvBNAct(512, 256, 1, 1, 0)
        self.reduce_c3 = ConvBNAct(256, 128, 1, 1, 0)
        self.reduce_p5_to_128 = ConvBNAct(256, 128, 1, 1, 0)

        # Top-down outputs
        self.top_out_c3 = ConvBNAct(128, 128, 3, 1, 1)
        self.top_out_c2 = ConvBNAct(128, 128, 3, 1, 1)

        # Bottom-up downsampling (channels corrected)
        self.down_c2 = ConvBNAct(128, 128, 3, 2, 1)    # p3 → n3 input
        self.down_c3 = ConvBNAct(256, 256, 3, 2, 1)    # n3 → n5 input (fixed!)

        # Bottom-up outputs
        self.out_c3 = ConvBNAct(128, 256, 3, 1, 1)     # mid-level → 256
        self.out_c5 = ConvBNAct(256, 512, 3, 1, 1)     # top-level → 512

    def forward(self, feats):
        c2, c3, c5 = feats

        # Top-down
        p5 = self.reduce_c5(c5)       # 512→256
        p4 = self.reduce_c3(c3)       # 256→128
        p5_upsampled = F.interpolate(p5, size=p4.shape[2:], mode="nearest")
        p5_upsampled = self.reduce_p5_to_128(p5_upsampled)
        p4 = p4 + p5_upsampled
        p4 = self.top_out_c3(p4)

        p3 = c2 + F.interpolate(p4, size=c2.shape[2:], mode="nearest")
        p3 = self.top_out_c2(p3)      # 128 channels

        # Bottom-up
        n3 = p4 + self.down_c2(p3)    # 128 + 128 → 128
        n3 = self.out_c3(n3)          # 256 channels

        n5 = p5 + self.down_c3(n3)    # 256 + 256 → 256
        n5 = self.out_c5(n5)          # 512 channels

        return [p3, n3, n5]           # [128, 256, 512]


        
        
# YOLOX Head

class YOLOXHead(nn.Module):
    def __init__(self, num_classes=80):
        super().__init__()
        self.cls_convs = nn.ModuleList([
            ConvBNAct(128, 128, 3),  # for out_c2
            ConvBNAct(256, 256, 3),  # for out_c3
            ConvBNAct(512, 512, 3)   # for out_c5
        ])
        self.reg_convs = nn.ModuleList([
            ConvBNAct(128, 128, 3),
            ConvBNAct(256, 256, 3),
            ConvBNAct(512, 512, 3)
        ])
        self.cls_preds = nn.ModuleList([
            nn.Conv2d(128, num_classes, 1),
            nn.Conv2d(256, num_classes, 1),
            nn.Conv2d(512, num_classes, 1)
        ])
        self.obj_preds = nn.ModuleList([
            nn.Conv2d(128, 1, 1),
            nn.Conv2d(256, 1, 1),
            nn.Conv2d(512, 1, 1)
        ])
        self.reg_preds = nn.ModuleList([
            nn.Conv2d(128, 4, 1),
            nn.Conv2d(256, 4, 1),
            nn.Conv2d(512, 4, 1)
        ])

    def forward(self, feats):
        outputs = []
        for i, feat in enumerate(feats):
            cls_feat = self.cls_convs[i](feat)
            reg_feat = self.reg_convs[i](feat)

            cls_output = self.cls_preds[i](cls_feat)
            obj_output = self.obj_preds[i](reg_feat)
            reg_output = self.reg_preds[i](reg_feat)

            out = torch.cat([reg_output, obj_output, cls_output], 1)
            outputs.append(out)
        return outputs


# Full YOLOX (Backbone + FPN + Head)

class YOLOX(nn.Module):
    def __init__(self, num_classes=80):
        super().__init__()
        self.backbone = CSPDarknetTiny()
        self.fpn = PAFPN()
        self.head = YOLOXHead(num_classes)

    def forward(self, x):
        feats = self.backbone(x)
        fpn_outs = self.fpn(feats)
        outputs = self.head(fpn_outs)
        return outputs


# Test

if __name__ == "__main__":
    model = YOLOX(num_classes=80)
    x = torch.randn(1, 3, 640, 640)
    out = model(x)
    for o in out:
        print(o.shape)


torch.Size([1, 85, 318, 318])
torch.Size([1, 85, 158, 158])
torch.Size([1, 85, 78, 78])


In [38]:
import torch
from torch.utils.data import Dataset
import os
import json
from PIL import Image
import torchvision.transforms as T
import random
import numpy as np

class COCODataset(Dataset):
    def __init__(self, img_dir, ann_file, img_size=640, mosaic=True):
        super().__init__()
        self.img_dir = img_dir
        self.img_size = img_size
        self.mosaic = mosaic

        # Load COCO annotations
        with open(ann_file) as f:
            data = json.load(f)
        self.images = {img['id']: img for img in data['images']}
        self.annotations = data['annotations']

        # Map image_id -> annotations
        self.img_to_anns = {}
        for ann in self.annotations:
            img_id = ann['image_id']
            if img_id not in self.img_to_anns:
                self.img_to_anns[img_id] = []
            self.img_to_anns[img_id].append(ann)
        self.ids = list(self.images.keys())

        # Category mapping
        self.cat2id = {cat['id']: idx for idx, cat in enumerate(data['categories'])}
        self.num_classes = len(self.cat2id)

        self.transforms = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])

    def __len__(self):
        return len(self.ids) #Returns number of images in dataset.

    def load_image_target(self, idx):
        #Loads the image by id and opens it as RGB.
        img_id = self.ids[idx]
        img_info = self.images[img_id]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        img = Image.open(img_path).convert("RGB")

        
        #Collects all bounding boxes and labels for that image.
        anns = self.img_to_anns.get(img_id, [])
        boxes, labels = [], []
        for ann in anns:
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x+w, y+h])
            labels.append(self.cat2id[ann['category_id']])
        boxes = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0,4))
        labels = torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)

        target = torch.zeros((boxes.shape[0],5))
        if boxes.shape[0] > 0:
            target[:,:4] = boxes
            target[:,4] = labels
            #Final target shape: [num_objects, 5].
        return img, target

    def mosaic_augment(self):
        
        indices = [random.randint(0, len(self)-1) for _ in range(4)] #Randomly select 4 images.
        imgs, targets = [], []
        for idx in indices:
            img, target = self.load_image_target(idx)
            img = np.array(img)
            imgs.append(img)
            targets.append(target)

       #A blank canvas, divided into 4 quadrants.
        mosaic_img = np.zeros((self.img_size*2, self.img_size*2, 3), dtype=np.uint8)
        positions = [(0,0),(0,self.img_size),(self.img_size,0),(self.img_size,self.img_size)]
        mosaic_targets = []

        for i,(img,target) in enumerate(zip(imgs,targets)):
            h,w,_ = img.shape
            img = np.array(Image.fromarray(img).resize((self.img_size,self.img_size), Image.BILINEAR), dtype=np.uint8)
            y1,x1 = positions[i]
            mosaic_img[y1:y1+self.img_size, x1:x1+self.img_size] = img

            # Adjust targets
            #Scale bounding boxes to resized image size.Shift them according to the image’s new location in the mosaic
            if target.shape[0] > 0:
                t = target.clone()
                t[:,0] = t[:,0]/w*self.img_size + x1
                t[:,1] = t[:,1]/h*self.img_size + y1
                t[:,2] = t[:,2]/w*self.img_size + x1
                t[:,3] = t[:,3]/h*self.img_size + y1
                mosaic_targets.append(t)
       #Combine all bounding boxes from the 4 images into one tensor.
        if mosaic_targets:
            mosaic_targets = torch.cat(mosaic_targets, dim=0)
        else:
            mosaic_targets = torch.zeros((0,5))
        .
        # Resize final mosaic back to standard input size
        #The 2×size mosaic is resized back to img_size × img_size.,Converted to tensor and normalized.
        mosaic_img = Image.fromarray(mosaic_img).resize((self.img_size,self.img_size), Image.BILINEAR)
        mosaic_img = T.ToTensor()(mosaic_img)
        mosaic_img = T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(mosaic_img)

        return mosaic_img, mosaic_targets

#50% chance: apply mosaic augmentation.50% chance: load a single image normally with transforms.
#If you used mosaic for every single sample, the model might not see enough "normal-looking" images.
#So we balance: sometimes show the model a mosaic, sometimes a regular image.
    def __getitem__(self, idx):
        if self.mosaic and random.random() < 0.5:
            img, target = self.mosaic_augment()
        else:
            img, target = self.load_image_target(idx)
            img = self.transforms(img)
        return img, target

#img: tensor of shape (3, img_size, img_size) (normalized).
#target: [num_objects, 5] (x1, y1, x2, y2, class).


In [39]:
import torch
import torch.nn as nn


# CIoU loss

def ciou_loss(pred_boxes, gt_boxes, eps=1e-7):
    """
    pred_boxes, gt_boxes: [N,4] format: x1,y1,x2,y2
    """
    #Checks if there are no predicted boxes (empty tensor).If yes, return 0 as the loss.
    if pred_boxes.numel() == 0:
        return torch.tensor(0.0, device=pred_boxes.device)

    # Intersection
    #Computes the coordinates of the intersection box between prediction and GT.
    #x1, y1 → top-left of intersection (max of both boxes’ top-left corners).
    #x2, y2 → bottom-right of intersection (min of both boxes’ bottom-right corners).
    x1 = torch.max(pred_boxes[:,0], gt_boxes[:,0])
    y1 = torch.max(pred_boxes[:,1], gt_boxes[:,1])
    x2 = torch.min(pred_boxes[:,2], gt_boxes[:,2])
    y2 = torch.min(pred_boxes[:,3], gt_boxes[:,3])

    #Computes intersection area.
    #(x2 - x1) → width of intersection; (y2 - y1) → height.
    #.clamp(0) ensures width/height cannot be negative (no intersection → 0 area).

    inter = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    area_pred = (pred_boxes[:,2]-pred_boxes[:,0])*(pred_boxes[:,3]-pred_boxes[:,1])
    area_gt   = (gt_boxes[:,2]-gt_boxes[:,0])*(gt_boxes[:,3]-gt_boxes[:,1])

    union = area_pred + area_gt - inter + eps
    iou = inter / union

    loss = 1 - iou
    return loss.mean() #Returns the average CIoU loss over all boxes in the batch.

# SimOTA Assigner

class SimOTAAssigner:
    def __init__(self, num_classes=80, top_k=10):
        self.num_classes = num_classes
        self.top_k = top_k

    def assign(self, pred_boxes, pred_cls, gt_boxes, gt_classes):
        """
        Returns:
        unique_pos_idx: indices of positive predictions
        final_assigned_gt_idx: corresponding GT index for each positive
        """
        #If there are no predictions or no GT boxes, return empty tensors.
        num_preds = pred_boxes.shape[0]
        num_gt = gt_boxes.shape[0]

        if num_gt == 0 or num_preds == 0:
            return torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)

        # Step 1: IoU matrix [num_preds, num_gt]
        ious = self.bbox_iou(pred_boxes, gt_boxes)

        # Step 2: top-k candidates per GT
        pos_idx_list = []
        assigned_gt_idx_list = []

        #For each GT box:
          #Take IoU of all predictions with this GT (iou_per_gt).
          #Pick top-k predictions that best match this GT.
          #Save their indices in pos_idx_list.
          #Assign GT index to these predictions in assigned_gt_idx_list

        for gt_idx in range(num_gt):
            iou_per_gt = ious[:, gt_idx]
            k = min(self.top_k, num_preds)
            topk_val, topk_idx = torch.topk(iou_per_gt, k)
            pos_idx_list.append(topk_idx)
            assigned_gt_idx_list.append(torch.full_like(topk_idx, gt_idx))

       #Concatenate all top-k candidates across GTs.
       #Now pos_idx = indices of candidate positive predictions.
       #assigned_gt_idx = corresponding GT index for each candidate.
        pos_idx = torch.cat(pos_idx_list)
        assigned_gt_idx = torch.cat(assigned_gt_idx_list)

        # Step 3: ensure unique predictions → assign GT with max IoU
        unique_pos_idx = torch.unique(pos_idx)
        final_assigned_gt_idx = torch.zeros_like(unique_pos_idx)

        for i, p in enumerate(unique_pos_idx):
            mask = (pos_idx == p)
            # select GT with maximum IoU
            gt_candidates = assigned_gt_idx[mask]
            iou_candidates = ious[p, gt_candidates]
            final_assigned_gt_idx[i] = gt_candidates[iou_candidates.argmax()]

        return unique_pos_idx, final_assigned_gt_idx
#unique_pos_idx → final positive prediction indices.
#final_assigned_gt_idx → which GT each positive prediction is assigned to.

    @staticmethod
    def bbox_iou(box1, box2):
        N = box1.shape[0]
        M = box2.shape[0]

        lt = torch.max(box1[:,None,:2], box2[:,:2])
        rb = torch.min(box1[:,None,2:], box2[:,2:])

        wh = (rb - lt).clamp(0)
        inter = wh[:,:,0] * wh[:,:,1]

        area1 = ((box1[:,2]-box1[:,0])*(box1[:,3]-box1[:,1]))[:,None]
        area2 = ((box2[:,2]-box2[:,0])*(box2[:,3]-box2[:,1]))[None,:]

        iou = inter / (area1 + area2 - inter + 1e-7)
        return iou


# YOLOX Loss

class YOLOXLoss(nn.Module):
    def __init__(self, num_classes=80):
        super().__init__()
        self.num_classes = num_classes
        self.assigner = SimOTAAssigner(num_classes)

    def forward(self, preds, targets):
        """
        preds: list of feature maps at 3 scales [B,4+1+C,H,W]
        targets: list of tensors [num_objects,5] per image
        """
        device = preds[0].device
        loss_cls, loss_obj, loss_reg = 0.0, 0.0, 0.0
        B = len(targets)

        #Each pred is a feature map.permute → move channel dimension last to [B,H,W,C].reshape → flatten spatial dimensions → [B, H*W, C]
        #Now every predicted box is a row in pred.
        for pred in preds:
            B, C, H, W = pred.shape
            pred = pred.permute(0,2,3,1).reshape(B, -1, C)

            pred_boxes = pred[...,:4]
            pred_obj   = pred[...,4:5]
            pred_cls   = pred[...,5:]

            #For each image in the batch:
           #gt → GT boxes for this image.
           #Skip if no GT boxes (empty image).
            for b in range(B):
                gt = targets[b]
                if gt.shape[0] == 0:
                    continue
            #Separate GT boxes and classes
                gt_boxes = gt[:,:4]
                gt_classes = gt[:,4].long()

                # SimOTA assignment
                pos_idx, assigned_gt_idx = self.assigner.assign(pred_boxes[b], pred_cls[b], gt_boxes, gt_classes)
                if pos_idx.numel() == 0:
                    continue

                # Positive predictions and corresponding GTs
                pos_pred_boxes = pred_boxes[b][pos_idx]
                assigned_gt_boxes = gt_boxes[assigned_gt_idx]
                pos_pred_obj = pred_obj[b][pos_idx]
                pos_pred_cls = pred_cls[b][pos_idx]
                assigned_gt_classes = gt_classes[assigned_gt_idx]

                # Regression loss (CIoU)
                loss_reg += ciou_loss(pos_pred_boxes, assigned_gt_boxes)

                # Objectness loss
                target_obj = torch.ones_like(pos_pred_obj, device=device)
                loss_obj += nn.BCEWithLogitsLoss()(pos_pred_obj, target_obj)

                # Classification loss (multi-class BCE)
                target_cls = torch.zeros_like(pos_pred_cls, device=device)
                target_cls[range(len(assigned_gt_classes)), assigned_gt_classes] = 1.0
                loss_cls += nn.BCEWithLogitsLoss()(pos_pred_cls, target_cls)

        total_loss = loss_cls + loss_obj + loss_reg
        return total_loss, loss_cls, loss_obj, loss_reg


In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# Make sure these are imported or defined in your project
# from dataset import COCODatasetYOLOX
# from model import YOLOX
# from loss import YOLOXLoss

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----------------------------
    # Dataset and DataLoader
    # ----------------------------
    train_dataset = COCODataset(
        img_dir='/kaggle/input/2017-2017/train2017/train2017',
        ann_file='/kaggle/input/2017-2017/annotations_trainval2017/annotations/instances_train2017.json',
        img_size=640,
        mosaic=True
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=lambda x: tuple(zip(*x))
    )

    # ----------------------------
    # Model, Loss, Optimizer
    # ----------------------------
    model = YOLOX(num_classes=train_dataset.num_classes).to(device)
    criterion = YOLOXLoss(num_classes=train_dataset.num_classes)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.01,
        momentum=0.9,
        weight_decay=5e-4
    )

    num_epochs = 20

    # ----------------------------
    # Training Loop
    # ----------------------------
    for epoch in range(num_epochs):
        model.train()
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for imgs, targets in loop:
            imgs = torch.stack(imgs).to(device)        # [B,3,H,W]
            targets = [t.to(device) for t in targets]  # list of [N,5]

            optimizer.zero_grad()
            preds = model(imgs)  # list of feature maps per scale
            loss, loss_cls, loss_obj, loss_reg = criterion(preds, targets)
            loss.backward()
            optimizer.step()

            loop.set_postfix({
                'loss': loss.item(),
                'cls': loss_cls.item(),
                'obj': loss_obj.item(),
                'reg': loss_reg.item()
            })

        # Optional: Save checkpoint every epoch
        torch.save({
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, f'yolox_epoch{epoch+1}.pt')

if __name__ == "__main__":
    main()


Epoch 1/20:   0%|          | 11/59144 [00:08<9:51:43,  1.67it/s, loss=6.54, cls=0.535, obj=9.31e-12, reg=6] 