In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import defaultdict
from det3d.models.centerpoint import CenterPoint
from det3d.types.pointcloud import PointCloud
from det3d.utils import move_to_gpu, move_to_cpu, print_dict_tensors_size
from torchinfo import summary
from pprint import pprint
import matplotlib.pyplot as plt
import math
import glob
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

# 1. 定义 Dataset（与模型输出保持一致）

In [2]:
def myfunc(batch_data):
    """
    batch_data: list of samples, 每个样本格式为 (pc, label)
        其中 pc 为一个对象，其属性 points 为 numpy array (N_i, feature_dim)
        label 为该帧对应的标签数据（例如 gt dict 等）。
    
    返回：
        padded_points: shape [batch_size, max_num_points, feature_dim+1]
           —— 第一列为 batch_index, 后面的列为原始特征
        labels: 列表形式保存每个样本的 label
        lengths: 每帧原始点数，方便后续做 mask
    """
    point_clouds = []
    labels = []
    lengths = []
    # 遍历每个样本，同时记录下样本在 batch 中的 index
    for batch_idx, sample in enumerate(batch_data):
        # sample[0].points 是一个 numpy array，形状为 (N_i, feature_dim)
        points_tensor = torch.tensor(sample[0].points, dtype=torch.float)
        lengths.append(points_tensor.shape[0])
        labels.append(sample[1])
        # 为当前样本所有的点创建 batch index 列
        batch_idx_tensor = torch.full((points_tensor.size(0), 1), batch_idx, dtype=torch.float)
        # 拼接，得到新的 tensor，形状为 (N_i, feature_dim + 1)
        points_tensor = torch.cat([batch_idx_tensor, points_tensor], dim=1)
        point_clouds.append(points_tensor)
    
    # 使用 pad_sequence 对齐所有点云，padding_value=0
    padded_points = pad_sequence(point_clouds, batch_first=True, padding_value=0)
    
    return padded_points, labels, lengths

