In [1]:
import torch  # 导入PyTorch库
import torch.nn as nn  # 导入神经网络模块
import torch.nn.functional as F  # 导入函数式API
import math  # 导入数学函数
from typing import List, Tuple, Optional  # 导入类型提示


class DETR3D(nn.Module):
    """
    DETR3D: 3D Object Detection from Multi-view Images via 3D-to-2D Queries
    """
    
    def __init__(self,
                 num_classes: int = 10,  # 类别数量
                 num_queries: int = 900,  # 查询数量
                 num_layers: int = 6,  # 层数
                 hidden_dim: int = 256,  # 隐藏维度
                 num_heads: int = 8,  # 注意力头数
                 num_feature_levels: int = 4):  # 特征层级数
        super().__init__()  # 调用父类初始化
        
        self.num_classes = num_classes  # 设置类别数量
        self.num_queries = num_queries  # 设置查询数量
        self.num_layers = num_layers  # 设置层数
        self.hidden_dim = hidden_dim  # 设置隐藏维度
        
        # Backbone: ResNet + FPN (简化版)
        self.backbone = SimplifiedBackbone(hidden_dim, num_feature_levels)  # 初始化骨干网络
        
        # Learnable object queries
        self.query_embed = nn.Embedding(num_queries, hidden_dim)  # 可学习的对象查询嵌入
        
        # Detection head layers
        self.transformer_layers = nn.ModuleList([  # 创建transformer层列表
            DETR3DLayer(hidden_dim, num_heads, num_feature_levels)
            for _ in range(num_layers)
        ])
        
        # Prediction heads
        self.reference_points_head = nn.Linear(hidden_dim, 3)  # 预测3D参考点
        self.bbox_head = nn.Linear(hidden_dim, 9)  # 预测边界框参数 (x,y,z,w,h,l,roll,pitch,yaw),roll,pitch,yaw是角度
        self.cls_head = nn.Linear(hidden_dim, num_classes + 1)  # +1 for no-object class
        
        self._reset_parameters()  # 初始化参数
    
    def _reset_parameters(self):
        """初始化参数"""
        for p in self.parameters():  # 遍历所有参数
            if p.dim() > 1:  # 如果参数维度大于1
                nn.init.xavier_uniform_(p)  # 使用xavier均匀初始化
    
    def forward(self, 
                images: torch.Tensor,  # 输入图像
                camera_matrices: torch.Tensor,  # 相机矩阵
                image_shapes: List[Tuple[int, int]]) -> dict:  # 图像形状
        """
        Args:
            images: [B, N_cams, 3, H, W] - 多视角图像
            camera_matrices: [B, N_cams, 3, 4] - 相机变换矩阵
            image_shapes: List of (H, W) for each image
        
        Returns:
            dict: 包含所有层的预测结果
        """
        batch_size, num_cams = images.shape[:2]  # 获取批次大小和相机数量
        
        # 1. Feature extraction using backbone
        multi_level_features = self.backbone(images)  # List of [B*N_cams, C, H_i, W_i]
        
        # 2. Initialize object queries (2,900,256)
        object_queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)  # [B, N_queries, C]
        
        # 3. Store predictions from each layer
        all_predictions = []  # 存储所有层的预测结果
        
        # 4. Iterative refinement through transformer layers
        for layer_idx, transformer_layer in enumerate(self.transformer_layers):  # 遍历每一层
            # Predict reference points from current queries,3代表x,y,z
            reference_points = self.reference_points_head(object_queries).sigmoid()  # [B, N_queries, 3]
            
            # Transform queries using multi-view features
            object_queries = transformer_layer(  # 通过transformer层更新查询
                object_queries, 
                multi_level_features, 
                reference_points,
                camera_matrices,
                image_shapes,
                batch_size,
                num_cams
            )
            
            # Make predictions for current layer
            bbox_pred = self.bbox_head(object_queries)  # [B, N_queries, 9]
            cls_pred = self.cls_head(object_queries)    # [B, N_queries, num_classes+1]
            
            all_predictions.append({  # 添加当前层的预测结果
                'pred_boxes': bbox_pred,
                'pred_logits': cls_pred,
                'reference_points': reference_points
            })
        
        return {  # 返回所有预测结果
            'predictions': all_predictions,
            'final_queries': object_queries
        }


