In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------------------------------------------------------
# 1. HELPER MODULES (GhostConv & CBAM)
# -----------------------------------------------------------------------------

class GhostConv(nn.Module):
    """
    GhostConv: More Features from Cheap Operations.
    Reduces computational cost by generating feature maps from cheap transformations.
    """
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, relu=True):
        super(GhostConv, self).__init__()
        self.out_channels = out_channels
        init_channels = self.out_channels // 2
        self.primary_conv = nn.Conv2d(in_channels, init_channels, kernel_size, stride, padding, bias=False)
        self.cheap_operation = nn.Conv2d(init_channels, init_channels, kernel_size=3, stride=1, padding=1, groups=init_channels, bias=False)
        self.relu = nn.ReLU(inplace=True) if relu else nn.Identity()

    def forward(self, x):
        primary_out = self.primary_conv(x)
        cheap_out = self.cheap_operation(primary_out)
        out = torch.cat([primary_out, cheap_out], dim=1)
        return self.relu(out[:, :self.out_channels, :, :])

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    """ Convolutional Block Attention Module to refine features. """
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_planes, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.channel_attention(x)
        x = x * self.spatial_attention(x)
        return x

# -----------------------------------------------------------------------------
# 2. CORE ARCHITECTURAL BLOCKS
# -----------------------------------------------------------------------------

class MobileViTBlock(nn.Module):
    """ A simplified MobileViT block for demonstration. """
    def __init__(self, in_channels, transformer_dim, ffn_dim, n_transformer_blocks, patch_size=(2, 2)):
        super(MobileViTBlock, self).__init__()
        # Local representation
        self.local_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        # Global representation (unfolding -> transformer -> folding)
        self.global_conv1 = nn.Conv2d(in_channels, transformer_dim, kernel_size=1)
        encoder_layer = nn.TransformerEncoderLayer(d_model=transformer_dim, nhead=4, dim_feedforward=ffn_dim, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_transformer_blocks)
        self.global_conv2 = nn.Conv2d(transformer_dim, in_channels, kernel_size=1)
        self.patch_size = patch_size

    def forward(self, x):
        local_features = self.local_conv(x)
        
        # Global path
        y = self.global_conv1(x)
        B, C, H, W = y.shape
        y = y.unfold(2, self.patch_size[0], self.patch_size[0]).unfold(3, self.patch_size[1], self.patch_size[1])
        y = y.contiguous().view(B, C, -1, self.patch_size[0] * self.patch_size[1])
        y = y.permute(0, 2, 3, 1).contiguous().view(B, -1, C) # B, num_patches, channels
        
        y = self.transformer(y) # B, num_patches, channels
        
        y = y.view(B, H // self.patch_size[0], W // self.patch_size[1], C).permute(0, 3, 1, 2)
        global_features = self.global_conv2(y)
        
        return local_features + global_features

class BiFPNLayer(nn.Module):
    """ Simplified BiFPN layer for efficient multi-scale feature fusion. """
    def __init__(self, channels):
        super(BiFPNLayer, self).__init__()
        self.up_conv = GhostConv(channels, channels, kernel_size=3, padding=1)
        self.down_conv = GhostConv(channels, channels, kernel_size=3, padding=1)

    def forward(self, inputs):
        # A true BiFPN has multiple levels, this is a simplified one-level example
        P_high, P_low = inputs # Expects two feature maps of different scales
        
        # Top-down pathway
        P_high_up = self.up_conv(P_high + F.interpolate(P_low, size=P_high.shape[-2:], mode='nearest'))
        
        # Bottom-up pathway
        P_low_down = self.down_conv(P_low + F.max_pool2d(P_high_up, kernel_size=2, stride=2))
        
        return [P_high_up, P_low_down]

class DecoupledHead(nn.Module):
    """ Decoupled head for classification, regression, and objectness. """
    def __init__(self, in_channels, num_classes):
        super(DecoupledHead, self).__init__()
        # Classification branch
        self.cls_conv = nn.Sequential(
            GhostConv(in_channels, in_channels, relu=True),
            CBAM(in_channels),
            nn.Conv2d(in_channels, num_classes, kernel_size=1)
        )
        # Regression branch
        self.reg_conv = nn.Sequential(
            GhostConv(in_channels, in_channels, relu=True),
            CBAM(in_channels),
            nn.Conv2d(in_channels, 4, kernel_size=1)  # 4 for bbox coords (x, y, w, h)
        )
        # Objectness branch
        self.obj_conv = nn.Sequential(
            GhostConv(in_channels, in_channels, relu=True),
            nn.Conv2d(in_channels, 1, kernel_size=1)  # 1 for objectness score
        )

    def forward(self, x):
        return self.cls_conv(x), self.reg_conv(x), self.obj_conv(x)

# -----------------------------------------------------------------------------
# 3. THE COMPLETE CUSTOM MODEL
# -----------------------------------------------------------------------------

class CustomTrafficDetector(nn.Module):
    def __init__(self, num_classes=4): # e.g., car, bus, truck, ambulance
        super(CustomTrafficDetector, self).__init__()
        # Backbone: A more complete backbone would have multiple stages
        self.backbone_stage1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            MobileViTBlock(16, 32, 64, 2)
        )
        self.backbone_stage2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            MobileViTBlock(32, 64, 128, 2)
        )
        
        # Neck
        self.neck = BiFPNLayer(channels=32) # Simplified for two feature maps
        
        # Head
        self.head = DecoupledHead(in_channels=32, num_classes=num_classes)

    def forward(self, x):
        # Get multi-scale features from backbone
        p1 = self.backbone_stage1(x)
        p2 = self.backbone_stage2(p1)
        
        # Feature fusion in the neck
        features = self.neck([p1, p2])
        
        # Detection head predicts on each feature map from the neck
        # Here we only predict on the first one for simplicity
        feature_map_for_detection = features[0]
        cls_out, reg_out, obj_out = self.head(feature_map_for_detection)
        
        return cls_out, reg_out, obj_out