In [3]:
class PointCloudDataset(Dataset):
    def __init__(
        self,
        voxel_size: list,          # [voxel_x, voxel_y, voxel_z] (这里只使用 x, y)
        feature_map_stride: int,
        num_classes: int,
        pc_range: list,
        debug: bool = False        # 调试模式，默认 False
    ):
        self.num_samples = 100
        self.num_points = 1000
        self.voxel_size = voxel_size
        self.feature_map_stride = feature_map_stride
        self.num_classes = num_classes
        self.debug = debug

        self.label_mapping = {
            "Car": 0,
            "Cyclist": 1,
            "Pedestrian": 2,
            # "Hero": 3
        }

        # ------------------------------
        # 读取本地点云数据
        # ------------------------------
        pc_path = '/workspace/rosbags/archive/data_city/data_city/lidar_livox'
        npy_files = glob.glob(f"{pc_path}/*.npy")
        self.pc_list = [np.load(file) for file in npy_files]

        # 使用预设的 pc_range
        self.point_cloud_range = pc_range
        if self.debug:
            print(f'[DEBUG] Computed point_cloud_range: {self.point_cloud_range}')

        # ------------------------------
        # 读取本地标签数据
        # ------------------------------
        label_path = '/workspace/rosbags/archive/data_city/data_city/label3'
        labels = glob.glob(f"{label_path}/*.txt")
        label_list = []
        for label_file in labels:
            gt_dict = dict(
                gt_boxes = [],
                gt_labels = [],
            )
            with open(label_file, "r") as file:
                lines = file.readlines()
                for line in lines:
                    raw_line_list = line.split(" ")
                    # 前7个数字对应 [x, y, z, l, w, h, rot]
                    xyzlwhr = torch.Tensor([float(x) for x in raw_line_list[:-1]])
                    cur_label = raw_line_list[-1].replace('\n','').strip()
                    if cur_label == "Hero":
                        continue
                    gt_dict["gt_boxes"].append(xyzlwhr)
                    gt_dict["gt_labels"].append(self.label_mapping[cur_label])
            if len(gt_dict["gt_boxes"]) > 0:
                gt_dict["gt_boxes"] = torch.stack(gt_dict["gt_boxes"])
                gt_dict["gt_labels"] = torch.tensor(gt_dict["gt_labels"], dtype=torch.long)
            else:
                gt_dict["gt_boxes"] = torch.empty((0, 7))
                gt_dict["gt_labels"] = torch.empty((0,), dtype=torch.long)
            label_list.append(gt_dict)
        self.label_list = label_list

    def __len__(self):
        return self.num_samples

    # ---------- 辅助函数 ----------

    def get_bev_size(self):
        """
        根据 point_cloud_range 和 voxel_size 计算 BEV 特征图尺寸 (H, W)。
        """
        x_min, y_min, z_min, x_max, y_max, z_max = self.point_cloud_range
        vx, vy = self.voxel_size[0], self.voxel_size[1]
        stride = self.feature_map_stride
        W = round((x_max - x_min) / (vx * stride))
        H = round((y_max - y_min) / (vy * stride))
        return H, W

    def map_to_bev(self, x, y):
        """
        将 (x, y) 坐标映射到 BEV 特征图中，返回：
         - bev_x, bev_y: 浮点数坐标（未取整）
         - x_int, y_int: 转换为整数索引（grid cell 坐标）
        """
        x_min, y_min, _, _, _, _ = self.point_cloud_range
        vx, vy = self.voxel_size[0], self.voxel_size[1]
        stride = self.feature_map_stride
        bev_x = (x - x_min) / (vx * stride)
        bev_y = (y - y_min) / (vy * stride)
        x_int, y_int = int(bev_x), int(bev_y)
        return bev_x, bev_y, x_int, y_int

    def get_gaussian_patch_indices(self, center_idx, radius, H, W):
        """
        根据中心索引 center_idx = (x_int, y_int) 和高斯核半径 radius，
        计算 BEV 区域及高斯核 patch 的索引范围：
            h_x_min, h_x_max, h_y_min, h_y_max,
            g_x_min, g_x_max, g_y_min, g_y_max
        """
        x_int, y_int = center_idx
        left = x_int - radius
        right = x_int + radius + 1
        top = y_int - radius
        bottom = y_int + radius + 1

        g_x_min = max(0, -left)
        g_x_max = (2 * radius + 1) - max(0, right - W)
        g_y_min = max(0, -top)
        g_y_max = (2 * radius + 1) - max(0, bottom - H)

        h_x_min = max(0, left)
        h_x_max = min(W, right)
        h_y_min = max(0, top)
        h_y_max = min(H, bottom)

        return h_x_min, h_x_max, h_y_min, h_y_max, g_x_min, g_x_max, g_y_min, g_y_max

    def __getitem__(self, idx):
        # 1. 加载点云和对应 GT 标签字典
        pc = PointCloud(self.pc_list[idx])
        gt_dict = self.label_list[idx]

        # 对 GT 进行过滤：只保留在 BEV 内的 box
        original_boxes = gt_dict['gt_boxes']
        original_labels = gt_dict['gt_labels']
        H, W = self.get_bev_size()

        valid_reg_list = []
        valid_ind_list = []
        valid_size_list = []
        valid_boxes_list = []
        valid_labels_list = []
        valid_height_list = []
        valid_rot_list = []

        for i, box in enumerate(original_boxes):
            x, y, z_val, l_box, w_box, h_box, rot_val = box.tolist()
            bev_x, bev_y, x_int, y_int = self.map_to_bev(x, y)
            # 只保留投影落在 [0,W) 和 [0,H) 内的 box
            if 0 <= x_int < W and 0 <= y_int < H:
                valid_boxes_list.append(box)
                valid_labels_list.append(original_labels[i])
                offset = [bev_x - x_int, bev_y - y_int]
                valid_reg_list.append(offset)
                valid_ind_list.append(y_int * W + x_int)
                size_w = l_box / (self.voxel_size[0] * self.feature_map_stride)
                size_h = w_box / (self.voxel_size[1] * self.feature_map_stride)
                size_z = h_box / (self.voxel_size[2] * self.feature_map_stride)
                valid_size_list.append([size_w, size_h,size_z])
                valid_height_list.append(z_val)
                valid_rot_list.append(rot_val)
            else:
                if self.debug:
                    print(f'[DEBUG] Box {i} with center ({x:.2f},{y:.2f}) mapped to ({x_int},{y_int}) is out of BEV range.')

        if len(valid_boxes_list) > 0:
            filtered_gt_boxes = torch.stack(valid_boxes_list)
            filtered_gt_labels = torch.tensor(valid_labels_list, dtype=torch.long)
            reg = torch.tensor(valid_reg_list, dtype=torch.float32)      # [valid_num, 2]
            ind = torch.tensor(valid_ind_list, dtype=torch.long)           # [valid_num]
            size = torch.tensor(valid_size_list, dtype=torch.float32)      # [valid_num, 2]
            reg_mask = torch.ones(len(valid_ind_list), dtype=torch.uint8)  # 有效标记
            height = torch.tensor(valid_height_list, dtype=torch.float32)  # [valid_num]
            rot = torch.tensor(valid_rot_list, dtype=torch.float32)        # [valid_num]
        else:
            filtered_gt_boxes = torch.empty((0, 7))
            filtered_gt_labels = torch.empty((0,), dtype=torch.long)
            reg = torch.empty((0, 2), dtype=torch.float32)
            ind = torch.empty((0,), dtype=torch.long)
            size = torch.empty((0, 3), dtype=torch.float32)
            reg_mask = torch.empty((0,), dtype=torch.uint8)
            height = torch.empty((0,), dtype=torch.float32)
            rot = torch.empty((0,), dtype=torch.float32)

        # 生成 GT 热力图（仅用有效的 box）
        filtered_heatmap = self.generate_heatmap(
            filtered_gt_boxes, filtered_gt_labels,
            point_cloud_range=self.point_cloud_range,
            voxel_size=self.voxel_size,
            feature_map_stride=self.feature_map_stride,
            num_classes=self.num_classes
        )

        # 构造返回的 GT 字典
        final_gt = {
            'gt_boxes': filtered_gt_boxes,   # 只包含有效 box
            'gt_labels': filtered_gt_labels,
            'heatmap': filtered_heatmap,
            'ind': ind,
            'reg': reg,
            'reg_mask': reg_mask,
            'size': size,
            'height': height,
            'rot': rot
        }
        
        return pc, final_gt

    def gaussian2D(self, shape, sigma=1):
        """生成 2D 高斯核矩阵"""
        m, n = [(ss - 1.) / 2. for ss in shape]
        y, x = np.ogrid[-m:m+1, -n:n+1]
        h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
        h[h < np.finfo(h.dtype).eps * h.max()] = 0
        return h

    def gaussian_radius(self, det_size, min_overlap=0.5):
        """计算高斯核半径（基于目标尺寸）"""
        height, width = det_size
        a1 = 1
        b1 = (height + width)
        c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
        sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
        r1 = (b1 - sq1) / (2 * a1)
        
        a2 = 4
        b2 = 2 * (height + width)
        c2 = (1 - min_overlap) * width * height
        sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
        r2 = (b2 - sq2) / (2 * a2)
        
        a3 = 4 * min_overlap
        b3 = -2 * min_overlap * (height + width)
        c3 = (min_overlap - 1) * width * height
        sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
        r3 = (b3 + sq3) / (2 * a3)
        
        return min(r1, r2, r3)

    def generate_heatmap(self, gt_boxes, gt_labels, point_cloud_range, voxel_size,
                         feature_map_stride=4, num_classes=3):
        """
        根据 gt_boxes 和 gt_labels 生成形状为 [num_classes, H, W] 的 heatmap，
        H,W 与 get_bev_size() 计算一致。
        """
        dx, dy = voxel_size[0], voxel_size[1]
        H, W = self.get_bev_size()
        if self.debug:
            print(f'[DEBUG] BEV feature map size: H={H}, W={W}')

        heatmap = torch.zeros((num_classes, H, W))
        for i, (box, label) in enumerate(zip(gt_boxes, gt_labels)):
            x, y, _, dx_size, dy_size, _, _ = box.tolist()
            bev_x, bev_y, x_int, y_int = self.map_to_bev(x, y)
            if self.debug:
                print(f'[DEBUG] Box {i}: center=({x:.2f}, {y:.2f}), mapped BEV=({bev_x:.2f}, {bev_y:.2f}) -> (x_int, y_int)=({x_int}, {y_int}), label={label}')
            if not (0 <= x_int < W and 0 <= y_int < H):
                if self.debug:
                    print(f'[DEBUG] Box {i} is out of BEV bounds, skipped.')
                continue
            # 计算 box 在 BEV 上的尺寸（顺序与 voxel_size 保持一致）
            box_hw = (dy_size / dy / feature_map_stride, dx_size / dx / feature_map_stride)
            radius = self.gaussian_radius(box_hw)
            radius = max(0, int(radius))
            diameter = 2 * radius + 1
            if self.debug:
                print(f'[DEBUG] Box {i}: box_hw={box_hw}, gaussian radius={radius}, diameter={diameter}')
            
            gaussian = self.gaussian2D((diameter, diameter), sigma=diameter / 6)
            gaussian = torch.from_numpy(gaussian).float()
            
            h_x_min, h_x_max, h_y_min, h_y_max, g_x_min, g_x_max, g_y_min, g_y_max = \
                self.get_gaussian_patch_indices((x_int, y_int), radius, H, W)
            
            if self.debug:
                print(f'[DEBUG] Box {i}: BEV region: x[{h_x_min}:{h_x_max}], y[{h_y_min}:{h_y_max}]')
                print(f'[DEBUG] Box {i}: Gaussian patch indices: x[{g_x_min}:{g_x_max}], y[{g_y_min}:{g_y_max}]')
            
            masked_heatmap = heatmap[label, h_y_min:h_y_max, h_x_min:h_x_max]
            masked_gaussian = gaussian[g_y_min:g_y_max, g_x_min:g_x_max]
            if self.debug:
                print(f'[DEBUG] Box {i}: before update, heatmap sum={masked_heatmap.sum():.4f}, gaussian sum={masked_gaussian.sum():.4f}')
            if masked_gaussian.shape != masked_heatmap.shape:
                if self.debug:
                    print(f'[DEBUG] Box {i}: Shape mismatch: heatmap patch shape {masked_heatmap.shape}, gaussian patch shape {masked_gaussian.shape}. Skipping this box.')
                continue
            heatmap[label, h_y_min:h_y_max, h_x_min:h_x_max] = torch.maximum(masked_heatmap, masked_gaussian)
            if self.debug:
                print(f'[DEBUG] Box {i}: after update, new heatmap channel {label} sum={heatmap[label].sum():.4f}')
        
        if self.debug:
            import matplotlib.pyplot as plt
            fig, axs = plt.subplots(1, num_classes, figsize=(5 * num_classes, 4))
            if num_classes == 1:
                axs = [axs]
            for i in range(num_classes):
                heat = heatmap[i].numpy()
                im = axs[i].imshow(heat, cmap='hot', interpolation='nearest')
                axs[i].set_title(f"Heatmap for class {i}")
                fig.colorbar(im, ax=axs[i])
            plt.suptitle("All Channel Heatmaps")
            plt.tight_layout()
            plt.show()
        
        return heatmap