class DETR3DLayer(nn.Module):
    """DETR3D transformer layer"""
    
    def __init__(self, hidden_dim: int, num_heads: int, num_feature_levels: int):  # 初始化函数
        super().__init__()  # 调用父类初始化
        
        self.hidden_dim = hidden_dim  # 设置隐藏维度
        self.num_heads = num_heads  # 设置注意力头数
        self.num_feature_levels = num_feature_levels  # 设置特征层级数
        
        # Multi-head self-attention
        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=0.1)  # 多头自注意力机制
        
        # Feature transformation
        self.feature_proj = nn.Linear(hidden_dim, hidden_dim)  # 特征投影变换
        
        # Normalization and feed-forward
        self.norm1 = nn.LayerNorm(hidden_dim)  # 第一个层归一化
        self.norm2 = nn.LayerNorm(hidden_dim)  # 第二个层归一化
        
        self.ffn = nn.Sequential(  # 前馈神经网络
            nn.Linear(hidden_dim, hidden_dim * 4),  # 扩展维度
            nn.ReLU(),  # ReLU激活函数
            nn.Dropout(0.1),  # 防止过拟合的dropout
            nn.Linear(hidden_dim * 4, hidden_dim),  # 恢复维度
            nn.Dropout(0.1)  # 再次应用dropout
        )
    
    def forward(self,
                queries: torch.Tensor,
                multi_level_features: List[torch.Tensor],
                reference_points: torch.Tensor,
                camera_matrices: torch.Tensor,
                image_shapes: List[Tuple[int, int]],
                batch_size: int,
                num_cams: int) -> torch.Tensor:
        """
        Args:
            queries: [B, N_queries, C]
            multi_level_features: List of [B*N_cams, C, H_i, W_i]
            reference_points: [B, N_queries, 3] - 3D reference points
            camera_matrices: [B, N_cams, 3, 4]
            image_shapes: List of (H, W)
            batch_size: int
            num_cams: int
        """
        
        # 1. Sample features from multiple views and levels
        sampled_features = self.sample_multi_view_features(  # 从多视角和多层级采样特征
            reference_points, 
            multi_level_features, 
            camera_matrices,
            image_shapes,
            batch_size, 
            num_cams
        )  # [B, N_queries, C]
        
        # 2. Add sampled features to queries
        queries = queries + sampled_features  # 将采样特征添加到查询中
        
        # 3. Self-attention among object queries
        queries_t = queries.transpose(0, 1)  # [N_queries, B, C]  # 转置查询以适应注意力机制
        attn_queries, _ = self.self_attn(queries_t, queries_t, queries_t)  # 应用自注意力
        attn_queries = attn_queries.transpose(0, 1)  # [B, N_queries, C]  # 转置回原始形状
        
        # 4. Residual connection and normalization
        queries = self.norm1(queries + attn_queries)  # 残差连接和归一化
        
        # 5. Feed-forward network
        ffn_output = self.ffn(queries)  # 通过前馈网络
        queries = self.norm2(queries + ffn_output)  # 第二次残差连接和归一化
        
        return queries  # 返回更新后的查询
    
    def sample_multi_view_features(self,
                                 reference_points: torch.Tensor,
                                 multi_level_features: List[torch.Tensor],
                                 camera_matrices: torch.Tensor,
                                 image_shapes: List[Tuple[int, int]],
                                 batch_size: int,
                                 num_cams: int) -> torch.Tensor:
        """
        Sample features from multiple camera views and feature levels
        
        Args:
            reference_points: [B, N_queries, 3] - 3D points in world coordinates
            multi_level_features: List of [B*N_cams, C, H_i, W_i]
            camera_matrices: [B, N_cams, 3, 4] - Camera transformation matrices
            image_shapes: List of (H, W) for normalization
            batch_size: int
            num_cams: int
        
        Returns:
            torch.Tensor: [B, N_queries, C] - Aggregated features
        """
        num_queries = reference_points.shape[1]  # 获取查询数量，900
        aggregated_features = []  # 存储聚合特征的列表
        
        for b in range(batch_size):  # 遍历每个批次
            batch_features = []  # 存储当前批次的特征
            valid_count = 0  # 有效相机计数
            
            for cam_idx in range(num_cams):  # 遍历每个相机
                # Get camera matrix for this batch and camera
                cam_matrix = camera_matrices[b, cam_idx]  # [3, 4]  # 获取当前批次和相机的变换矩阵
                
                # Convert 3D points to homogeneous coordinates
                points_3d_homo = torch.cat([  # 将3D点转换为齐次坐标
                    reference_points[b], 
                    torch.ones(num_queries, 1, device=reference_points.device)
                ], dim=1)  # [N_queries, 4] (900,4)
                
                # Project to 2D image coordinates
                points_2d_homo = torch.mm(points_3d_homo, cam_matrix.T)  # [N_queries, 3]  # 投影到2D图像坐标
                points_2d = points_2d_homo[:, :2] / (points_2d_homo[:, 2:3] + 1e-8)  # [N_queries, 2]  # 归一化坐标
                
                # Sample features from all levels for this camera
                cam_features = []  # 存储当前相机的特征
                for level_idx, features in enumerate(multi_level_features):  # 遍历每个特征层级
                    H, W = features.shape[-2:]  # 获取特征图的高和宽
                    
                    # Normalize coordinates to [-1, 1] for grid_sample
                    normalized_coords = points_2d.clone()  # 复制2D点坐标
                    normalized_coords[:, 0] = 2.0 * points_2d[:, 0] / W - 1.0  # x  # 归一化x坐标到[-1,1]
                    normalized_coords[:, 1] = 2.0 * points_2d[:, 1] / H - 1.0  # y  # 归一化y坐标到[-1,1]
                    
                    # Check if points are within image bounds
                    valid_mask = (  # 检查点是否在图像边界内
                        (normalized_coords[:, 0] >= -1) & (normalized_coords[:, 0] <= 1) &
                        (normalized_coords[:, 1] >= -1) & (normalized_coords[:, 1] <= 1)
                    )
                    
                    # Get features for this camera and level(1,256,256,256) 取其中一个特征图
                    feat_map = features[b * num_cams + cam_idx].unsqueeze(0)  # [1, C, H, W]  # 获取当前相机和层级的特征图
                    
                    # Sample features using bilinear interpolation
                    sample_coords = normalized_coords.unsqueeze(0).unsqueeze(0)  # [1, 1, N_queries, 2]  # 准备采样坐标(1,1,900,2)
                    sampled_feat = F.grid_sample(  # 使用双线性插值采样特征,通过sample_coords去拿feat_map的信息
                        feat_map, 
                        sample_coords, 
                        mode='bilinear', 
                        padding_mode='zeros',
                        align_corners=False
                    )  # [1, C, 1, N_queries] (1,256,1,900)
                    
                    sampled_feat = sampled_feat.squeeze(0).squeeze(1).T  # [N_queries, C]  # 调整维度(900,256)
                    
                    # Apply valid mask
                    sampled_feat[~valid_mask] = 0  # 将无效点的特征设为0
                    
                    cam_features.append(sampled_feat)  # 添加到当前相机的特征列表
                
                # Average features across levels for this camera
                if cam_features:  # 如果有特征
                    cam_feat_avg = torch.stack(cam_features, dim=0).mean(dim=0)  # [N_queries, C]  # 计算所有层级的平均特征
                    batch_features.append(cam_feat_avg)  # 添加到批次特征列表
                    valid_count += 1  # 有效相机计数加1
            
            # Average features across cameras
            if batch_features:  # 如果有批次特征
                batch_feat_avg = torch.stack(batch_features, dim=0).sum(dim=0) / (valid_count + 1e-8)  # 计算所有相机的平均特征
            else:  # 如果没有批次特征
                batch_feat_avg = torch.zeros(num_queries, self.hidden_dim, device=reference_points.device)  # 创建零特征
            
            aggregated_features.append(batch_feat_avg)  # 添加到聚合特征列表
        
        return torch.stack(aggregated_features, dim=0)  # [B, N_queries, C]  # 返回堆叠后的聚合特征


