let's understand DETR's architecture by implementing a simplified version. Let's break it down into key components.



Let me explain the key components and architecture of DETR:

1. Overall Architecture Flow:
   - Input image → CNN Backbone → Transformer Encoder → Transformer Decoder → FFNs → Predictions

2. Key Components:

a) CNN Backbone:
```python
self.backbone = torchvision.models.resnet50(pretrained=True)
```
- Uses ResNet50 to extract image features
- Converts image to high-level feature map

b) Position Embeddings:
```python
self.pos_embedding = PositionalEncoding(hidden_dim)
```
- Adds positional information to features
- Helps transformer understand spatial relationships

c) Transformer:
```python
self.transformer = nn.Transformer(
    d_model=hidden_dim,
    nhead=nheads,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers
)
```
- Encoder: Processes image features globally
- Decoder: Converts object queries into object predictions

d) Object Queries:
```python
self.query_embed = nn.Embedding(num_queries, hidden_dim)
```
- Learnable embeddings that specify what objects to look for
- Fixed number (typically 100) of queries

e) Prediction Heads:
```python
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
```
- Class prediction head
- Box coordinate prediction head

3. Key Innovations:

a) No Anchor Boxes:
- Uses direct set prediction instead
- Learnable object queries replace anchor boxes

b) No NMS:
- Built-in set prediction eliminates need for post-processing
- Each query predicts a unique object

c) Global Processing:
- Transformer allows global context
- All objects processed in parallel

Would you like me to:
1. Explain the loss function and bipartite matching?
2. Show how to train this model?
3. Discuss how to use it for inference?
4. Dive deeper into any specific component?

In [2]:
import math
import torch
import torch.nn as nn
import torchvision

class CustomTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, num_layers=6):
        super().__init__()
        encoder_layer = CustomTransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, src, pos=None):
        output = src
        for layer in self.layers:
            output = layer(output, pos=pos)
        return self.norm(output)

class CustomTransformerDecoder(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, num_layers=6):
        super().__init__()
        decoder_layer = CustomTransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.layers = nn.ModuleList([decoder_layer for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, tgt, memory, pos=None, query_pos=None):
        output = tgt
        for layer in self.layers:
            output = layer(output, memory, pos=pos, query_pos=query_pos)
        return self.norm(output)

class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def with_pos_embed(self, tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward(self, src, pos=None):
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class CustomTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def with_pos_embed(self, tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward(self, tgt, memory, pos=None, query_pos=None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        tgt2 = self.multihead_attn(
            query=self.with_pos_embed(tgt, query_pos),
            key=self.with_pos_embed(memory, pos),
            value=memory)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

class SimplifiedDETR(nn.Module):
    def __init__(self, num_classes, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6, num_queries=100):
        super().__init__()
        
        # 1. CNN Backbone (using ResNet50 by default)
        self.backbone = torchvision.models.resnet50(pretrained=True)
        del self.backbone.fc  # Remove the classification head
        
        # 2. Position Embeddings
        self.pos_embedding = PositionalEncoding(hidden_dim)
        
        # 3. Input Projection: Convert backbone features to transformer dimensions
        self.input_proj = nn.Conv2d(2048, hidden_dim, kernel_size=1)
        
        # 4. Custom Transformer
        self.transformer_encoder = CustomTransformerEncoder(
            hidden_dim, nheads, dim_feedforward=2048, 
            dropout=0.1, num_layers=num_encoder_layers
        )
        self.transformer_decoder = CustomTransformerDecoder(
            hidden_dim, nheads, dim_feedforward=2048, 
            dropout=0.1, num_layers=num_decoder_layers
        )
        
        # 5. Object Queries (learnable parameters)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        
        # 6. Output FFNs (Feed-Forward Networks)
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)  # +1 for no-object class
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)  # 4 for box coordinates
        
        self.hidden_dim = hidden_dim
        self.num_queries = num_queries

    def forward(self, x):
        # 1. Extract features using CNN backbone
        features = self.backbone.conv1(x)
        features = self.backbone.bn1(features)
        features = self.backbone.relu(features)
        features = self.backbone.maxpool(features)
        
        features = self.backbone.layer1(features)
        features = self.backbone.layer2(features)
        features = self.backbone.layer3(features)
        features = self.backbone.layer4(features)
        
        # 2. Project features to transformer dimension
        features_proj = self.input_proj(features)
        
        # 3. Flatten spatial dimensions and transpose for transformer
        batch_size, channels, height, width = features_proj.shape
        features_flat = features_proj.flatten(2).permute(2, 0, 1)
        
        # 4. Generate positional embeddings
        pos = self.pos_embedding(features_flat)
        
        # 5. Object queries
        query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)
        tgt = torch.zeros_like(query_embed)
        
        # 6. Pass through transformer
        memory = self.transformer_encoder(features_flat, pos=pos)
        hs = self.transformer_decoder(tgt, memory, pos=pos, query_pos=query_embed)
        
        # 7. Predict classes and boxes
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        
        return {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

# Example usage:
def main():
    # Create model
    model = SimplifiedDETR(num_classes=91)  # COCO has 91 classes
    
    # Create dummy input
    x = torch.randn(2, 3, 800, 1200)  # batch_size=2, 3 channels, 800x1200 image
    
    # Forward pass
    outputs = model(x)
    
    # Print output shapes
    print("Prediction logits shape:", outputs['pred_logits'].shape)
    print("Prediction boxes shape:", outputs['pred_boxes'].shape)

if __name__ == "__main__":
    main()

Prediction logits shape: torch.Size([2, 92])
Prediction boxes shape: torch.Size([2, 4])