可以看到dataset中每个data的input为一帧点云 (N,4), gt为boxes,labels,scores

# 2. 使用提供的CenterPoint


In [4]:
# model = CenterPoint([0.2,0.2,0.2],[0.2,0.2,0.2,0.2,0.2,0.2])
# summary(model,input_size=(2,10, 5))

这些尺寸实际上正是由各个预测分支的设计决定的。一般来说，CenterPoint 和类似模型会为每个任务分配固定数量的输出通道，然后对整个 BEV 特征图进行卷积预测。具体解释如下：

1. **raw_hm (热力图)**  
   - 输出尺寸为 `[3, 448, 1120]`，其中 3 表示预测类别数（这里类别数为 3，比如 Vehicle、Pedestrian、Cyclist）。因此，每个类别对应一个通道，最终形成一个形状为 `[num_classes, H, W]` 的热力图。

2. **raw_center (中心偏移)**  
   - 输出尺寸为 `[2, 448, 1120]`，这里 2 通常表示在 BEV 特征图上，每个 grid cell 内的预测中心坐标（x 和 y 的残差）。即从离散的网格中心到真实中心的偏移。

3. **raw_center_z (高度/中心z值)**  
   - 输出尺寸为 `[1, 448, 1120]`，只有 1 个通道，用于预测目标在 z 轴方向上的位置（或者说目标的高度信息）。