class SimplifiedBackbone(nn.Module):
    """简化版的backbone网络 (ResNet + FPN)"""
    
    def __init__(self, hidden_dim: int, num_levels: int = 4):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_levels = num_levels
        
        # 简化的特征提取层
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(3 if i == 0 else hidden_dim, hidden_dim, 3, 
                         stride=2**i if i > 0 else 1, padding=1),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True)
            ) for i in range(num_levels)
        ])
    
    def forward(self, images: torch.Tensor) -> List[torch.Tensor]:
        """
        Args:
            images: [B, N_cams, 3, H, W]
        
        Returns:
            List of feature maps: [B*N_cams, C, H_i, W_i]
        """
        B, N, C, H, W = images.shape
        
        # Reshape to process all images together
        x = images.view(B * N, C, H, W)
        
        features = []
        for i, layer in enumerate(self.conv_layers):
            x = layer(x)
            features.append(x)
            
        return features


def demo_forward_pass():
    """演示DETR3D的前向传播过程"""
    
    # 设置参数
    batch_size = 2
    num_cams = 6
    num_queries = 900
    num_classes = 10
    image_size = (256, 256)
    
    # 创建模型
    model = DETR3D(
        num_classes=num_classes,
        num_queries=num_queries,
        num_layers=6,
        hidden_dim=256
    )
    
    # 创建示例输入
    images = torch.randn(batch_size, num_cams, 3, *image_size)
    
    # 创建相机变换矩阵 (简化版)
    camera_matrices = torch.randn(batch_size, num_cams, 3, 4)
    
    # 图像尺寸,列表中有12个元素
    image_shapes = [image_size] * batch_size * num_cams
    
    print("=" * 50)
    print("DETR3D Forward Pass Demo")
    print("=" * 50)
    
    print(f"Input shapes:")
    print(f"  Images: {images.shape}")
    print(f"  Camera matrices: {camera_matrices.shape}")
    print(f"  Number of queries: {num_queries}")
    print(f"  Number of classes: {num_classes}")
    
    # 前向传播
    model.eval()
    with torch.no_grad():
        outputs = model(images, camera_matrices, image_shapes)
    
    print(f"\nOutput structure:")
    print(f"  Number of layers: {len(outputs['predictions'])}")
    
    # 显示每层的输出
    for i, pred in enumerate(outputs['predictions']):
        print(f"\n  Layer {i+1}:")
        print(f"    Bounding boxes: {pred['pred_boxes'].shape}")
        print(f"    Classification logits: {pred['pred_logits'].shape}")
        print(f"    Reference points: {pred['reference_points'].shape}")
        
        # 显示一些统计信息
        bbox_mean = pred['pred_boxes'].mean().item()
        cls_max_prob = torch.softmax(pred['pred_logits'], dim=-1).max().item()
        
        print(f"    Bbox mean: {bbox_mean:.4f}")
        print(f"    Max class probability: {cls_max_prob:.4f}")
    
    print(f"\nFinal queries shape: {outputs['final_queries'].shape}")
    
    # 模拟后处理：获取置信度最高的检测结果
    final_predictions = outputs['predictions'][-1]  # 使用最后一层的预测
    class_probs = torch.softmax(final_predictions['pred_logits'], dim=-1)
    
    # 获取非背景类的最大概率
    object_probs = class_probs[:, :, :-1].max(dim=-1)[0]  # [B, N_queries]
    
    print(f"\nPost-processing example:")
    for b in range(batch_size):
        # 获取置信度大于阈值的检测
        confident_mask = object_probs[b] > 0.5
        num_detections = confident_mask.sum().item()
        
        print(f"  Batch {b+1}: {num_detections} confident detections (>0.5 confidence)")
        
        if num_detections > 0:
            confident_boxes = final_predictions['pred_boxes'][b][confident_mask]
            confident_probs = object_probs[b][confident_mask]
            
            print(f"    Top detection - Confidence: {confident_probs.max().item():.4f}")
            print(f"    Box params: {confident_boxes[confident_probs.argmax()].tolist()}")