# -----------------------------------------------------------------------------
# 4. LOSS FUNCTIONS
# -----------------------------------------------------------------------------

def bbox_ciou(box1, box2, eps=1e-7):
    # Convert boxes from (center_x, center_y, w, h) to (x1, y1, x2, y2)
    b1_x1, b1_y1 = box1[..., 0] - box1[..., 2] / 2, box1[..., 1] - box1[..., 3] / 2
    b1_x2, b1_y2 = box1[..., 0] + box1[..., 2] / 2, box1[..., 1] + box1[..., 3] / 2
    b2_x1, b2_y1 = box2[..., 0] - box2[..., 2] / 2, box2[..., 1] - box2[..., 3] / 2
    b2_x2, b2_y2 = box2[..., 0] + box2[..., 2] / 2, box2[..., 1] + box2[..., 3] / 2

    # Intersection area
    inter_area = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
                 (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

    # Union Area
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
    union_area = w1 * h1 + w2 * h2 - inter_area + eps
    iou = inter_area / union_area

    # Enclosing box
    cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
    ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
    c2 = cw**2 + ch**2 + eps

    # Center distance
    rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4
    
    # Aspect ratio
    v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2)
    with torch.no_grad():
        alpha = v / (1 - iou + v + eps)
    
    return iou - (rho2 / c2 + v * alpha)

class CIoULoss(nn.Module):
    def forward(self, pred, target):
        return (1.0 - bbox_ciou(pred, target)).mean()

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt)**self.gamma * bce_loss
        return focal_loss.mean()

class CompositeLoss(nn.Module):
    def __init__(self, lambda_reg=1.0, lambda_cls=1.0, lambda_obj=1.0):
        super(CompositeLoss, self).__init__()
        self.lambda_reg = lambda_reg
        self.lambda_cls = lambda_cls
        self.lambda_obj = lambda_obj
        self.reg_loss = CIoULoss()
        self.cls_loss = FocalLoss()
        self.obj_loss = nn.BCEWithLogitsLoss()

    def forward(self, preds, targets):
        cls_pred, reg_pred, obj_pred = preds
        # Note: You will need to match predictions to targets
        # This is a complex part of object detection (e.g., using anchor matching)
        # For simplicity, we assume targets are already matched.
        loss_reg = self.reg_loss(reg_pred[targets['mask']], targets['bbox'][targets['mask']])
        loss_cls = self.cls_loss(cls_pred[targets['mask']], targets['cls'][targets['mask']])
        loss_obj = self.obj_loss(obj_pred, targets['obj'])

        total_loss = self.lambda_reg * loss_reg + self.lambda_cls * loss_cls + self.lambda_obj * loss_obj
        return total_loss

# -----------------------------------------------------------------------------
# 5. USAGE EXAMPLE
# -----------------------------------------------------------------------------

# if __name__ == '__main__':
#     # Ensure you have PyTorch installed: pip install torch torchvision
    
#     # --- Model Initialization ---
#     num_classes = 4  # e.g., 0: car, 1: bus, 2: truck, 3: ambulance
#     model = CustomTrafficDetector(num_classes=num_classes)
#     print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters.")
    
#     # --- Dummy Input ---
#     batch_size = 2
#     input_image = torch.randn(batch_size, 3, 256, 256) # B, C, H, W
    
#     # --- Forward Pass ---
#     cls_out, reg_out, obj_out = model(input_image)
#     print("Output shapes:")
#     print(f"  Classification: {cls_out.shape}")
#     print(f"  Regression:     {reg_out.shape}")
#     print(f"  Objectness:     {obj_out.shape}")

#     # --- Loss Calculation Example ---
#     loss_fn = CompositeLoss()
    
#     # Dummy targets (in a real scenario, these come from your dataloader)
#     # The 'mask' indicates which grid cells contain an object.
#     mask = torch.zeros(cls_out.shape[0], cls_out.shape[2], cls_out.shape[3], dtype=torch.bool)
#     mask[0, 10, 10] = True # Example: one object in the first image
    
#     targets = {
#         'bbox': torch.randn(mask.sum(), 4), # Dummy bbox values for the object
#         'cls': torch.zeros(mask.sum(), num_classes).scatter_(1, torch.randint(0, num_classes, (mask.sum(), 1)), 1.), # One-hot class
#         'obj': mask.float().unsqueeze(1), # Objectness target
#         'mask': mask
#     }

#     # You'll need to process the model output to match the target format
#     # This is a simplification.
#     # total_loss = loss_fn((cls_out, reg_out, obj_out), targets)
#     # print(f"\nExample Loss (requires proper target matching): {total_loss.item()}")