4. **raw_dim (尺寸)**  
   - 输出尺寸为 `[3, 448, 1120]`，这里 3 个通道分别对应目标的长、宽、高（l, w, h）的预测值，通常预测的是对数尺度（后续可能会做 exp 处理）或者直接在 BEV 中的尺寸。

5. **raw_rot (旋转角度)**  
   - 输出尺寸为 `[2, 448, 1120]`，通常用两个通道分别表示旋转角度的 sin 和 cos 值，这样可以避免角度回归带来的周期性问题。

**总结：**  
- 每个分支的通道数即对应任务的输出数量。  
- 448 和 1120 是 BEV 特征图的高度和宽度（由点云范围、voxel 大小和下采样系数决定）。  
- 例如，hm 分支输出 3 个通道对应 3 个类别，中心偏移输出 2 个通道（x,y），中心高度（z）输出 1 个通道，尺寸输出 3 个通道，而旋转输出 2 个通道。

# 3. 定义损失函数

In [5]:
def focal_loss(pred, gt, alpha=1, beta=4):
    """
    实现一个简单版本的 focal loss，
    输入 pred 和 gt 均为 [num_classes, H, W] 的张量
    pred 应该是经过 sigmoid 激活后的预测热力图
    """
    assert(isinstance(pred,torch.Tensor))
    assert(isinstance(gt, torch.Tensor))
    
    pos_inds = (gt == 1).float()
    neg_inds = (gt < 1).float()

    pos_loss = -torch.log(pred + 1e-4) * torch.pow(1 - pred, alpha) * pos_inds
    neg_loss = -torch.log(1 - pred + 1e-4) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds

    num_pos = pos_inds.sum()
    if num_pos == 0:
        loss = neg_loss.sum()
    else:
        loss = (pos_loss.sum() + neg_loss.sum()) / num_pos
    return loss

In [6]:
def xyz_loss(pred_dicts, gt_dicts):
    total_offset_loss = 0.0
    total_height_loss = 0.0
    batch_count = len(pred_dicts)
    
    for i in range(batch_count):
        pred = pred_dicts[i]
        gt = gt_dicts[i]
        
        raw_center = pred["raw_center"]   # shape: [2, H, W]
        _, H, W = raw_center.shape
        raw_center_flat = raw_center.view(2, -1).transpose(0, 1)  # shape: [H*W, 2]
        
        raw_center_z = pred["raw_center_z"]  # shape: [1, H, W]
        raw_center_z_flat = raw_center_z.view(1, -1).transpose(0, 1)  # shape: [H*W, 1]
        
        # 确保索引为 long 类型
        indices = gt["ind"].to(torch.long)
        if indices.nelement() > 0:
            pred_offset = raw_center_flat[indices]  # shape: [N, 2]
            pred_height = raw_center_z_flat[indices]  # shape: [N, 1]
            
            gt_offset = gt["reg"]                 # shape: [N, 2]
            gt_height = gt["height"].unsqueeze(-1)  # shape: [N, 1]
            reg_mask = gt["reg_mask"].float()
            num_valid = reg_mask.sum() + 1e-4

            offset_loss = F.l1_loss(pred_offset, gt_offset, reduction="sum") / num_valid
            height_loss = F.l1_loss(pred_height, gt_height, reduction="sum") / num_valid
        else:
            # 如果没有有效目标，则构造一个与预测相关联的零 loss 以便 backward 正常
            offset_loss = raw_center_flat[0].sum() * 0.0
            height_loss = raw_center_z_flat[0].sum() * 0.0
        
        total_offset_loss += offset_loss
        total_height_loss += height_loss

    avg_offset_loss = total_offset_loss / batch_count
    avg_height_loss = total_height_loss / batch_count

    return avg_offset_loss, avg_height_loss

In [7]:
def lwh_loss(pred_dicts, gt_dicts):
    """
    计算每个 batch (pred_dicts[i] 和 gt_dicts[i] 对应一个 batch) 的尺寸（长、宽、高）回归 L1 损失，
    输入参数与 xyz_loss 完全一致。

    pred_dicts: list，每个元素为一个 batch 的预测 dict，其中包含键:
                - "raw_dim": 预测的尺寸张量，形状为 [3, H, W]
    gt_dicts:   list，每个元素为一个 batch 的 GT dict，其中包含键:
                - "size": GT 尺寸，形状为 [N, 3]  (包含长、宽和高)
                - "ind":  每个有效目标在 BEV 下采样后的 flatten 索引，形状为 [N]
                - "reg_mask": 有效目标的 mask，形状为 [N]

    返回:
       avg_size_loss: 平均尺寸回归损失（标量 tensor）
    """
    total_size_loss = 0.0
    batch_count = len(pred_dicts)
    
    for i in range(batch_count):
        pred = pred_dicts[i]
        gt = gt_dicts[i]
        
        # raw_dim: [3, H, W]
        raw_dim = pred["raw_dim"]
        _, H, W = raw_dim.shape
        # 扁平化为 [H*W, 3]
        raw_dim_flat = raw_dim.view(3, -1).transpose(0, 1)
        
        # 确保 indices 为 long 类型
        indices = gt["ind"].to(torch.long)
        
        if indices.nelement() > 0:
            # 从预测中采样出有效目标对应的尺寸预测： [N, 3]
            pred_size = raw_dim_flat[indices]
            gt_size = gt["size"]  # [N, 3]
            reg_mask = gt["reg_mask"].float()
            num_valid = reg_mask.sum() + 1e-4
            
            # 计算 L1 损失并归一化
            loss_size = F.l1_loss(pred_size, gt_size, reduction="sum") / num_valid
        else:
            # 若无有效目标，则返回与预测相关联的零 loss
            loss_size = raw_dim_flat[0].sum() * 0.0
        
        total_size_loss += loss_size
        
    avg_size_loss = total_size_loss / batch_count
    return avg_size_loss