if __name__ == "__main__":
    demo_forward_pass()

DETR3D Forward Pass Demo
Input shapes:
  Images: torch.Size([2, 6, 3, 256, 256])
  Camera matrices: torch.Size([2, 6, 3, 4])
  Number of queries: 900
  Number of classes: 10

Output structure:
  Number of layers: 6

  Layer 1:
    Bounding boxes: torch.Size([2, 900, 9])
    Classification logits: torch.Size([2, 900, 11])
    Reference points: torch.Size([2, 900, 3])
    Bbox mean: -0.0753
    Max class probability: 0.9284

  Layer 2:
    Bounding boxes: torch.Size([2, 900, 9])
    Classification logits: torch.Size([2, 900, 11])
    Reference points: torch.Size([2, 900, 3])
    Bbox mean: -0.0304
    Max class probability: 0.9046

  Layer 3:
    Bounding boxes: torch.Size([2, 900, 9])
    Classification logits: torch.Size([2, 900, 11])
    Reference points: torch.Size([2, 900, 3])
    Bbox mean: -0.1728
    Max class probability: 0.9444

  Layer 4:
    Bounding boxes: torch.Size([2, 900, 9])
    Classification logits: torch.Size([2, 900, 11])
    Reference points: torch.Size([2, 900, 3]