In [8]:
def rot_loss(pred_dicts, gt_dicts):
    """
    计算每个 batch 的旋转回归 L1 损失，输入参数与 xyz_loss 和 lwh_loss 保持一致。
    
    输入：
      pred_dicts: list，每个元素是一个 batch 的预测 dict，其中包含键：
                  - "raw_rot": 预测的旋转张量，形状为 [2, H, W]，代表 sin 和 cos 的预测
      gt_dicts:   list，每个元素是一个 batch 的 GT dict，其中包含键：
                  - "rot": GT 的旋转角，形状为 [N]（标量角度）
                  - "ind": GT 对应在 BEV 下采样后的 flatten 索引，形状为 [N]
                  - "reg_mask": 有效目标 mask，形状为 [N]
    
    返回：
      avg_rot_loss: 平均旋转回归损失（标量 tensor）
    """
    total_rot_loss = 0.0
    batch_count = len(pred_dicts)
    
    for i in range(batch_count):
        pred = pred_dicts[i]
        gt = gt_dicts[i]
        
        # raw_rot: [2, H, W] → flatten 成 [H*W, 2]
        raw_rot = pred["raw_rot"]
        _, H, W = raw_rot.shape
        raw_rot_flat = raw_rot.view(2, -1).transpose(0, 1)  # 形状: [H*W, 2]
        
        # 确保索引为 long 类型
        indices = gt["ind"].to(torch.long)
        
        if indices.nelement() > 0:
            # 采样出预测的旋转值 [N, 2]
            pred_rot = raw_rot_flat[indices]
            
            # GT 的旋转角以标量形式给出，转换为 sin 和 cos 组成的目标 [N, 2]
            gt_rot = gt["rot"].unsqueeze(-1)      # [N, 1]
            gt_rot_targets = torch.cat([torch.sin(gt_rot), torch.cos(gt_rot)], dim=1)  # [N, 2]
            
            reg_mask = gt["reg_mask"].float()
            num_valid = reg_mask.sum() + 1e-4
            
            loss_rot = F.l1_loss(pred_rot, gt_rot_targets, reduction="sum") / num_valid
        else:
            # 若无有效目标，则构造一个和预测相关联的零 loss，以确保 backward 时有 grad_fn
            loss_rot = raw_rot_flat[0].sum() * 0.0
        
        total_rot_loss += loss_rot

    avg_rot_loss = total_rot_loss / batch_count
    return avg_rot_loss

In [9]:
def mse_heatmap_loss(pred_dicts, gt_dicts):
    """
    计算每个 batch (pred_dicts[i] 与 gt_dicts[i] 对应一个 batch) 的 MSE 损失，
    使得预测的热力图逼近 ground truth 热力图。

    输入：
      pred_dicts: list，每个元素为一个 batch 的预测 dict，其中包含键：
                  - "raw_hm": 预测的热力图张量，形状为 [num_classes, H, W]
      gt_dicts:   list，每个元素为一个 batch 的 GT dict，其中包含键：
                  - "heatmap": GT 热力图，形状为 [num_classes, H, W]
                  
    返回：
      avg_mse_loss: 每个 batch 的平均 MSE 损失（标量 tensor）
    """
    total_mse_loss = 0.0
    batch_count = len(pred_dicts)
    
    for i in range(batch_count):
        pred = pred_dicts[i]
        gt = gt_dicts[i]
        
        # 预测热力图 raw_hm 通常为未经过 sigmoid 的输出，
        # 为了使其与 GT 热力图比较前先归一化到 [0,1]，这里用 sigmoid 处理
        raw_hm = pred["raw_hm"]  # shape: [num_classes, H, W]
        pred_hm = torch.sigmoid(raw_hm)   # shape: [num_classes, H, W]
        
        gt_hm = gt["heatmap"]             # shape: [num_classes, H, W]
        
        # 计算均方误差（MSE）损失，这里 reduction 选 "mean" 表示取所有像素的平均值
        mse_loss = F.mse_loss(pred_hm, gt_hm, reduction="mean")
        total_mse_loss += mse_loss
    
    avg_mse_loss = total_mse_loss / batch_count
    return avg_mse_loss

In [10]:
def compute_loss(pred_dicts, gt_dicts):
    """
    计算每个 batch（pred_dicts[i] 和 gt_dicts[i] 对应一批）的 focal loss，
    然后求平均总损失。

    输入：
      pred_dicts: list，每个元素为一个 batch 的预测 dict，
                   其中包含键 'raw_hm'，形状为 [num_classes, H, W]
      gt_dicts:   list，每个元素为一个 batch 的 GT dict，
                   其中包含键 'heatmap'，形状为 [num_classes, H, W]

    返回：
      平均 focal loss（标量 tensor）
    """
    total_loss = 0.0
    fc_ = 0.0
    xyz_ = 0.0
    lwh_ = 0.0
    rot_ = 0.0
    
    batch_count = len(pred_dicts)
    for i in range(batch_count):
        # print(f"计算第{i+1}个batch的 focal loss ...")
        # 获取该 batch 的预测热力图，注意需要先经过 sigmoid
        
        ######## 1 计算 focal loss ###########
        # pred_hm = torch.sigmoid(pred_dicts[i]["raw_hm"])  # [num_classes, H, W]
        # gt_hm = gt_dicts[i]["heatmap"]                    # [num_classes, H, W]

        # fc_loss = focal_loss(pred_hm, gt_hm)
        # total_loss += fc_loss
        # fc_ += fc_loss
        ###################################
        
        
        ######## 2 计算 xyz loss #############
        xy_loss,z_loss = xyz_loss(pred_dicts,gt_dicts)
        total_loss += xy_loss + z_loss
        xyz_ += xy_loss + z_loss
        ###################################
        
        ######## 3 计算 lwh loss #############
        lwh_loss_ = lwh_loss(pred_dicts,gt_dicts)
        total_loss += lwh_loss_
        lwh_ += lwh_loss_
        ###################################
        
        ######## 4 计算 rot loss #############
        rot_loss_ = rot_loss(pred_dicts,gt_dicts)
        total_loss += rot_loss_
        rot_ += rot_loss_
        ###################################
        
        ####### 5 计算 Heatmap MSE loss ######
        total_loss += mse_heatmap_loss(pred_dicts,gt_dicts)
        #####################################

    avg_loss = total_loss / batch_count
    avg_fc = fc_ / batch_count
    avg_xyz = xyz_ / batch_count
    avg_lwh = lwh_ / batch_count
    avg_rot = rot_ / batch_count
    
    return dict(
      total_loss = avg_loss,
      fc_loss = avg_fc,
      xyz_loss = avg_xyz,
      lwh_loss = avg_lwh,
      rot_loss = avg_rot
    )

# 4. 定义Evaluation Metrics

In [11]:
def detection_metrics(pred_dicts, gt_dicts, center_thresh=2.0):
    """
    计算每个 batch 的检测评估指标，包含：
      - total_pred: 预测候选框总数
      - total_gt:   GT 框总数
      - total_matches: 成功匹配的 GT 框数
      - avg_center_error: 匹配成功时中心位置的平均欧氏距离误差（仅基于 (x,y) 坐标）
      - cls_accuracy: 匹配成功中预测类别正确的比率
      - true_match: 匹配成功的预测框占总预测框的比例
      - gt_box_recall: 匹配成功的 GT 框占总 GT 框的比例

    匹配策略（简单策略）：
      对于每个 GT 框（使用其 (x,y) 中心），计算与所有预测框中心之间的欧氏距离，
      选择距离最小的预测框；若最小距离小于 center_thresh 且该预测框未被其它 GT 框匹配，则认为匹配成功。

    输入：
      pred_dicts: list，每个元素为一个 batch 的预测 dict，其中必须包含：
                  - "pred_boxes": [N_pred, 7]，预测的 3D 框（格式：x, y, z, l, w, h, rot）
                  - "pred_labels": [N_pred]，预测的类别标签
      gt_dicts:   list，每个元素为一个 batch 的 GT dict，其中必须包含：
                  - "gt_boxes": [N_gt, 7]，GT 框，格式同上
                  - "gt_labels": [N_gt]，GT 类别标签
    参数：
      center_thresh: 匹配阈值，单位与 box 中 x,y 坐标一致（例如米）

    返回：
      metrics: dict，包含上述各项指标，其中还额外添加了：
               - "true_match": total_matches / total_pred (预测匹配率)
               - "gt_box_recall": total_matches / total_gt (GT 框的召回率)
    """
    total_matches = 0
    correct_cls = 0
    total_center_error = 0.0
    total_pred = 0
    total_gt = 0
    batch_count = len(pred_dicts)
    
    for i in range(batch_count):
        pred = pred_dicts[i]
        gt = gt_dicts[i]
        
        # 获取预测和 GT 的 box 与标签
        pred_boxes = pred.get("pred_boxes", torch.empty((0, 7)))  # [N_pred, 7]
        pred_labels = pred.get("pred_labels", torch.empty((0,), dtype=torch.long))  # [N_pred]
        gt_boxes = gt.get("gt_boxes", torch.empty((0, 7)))
        gt_labels = gt.get("gt_labels", torch.empty((0,), dtype=torch.long))
        
        total_pred += pred_boxes.shape[0]
        total_gt += gt_boxes.shape[0]
        
        # 为简单起见，这里只计算 (x, y) 坐标之间的距离
        matched_pred = set()  # 用于记录已经匹配的预测索引，保证一对一匹配
        for j in range(gt_boxes.shape[0]):
            gt_box = gt_boxes[j]
            gt_center = gt_box[:2]  # 取 GT 的 x,y 中心
            if pred_boxes.shape[0] == 0:
                continue
            # 预测框的中心也是取前两个数
            pred_centers = pred_boxes[:, :2]  # [N_pred, 2]
            # 计算每个预测与 GT 的欧氏距离
            dists = torch.norm(pred_centers - gt_center.unsqueeze(0), dim=1)
            min_val, min_idx = dists.min(0)
            if min_val.item() < center_thresh and min_idx.item() not in matched_pred:
                total_matches += 1
                total_center_error += min_val.item()
                if pred_labels[min_idx] == gt_labels[j]:
                    correct_cls += 1
                matched_pred.add(min_idx.item())
    
    if total_matches > 0:
        avg_center_error = total_center_error / total_matches
        cls_accuracy = correct_cls / total_matches
    else:
        avg_center_error = 0.0
        cls_accuracy = 0.0

    # 计算额外的指标, 注意避免除以 0
    true_match = total_matches / total_pred if total_pred > 0 else 0.0
    gt_box_recall = total_matches / total_gt if total_gt > 0 else 0.0
    
    metrics = {
        "total_pred": total_pred,
        "total_gt": total_gt,
        "total_matches": total_matches,
        "avg_center_error": avg_center_error,
        "cls_accuracy": cls_accuracy,
        "true_match": true_match,
        "gt_box_recall": gt_box_recall
    }
    return metrics

# 5. 训练流程

In [12]:
voxel_size = [0.2, 0.2, 0.2]
feature_map_stride = 1
num_classes = 3
batch_size = 1
pc_range = [0, -44.8, -2, 224, 44.8, 4]
lr = 1e-4


dataset = PointCloudDataset(debug=False,
                            voxel_size=voxel_size,
                            feature_map_stride=feature_map_stride,
                            pc_range=pc_range,
                            num_classes=num_classes)

# 划分训练集和测试集
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])


# Dataloader
trainloader = DataLoader(train_dataset, batch_size=batch_size,shuffle=False,collate_fn=myfunc,drop_last=True)

testloader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False,collate_fn=myfunc,drop_last=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 这行注释掉

model = CenterPoint(
    voxel_size=voxel_size,
    pc_range=pc_range,
    feature_map_stride=feature_map_stride,
    ).to(device)

torch.cuda.empty_cache()  

optimizer = optim.Adam(model.parameters(), lr=lr)

num_epochs = 100


In [13]:
dataset[0][1]['gt_boxes'].size()

torch.Size([33, 7])

In [14]:
for epoch in range(num_epochs):
    model.train()  # 训练模式
    running_loss = 0.0
    running_fc = 0.0
    running_xyz = 0.0
    running_lwh = 0.0
    running_rot = 0.0

    for i, (pointclouds, gt_dicts, lengths) in enumerate(trainloader):
        # 将点云和 GT 数据移到 GPU（或者本来就在对应设备上）
        pointclouds = move_to_gpu(pointclouds)
        gt_dicts = move_to_gpu(gt_dicts)
        
        # 可以调试打印 GT 信息的尺寸
        # print_dict_tensors_size(gt_dicts)
        
        optimizer.zero_grad()  # 清除梯度
        
        # 前向传播
        pred_dicts = model(pointclouds)
        
        # 计算损失
        loss_dict = compute_loss(pred_dicts, gt_dicts)
        loss = loss_dict['total_loss']
        loss.backward()  # 反向传播
        optimizer.step()  # 参数更新
        
        running_loss += loss.item()
        # running_fc += loss_dict['fc_loss'].item()
        running_xyz += loss_dict['xyz_loss'].item()
        running_lwh += loss_dict['lwh_loss'].item()
        running_rot += loss_dict['rot_loss'].item()
        
        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], "
                  f"Total Loss: {running_loss/(i+1):.4f}, FC Loss: {running_fc/(i+1):.4f}, "
                  f"XYZ Loss: {running_xyz/(i+1):.4f}, LWH Loss: {running_lwh/(i+1):.4f}, "
                  f"Rot Loss: {running_rot/(i+1):.4f}")
            
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Total Loss: {running_loss/len(trainloader):.4f}, "
          f"Average FC Loss: {running_fc/len(trainloader):.4f}, Average XYZ Loss: {running_xyz/len(trainloader):.4f}, "
          f"Average LWH Loss: {running_lwh/len(trainloader):.4f}, Average Rot Loss: {running_rot/len(trainloader):.4f}")
    
    # ===========================
    # Evaluation 阶段：在验证集上计算损失和检测指标
    # ===========================
    model.eval()  # 设置为验证模式
    val_loss_total = 0.0
    val_batches = 0
    # 额外统计检测指标（使用你之前实现的 detection_metrics 函数）
    all_metrics = []
    
    with torch.no_grad():
        for j, (val_pointclouds, val_gt_dicts, val_lengths) in enumerate(testloader):
            val_pointclouds = move_to_gpu(val_pointclouds)
            val_gt_dicts = move_to_gpu(val_gt_dicts)
            
            pred_dicts_val = model(val_pointclouds)
            
            loss_dict_val = compute_loss(pred_dicts_val, val_gt_dicts)
            val_loss_total += loss_dict_val['total_loss'].item()
            val_batches += 1
            
            # 计算检测指标
            batch_metrics = detection_metrics(pred_dicts_val, val_gt_dicts, center_thresh=1.0)
            all_metrics.append(batch_metrics)
    
    avg_val_loss = val_loss_total / val_batches if val_batches > 0 else 0.0
    
    # 这里对检测指标做一个简单的平均处理（对每个指标求平均）
    if len(all_metrics) > 0:
        avg_metrics = {}
        keys = all_metrics[0].keys()
        for key in keys:
            avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
    else:
        avg_metrics = {}
    
    print(f"Epoch [{epoch+1}/{num_epochs}] Validation Loss: {avg_val_loss:.4f}")
    print("Validation detection metrics:", avg_metrics)
    
    # 切换回训练模式（下一 epoch）
    model.train()

Epoch [1/100], Step [10/80], Total Loss: 32.9584, FC Loss: 0.0000, XYZ Loss: 4.8522, LWH Loss: 26.3080, Rot Loss: 1.5694
Epoch [1/100], Step [20/80], Total Loss: 33.2658, FC Loss: 0.0000, XYZ Loss: 3.7111, LWH Loss: 27.7019, Rot Loss: 1.6271
Epoch [1/100], Step [30/80], Total Loss: 33.1203, FC Loss: 0.0000, XYZ Loss: 3.6411, LWH Loss: 27.6574, Rot Loss: 1.6003
Epoch [1/100], Step [40/80], Total Loss: 32.7973, FC Loss: 0.0000, XYZ Loss: 3.3769, LWH Loss: 27.6272, Rot Loss: 1.5763
Epoch [1/100], Step [50/80], Total Loss: 34.1637, FC Loss: 0.0000, XYZ Loss: 4.0029, LWH Loss: 28.3760, Rot Loss: 1.5722
Epoch [1/100], Step [60/80], Total Loss: 33.4046, FC Loss: 0.0000, XYZ Loss: 3.7768, LWH Loss: 27.8663, Rot Loss: 1.5536
Epoch [1/100], Step [70/80], Total Loss: 33.4870, FC Loss: 0.0000, XYZ Loss: 3.6586, LWH Loss: 28.0649, Rot Loss: 1.5596
Epoch [1/100], Step [80/80], Total Loss: 33.2568, FC Loss: 0.0000, XYZ Loss: 3.8310, LWH Loss: 27.6794, Rot Loss: 1.5466
Epoch [1/100], Average Total Los

KeyboardInterrupt: 

In [None]:
'''WANDDB
import torch
import torch.optim as optim
import wandb  # pip install wandb
from pprint import pprint

# 初始化 wandb，指定项目名称和一些配置参数
wandb.init(project="your_project_name", config={
    "learning_rate": 1e-3,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    # 其它配置...
})

# 如果你之前用 move_to_gpu 函数，此处不需要修改
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.train()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

for epoch in range(num_epochs):
    model.train()  # 训练模式
    running_loss = 0.0
    running_fc = 0.0
    running_xyz = 0.0
    running_lwh = 0.0
    running_rot = 0.0

    for i, (pointclouds, gt_dicts, lengths) in enumerate(trainloader):
        pointclouds = move_to_gpu(pointclouds)
        gt_dicts = move_to_gpu(gt_dicts)
        
        # 可以打印 GT 信息尺寸以作调试
        # print_dict_tensors_size(gt_dicts)
        
        optimizer.zero_grad()  # 清除梯度
        
        # 前向传播
        pred_dicts = model(pointclouds)
        
        # 计算损失
        loss_dict = compute_loss(pred_dicts, gt_dicts)
        loss = loss_dict['total_loss']
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        
        running_loss += loss.item()
        running_fc += loss_dict['fc_loss'].item()
        running_xyz += loss_dict['xyz_loss'].item()
        running_lwh += loss_dict['lwh_loss'].item()
        running_rot += loss_dict['rot_loss'].item()
        
        if (i + 1) % 10 == 0:
            # 打印当前平均损失
            avg_total = running_loss / (i+1)
            avg_fc = running_fc / (i+1)
            avg_xyz = running_xyz / (i+1)
            avg_lwh = running_lwh / (i+1)
            avg_rot = running_rot / (i+1)
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], "
                  f"Total Loss: {avg_total:.4f}, FC Loss: {avg_fc:.4f}, "
                  f"XYZ Loss: {avg_xyz:.4f}, LWH Loss: {avg_lwh:.4f}, Rot Loss: {avg_rot:.4f}")
            
            # 同步记录到 wandb
            wandb.log({
                "Epoch": epoch + 1,
                "Step": i + 1,
                "Total Loss": avg_total,
                "FC Loss": avg_fc,
                "XYZ Loss": avg_xyz,
                "LWH Loss": avg_lwh,
                "Rot Loss": avg_rot
            }, step=epoch * len(trainloader) + i)

    avg_total_epoch = running_loss / len(trainloader)
    avg_fc_epoch = running_fc / len(trainloader)
    avg_xyz_epoch = running_xyz / len(trainloader)
    avg_lwh_epoch = running_lwh / len(trainloader)
    avg_rot_epoch = running_rot / len(trainloader)

    print(f"Epoch [{epoch+1}/{num_epochs}], Average Total Loss: {avg_total_epoch:.4f}, "
          f"FC Loss: {avg_fc_epoch:.4f}, XYZ Loss: {avg_xyz_epoch:.4f}, "
          f"LWH Loss: {avg_lwh_epoch:.4f}, Rot Loss: {avg_rot_epoch:.4f}")
    
    # 每个 epoch 后记录一次
    wandb.log({
        "Epoch": epoch + 1,
        "Average Total Loss": avg_total_epoch,
        "Average FC Loss": avg_fc_epoch,
        "Average XYZ Loss": avg_xyz_epoch,
        "Average LWH Loss": avg_lwh_epoch,
        "Average Rot Loss": avg_rot_epoch
    }, step=(epoch+1) * len(trainloader))
'''