In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import defaultdict
from det3d.models.centerpoint import CenterPoint
from det3d.types.pointcloud import PointCloud
from det3d.utils import move_to_gpu, move_to_cpu, print_dict_tensors_size
from torchinfo import summary
from pprint import pprint
import matplotlib.pyplot as plt
import math
import glob
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

# 1. 定义一个Toy Dataset（与模型输出保持一致）

In [2]:
def myfunc(batch_data):
    """
    batch_data: list of samples, 每个样本格式为 (pc, label)
        其中 pc 为一个对象，其属性 points 为 numpy array (N_i, feature_dim)
        label 为该帧对应的标签数据（例如 gt dict 等）。
    
    返回：
        padded_points: shape [batch_size, max_num_points, feature_dim+1]
           —— 第一列为 batch_index, 后面的列为原始特征
        labels: 列表形式保存每个样本的 label
        lengths: 每帧原始点数，方便后续做 mask
    """
    point_clouds = []
    labels = []
    lengths = []
    # 遍历每个样本，同时记录下样本在 batch 中的 index
    for batch_idx, sample in enumerate(batch_data):
        # sample[0].points 是一个 numpy array，形状为 (N_i, feature_dim)
        points_tensor = torch.tensor(sample[0].points, dtype=torch.float)
        lengths.append(points_tensor.shape[0])
        labels.append(sample[1])
        # 为当前样本所有的点创建 batch index 列
        batch_idx_tensor = torch.full((points_tensor.size(0), 1), batch_idx, dtype=torch.float)
        # 拼接，得到新的 tensor，形状为 (N_i, feature_dim + 1)
        points_tensor = torch.cat([batch_idx_tensor, points_tensor], dim=1)
        point_clouds.append(points_tensor)
    
    # 使用 pad_sequence 对齐所有点云，padding_value=0
    padded_points = pad_sequence(point_clouds, batch_first=True, padding_value=0)
    
    return padded_points, labels, lengths

In [3]:
class PointCloudDataset(Dataset):
    def __init__(
        self,
        voxel_size: list,          # [voxel_x, voxel_y, voxel_z] (这里只使用 x, y)
        feature_map_stride: int,
        num_classes: int,
        pc_range: list,
        debug: bool = False        # 调试模式，默认 False
    ):
        self.num_samples = 100
        self.num_points = 1000
        self.voxel_size = voxel_size
        self.feature_map_stride = feature_map_stride
        self.num_classes = num_classes
        self.debug = debug

        self.label_mapping = {
            "Car": 0,
            "Cyclist": 1,
            "Pedestrian": 2,
            # "Hero": 3
        }

        # ------------------------------
        # 读取本地点云数据
        # ------------------------------
        pc_path = '/workspace/rosbags/archive/data_city/data_city/lidar_livox'
        npy_files = glob.glob(f"{pc_path}/*.npy")
        self.pc_list = [np.load(file) for file in npy_files]

        # 使用预设的 pc_range
        self.point_cloud_range = pc_range
        if self.debug:
            print(f'[DEBUG] Computed point_cloud_range: {self.point_cloud_range}')

        # ------------------------------
        # 读取本地标签数据
        # ------------------------------
        label_path = '/workspace/rosbags/archive/data_city/data_city/label3'
        labels = glob.glob(f"{label_path}/*.txt")
        label_list = []
        for label_file in labels:
            gt_dict = dict(
                gt_boxes = [],
                gt_labels = [],
            )
            with open(label_file, "r") as file:
                lines = file.readlines()
                for line in lines:
                    raw_line_list = line.split(" ")
                    # 前7个数字对应 [x, y, z, l, w, h, rot]
                    xyzlwhr = torch.Tensor([float(x) for x in raw_line_list[:-1]])
                    cur_label = raw_line_list[-1].replace('\n','').strip()
                    if cur_label == "Hero":
                        continue
                    gt_dict["gt_boxes"].append(xyzlwhr)
                    gt_dict["gt_labels"].append(self.label_mapping[cur_label])
            if len(gt_dict["gt_boxes"]) > 0:
                gt_dict["gt_boxes"] = torch.stack(gt_dict["gt_boxes"])
                gt_dict["gt_labels"] = torch.tensor(gt_dict["gt_labels"], dtype=torch.long)
            else:
                gt_dict["gt_boxes"] = torch.empty((0, 7))
                gt_dict["gt_labels"] = torch.empty((0,), dtype=torch.long)
            label_list.append(gt_dict)
        self.label_list = label_list

    def __len__(self):
        return self.num_samples

    # ---------- 辅助函数 ----------

    def get_bev_size(self):
        """
        根据 point_cloud_range 和 voxel_size 计算 BEV 特征图尺寸 (H, W)。
        """
        x_min, y_min, z_min, x_max, y_max, z_max = self.point_cloud_range
        vx, vy = self.voxel_size[0], self.voxel_size[1]
        stride = self.feature_map_stride
        W = round((x_max - x_min) / (vx * stride))
        H = round((y_max - y_min) / (vy * stride))
        return H, W

    def map_to_bev(self, x, y):
        """
        将 (x, y) 坐标映射到 BEV 特征图中，返回：
         - bev_x, bev_y: 浮点数坐标（未取整）
         - x_int, y_int: 转换为整数索引（grid cell 坐标）
        """
        x_min, y_min, _, _, _, _ = self.point_cloud_range
        vx, vy = self.voxel_size[0], self.voxel_size[1]
        stride = self.feature_map_stride
        bev_x = (x - x_min) / (vx * stride)
        bev_y = (y - y_min) / (vy * stride)
        x_int, y_int = int(bev_x), int(bev_y)
        return bev_x, bev_y, x_int, y_int

    def get_gaussian_patch_indices(self, center_idx, radius, H, W):
        """
        根据中心索引 center_idx = (x_int, y_int) 和高斯核半径 radius，
        计算 BEV 区域及高斯核 patch 的索引范围：
            h_x_min, h_x_max, h_y_min, h_y_max,
            g_x_min, g_x_max, g_y_min, g_y_max
        """
        x_int, y_int = center_idx
        left = x_int - radius
        right = x_int + radius + 1
        top = y_int - radius
        bottom = y_int + radius + 1

        g_x_min = max(0, -left)
        g_x_max = (2 * radius + 1) - max(0, right - W)
        g_y_min = max(0, -top)
        g_y_max = (2 * radius + 1) - max(0, bottom - H)

        h_x_min = max(0, left)
        h_x_max = min(W, right)
        h_y_min = max(0, top)
        h_y_max = min(H, bottom)

        return h_x_min, h_x_max, h_y_min, h_y_max, g_x_min, g_x_max, g_y_min, g_y_max

    def __getitem__(self, idx):
        # 1. 加载点云和对应 GT 标签字典
        pc = PointCloud(self.pc_list[idx])
        gt_dict = self.label_list[idx]

        # 对 GT 进行过滤：只保留在 BEV 内的 box
        original_boxes = gt_dict['gt_boxes']
        original_labels = gt_dict['gt_labels']
        H, W = self.get_bev_size()

        valid_reg_list = []
        valid_ind_list = []
        valid_size_list = []
        valid_boxes_list = []
        valid_labels_list = []
        valid_height_list = []
        valid_rot_list = []

        for i, box in enumerate(original_boxes):
            x, y, z_val, l_box, w_box, h_box, rot_val = box.tolist()
            bev_x, bev_y, x_int, y_int = self.map_to_bev(x, y)
            # 只保留投影落在 [0,W) 和 [0,H) 内的 box
            if 0 <= x_int < W and 0 <= y_int < H:
                valid_boxes_list.append(box)
                valid_labels_list.append(original_labels[i])
                offset = [bev_x - x_int, bev_y - y_int]
                valid_reg_list.append(offset)
                valid_ind_list.append(y_int * W + x_int)
                size_w = l_box / (self.voxel_size[0] * self.feature_map_stride)
                size_h = w_box / (self.voxel_size[1] * self.feature_map_stride)
                valid_size_list.append([size_w, size_h])
                valid_height_list.append(z_val)
                valid_rot_list.append(rot_val)
            else:
                if self.debug:
                    print(f'[DEBUG] Box {i} with center ({x:.2f},{y:.2f}) mapped to ({x_int},{y_int}) is out of BEV range.')

        if len(valid_boxes_list) > 0:
            filtered_gt_boxes = torch.stack(valid_boxes_list)
            filtered_gt_labels = torch.tensor(valid_labels_list, dtype=torch.long)
            reg = torch.tensor(valid_reg_list, dtype=torch.float32)      # [valid_num, 2]
            ind = torch.tensor(valid_ind_list, dtype=torch.long)           # [valid_num]
            size = torch.tensor(valid_size_list, dtype=torch.float32)      # [valid_num, 2]
            reg_mask = torch.ones(len(valid_ind_list), dtype=torch.uint8)  # 有效标记
            height = torch.tensor(valid_height_list, dtype=torch.float32)  # [valid_num]
            rot = torch.tensor(valid_rot_list, dtype=torch.float32)        # [valid_num]
        else:
            filtered_gt_boxes = torch.empty((0, 7))
            filtered_gt_labels = torch.empty((0,), dtype=torch.long)
            reg = torch.empty((0, 2), dtype=torch.float32)
            ind = torch.empty((0,), dtype=torch.long)
            size = torch.empty((0, 2), dtype=torch.float32)
            reg_mask = torch.empty((0,), dtype=torch.uint8)
            height = torch.empty((0,), dtype=torch.float32)
            rot = torch.empty((0,), dtype=torch.float32)

        # 生成 GT 热力图（仅用有效的 box）
        filtered_heatmap = self.generate_heatmap(
            filtered_gt_boxes, filtered_gt_labels,
            point_cloud_range=self.point_cloud_range,
            voxel_size=self.voxel_size,
            feature_map_stride=self.feature_map_stride,
            num_classes=self.num_classes
        )

        # 构造返回的 GT 字典
        final_gt = {
            'gt_boxes': filtered_gt_boxes,   # 只包含有效 box
            'gt_labels': filtered_gt_labels,
            'heatmap': filtered_heatmap,
            'ind': ind,
            'reg': reg,
            'reg_mask': reg_mask,
            'size': size,
            'height': height,
            'rot': rot
        }
        
        return pc, final_gt

    def gaussian2D(self, shape, sigma=1):
        """生成 2D 高斯核矩阵"""
        m, n = [(ss - 1.) / 2. for ss in shape]
        y, x = np.ogrid[-m:m+1, -n:n+1]
        h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
        h[h < np.finfo(h.dtype).eps * h.max()] = 0
        return h

    def gaussian_radius(self, det_size, min_overlap=0.5):
        """计算高斯核半径（基于目标尺寸）"""
        height, width = det_size
        a1 = 1
        b1 = (height + width)
        c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
        sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
        r1 = (b1 - sq1) / (2 * a1)
        
        a2 = 4
        b2 = 2 * (height + width)
        c2 = (1 - min_overlap) * width * height
        sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
        r2 = (b2 - sq2) / (2 * a2)
        
        a3 = 4 * min_overlap
        b3 = -2 * min_overlap * (height + width)
        c3 = (min_overlap - 1) * width * height
        sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
        r3 = (b3 + sq3) / (2 * a3)
        
        return min(r1, r2, r3)

    def generate_heatmap(self, gt_boxes, gt_labels, point_cloud_range, voxel_size,
                         feature_map_stride=4, num_classes=3):
        """
        根据 gt_boxes 和 gt_labels 生成形状为 [num_classes, H, W] 的 heatmap，
        H,W 与 get_bev_size() 计算一致。
        """
        dx, dy = voxel_size[0], voxel_size[1]
        H, W = self.get_bev_size()
        if self.debug:
            print(f'[DEBUG] BEV feature map size: H={H}, W={W}')

        heatmap = torch.zeros((num_classes, H, W))
        for i, (box, label) in enumerate(zip(gt_boxes, gt_labels)):
            x, y, _, dx_size, dy_size, _, _ = box.tolist()
            bev_x, bev_y, x_int, y_int = self.map_to_bev(x, y)
            if self.debug:
                print(f'[DEBUG] Box {i}: center=({x:.2f}, {y:.2f}), mapped BEV=({bev_x:.2f}, {bev_y:.2f}) -> (x_int, y_int)=({x_int}, {y_int}), label={label}')
            if not (0 <= x_int < W and 0 <= y_int < H):
                if self.debug:
                    print(f'[DEBUG] Box {i} is out of BEV bounds, skipped.')
                continue
            # 计算 box 在 BEV 上的尺寸（顺序与 voxel_size 保持一致）
            box_hw = (dy_size / dy / feature_map_stride, dx_size / dx / feature_map_stride)
            radius = self.gaussian_radius(box_hw)
            radius = max(0, int(radius))
            diameter = 2 * radius + 1
            if self.debug:
                print(f'[DEBUG] Box {i}: box_hw={box_hw}, gaussian radius={radius}, diameter={diameter}')
            
            gaussian = self.gaussian2D((diameter, diameter), sigma=diameter / 6)
            gaussian = torch.from_numpy(gaussian).float()
            
            h_x_min, h_x_max, h_y_min, h_y_max, g_x_min, g_x_max, g_y_min, g_y_max = \
                self.get_gaussian_patch_indices((x_int, y_int), radius, H, W)
            
            if self.debug:
                print(f'[DEBUG] Box {i}: BEV region: x[{h_x_min}:{h_x_max}], y[{h_y_min}:{h_y_max}]')
                print(f'[DEBUG] Box {i}: Gaussian patch indices: x[{g_x_min}:{g_x_max}], y[{g_y_min}:{g_y_max}]')
            
            masked_heatmap = heatmap[label, h_y_min:h_y_max, h_x_min:h_x_max]
            masked_gaussian = gaussian[g_y_min:g_y_max, g_x_min:g_x_max]
            if self.debug:
                print(f'[DEBUG] Box {i}: before update, heatmap sum={masked_heatmap.sum():.4f}, gaussian sum={masked_gaussian.sum():.4f}')
            if masked_gaussian.shape != masked_heatmap.shape:
                if self.debug:
                    print(f'[DEBUG] Box {i}: Shape mismatch: heatmap patch shape {masked_heatmap.shape}, gaussian patch shape {masked_gaussian.shape}. Skipping this box.')
                continue
            heatmap[label, h_y_min:h_y_max, h_x_min:h_x_max] = torch.maximum(masked_heatmap, masked_gaussian)
            if self.debug:
                print(f'[DEBUG] Box {i}: after update, new heatmap channel {label} sum={heatmap[label].sum():.4f}')
        
        if self.debug:
            import matplotlib.pyplot as plt
            fig, axs = plt.subplots(1, num_classes, figsize=(5 * num_classes, 4))
            if num_classes == 1:
                axs = [axs]
            for i in range(num_classes):
                heat = heatmap[i].numpy()
                im = axs[i].imshow(heat, cmap='hot', interpolation='nearest')
                axs[i].set_title(f"Heatmap for class {i}")
                fig.colorbar(im, ax=axs[i])
            plt.suptitle("All Channel Heatmaps")
            plt.tight_layout()
            plt.show()
        
        return heatmap

可以看到dataset中每个data的input为一帧点云 (N,4), gt为boxes,labels,scores

# 2. 定义一个简单的Toy模型，使用提供的CenterPointModel作为输出


In [4]:
# model = CenterPoint([0.2,0.2,0.2],[0.2,0.2,0.2,0.2,0.2,0.2])
# summary(model,input_size=(2,10, 5))

这些尺寸实际上正是由各个预测分支的设计决定的。一般来说，CenterPoint 和类似模型会为每个任务分配固定数量的输出通道，然后对整个 BEV 特征图进行卷积预测。具体解释如下：

1. **raw_hm (热力图)**  
   - 输出尺寸为 `[3, 448, 1120]`，其中 3 表示预测类别数（这里类别数为 3，比如 Vehicle、Pedestrian、Cyclist）。因此，每个类别对应一个通道，最终形成一个形状为 `[num_classes, H, W]` 的热力图。

2. **raw_center (中心偏移)**  
   - 输出尺寸为 `[2, 448, 1120]`，这里 2 通常表示在 BEV 特征图上，每个 grid cell 内的预测中心坐标（x 和 y 的残差）。即从离散的网格中心到真实中心的偏移。

3. **raw_center_z (高度/中心z值)**  
   - 输出尺寸为 `[1, 448, 1120]`，只有 1 个通道，用于预测目标在 z 轴方向上的位置（或者说目标的高度信息）。

4. **raw_dim (尺寸)**  
   - 输出尺寸为 `[3, 448, 1120]`，这里 3 个通道分别对应目标的长、宽、高（l, w, h）的预测值，通常预测的是对数尺度（后续可能会做 exp 处理）或者直接在 BEV 中的尺寸。

5. **raw_rot (旋转角度)**  
   - 输出尺寸为 `[2, 448, 1120]`，通常用两个通道分别表示旋转角度的 sin 和 cos 值，这样可以避免角度回归带来的周期性问题。

**总结：**  
- 每个分支的通道数即对应任务的输出数量。  
- 448 和 1120 是 BEV 特征图的高度和宽度（由点云范围、voxel 大小和下采样系数决定）。  
- 例如，hm 分支输出 3 个通道对应 3 个类别，中心偏移输出 2 个通道（x,y），中心高度（z）输出 1 个通道，尺寸输出 3 个通道，而旋转输出 2 个通道。

# 3. 定义损失函数

In [5]:
def focal_loss(pred, gt, alpha=2, beta=4):
    """
    实现一个简单版本的 focal loss，
    输入 pred 和 gt 均为 [num_classes, H, W] 的张量
    pred 应该是经过 sigmoid 激活后的预测热力图
    """
    assert(isinstance(pred,torch.Tensor))
    assert(isinstance(gt, torch.Tensor))
    
    pos_inds = (gt == 1).float()
    neg_inds = (gt < 1).float()

    pos_loss = -torch.log(pred + 1e-4) * torch.pow(1 - pred, alpha) * pos_inds
    neg_loss = -torch.log(1 - pred + 1e-4) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds

    num_pos = pos_inds.sum()
    if num_pos == 0:
        loss = neg_loss.sum()
    else:
        loss = (pos_loss.sum() + neg_loss.sum()) / num_pos
    return loss

In [6]:
def regression_loss(pred_dicts, gt_dicts):
    total_offset_loss = 0.0
    total_height_loss = 0.0
    batch_count = len(pred_dicts)
    
    for i in range(batch_count):
        pred = pred_dicts[i]
        gt = gt_dicts[i]
        
        raw_center = pred["raw_center"]   # shape: [2, H, W]
        _, H, W = raw_center.shape
        raw_center_flat = raw_center.view(2, -1).transpose(0, 1)  # shape: [H*W, 2]
        
        raw_center_z = pred["raw_center_z"]  # shape: [1, H, W]
        raw_center_z_flat = raw_center_z.view(1, -1).transpose(0, 1)  # shape: [H*W, 1]
        
        # 确保索引为 long 类型
        indices = gt["ind"].to(torch.long)
        if indices.nelement() > 0:
            pred_offset = raw_center_flat[indices]  # shape: [N, 2]
            pred_height = raw_center_z_flat[indices]  # shape: [N, 1]
            
            gt_offset = gt["reg"]                 # shape: [N, 2]
            gt_height = gt["height"].unsqueeze(-1)  # shape: [N, 1]
            reg_mask = gt["reg_mask"].float()
            num_valid = reg_mask.sum() + 1e-4

            offset_loss = F.l1_loss(pred_offset, gt_offset, reduction="sum") / num_valid
            height_loss = F.l1_loss(pred_height, gt_height, reduction="sum") / num_valid
        else:
            # 如果没有有效目标，则构造一个与预测相关联的零 loss 以便 backward 正常
            offset_loss = raw_center_flat[0].sum() * 0.0
            height_loss = raw_center_z_flat[0].sum() * 0.0
        
        total_offset_loss += offset_loss
        total_height_loss += height_loss

    avg_offset_loss = total_offset_loss / batch_count
    avg_height_loss = total_height_loss / batch_count

    return avg_offset_loss, avg_height_loss

In [7]:
def compute_loss(pred_dicts, gt_dicts):
    """
    计算每个 batch（pred_dicts[i] 和 gt_dicts[i] 对应一批）的 focal loss，
    然后求平均总损失。

    输入：
      pred_dicts: list，每个元素为一个 batch 的预测 dict，
                   其中包含键 'raw_hm'，形状为 [num_classes, H, W]
      gt_dicts:   list，每个元素为一个 batch 的 GT dict，
                   其中包含键 'heatmap'，形状为 [num_classes, H, W]

    返回：
      平均 focal loss（标量 tensor）
    """
    total_loss = 0.0
    batch_count = len(pred_dicts)
    for i in range(batch_count):
        # print(f"计算第{i+1}个batch的 focal loss ...")
        # 获取该 batch 的预测热力图，注意需要先经过 sigmoid
        # pred_hm = torch.sigmoid(pred_dicts[i]["raw_hm"])  # [num_classes, H, W]
        # gt_hm = gt_dicts[i]["heatmap"]                    # [num_classes, H, W]

        # fc_loss = focal_loss(pred_hm, gt_hm)
        # total_loss += fc_loss
        
        xy_loss,z_loss = regression_loss(pred_dicts,gt_dicts)
        total_loss += xy_loss + z_loss

    avg_loss = total_loss / batch_count
    # print("平均 focal loss:", avg_loss.item())
    return avg_loss

# 4. 训练流程

In [None]:
voxel_size = [0.2, 0.2, 0.2]
feature_map_stride = 1
num_classes = 3
batch_size = 1
pc_range = [0, -44.8, -2, 224, 44.8, 4]

dataset = PointCloudDataset(debug=False,
                            voxel_size=voxel_size,
                            feature_map_stride=feature_map_stride,
                            pc_range=pc_range,
                            num_classes=num_classes)

# 划分训练集和测试集
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])


# Dataloader
trainloader = DataLoader(train_dataset, batch_size=batch_size,shuffle=False,collate_fn=myfunc,drop_last=True)

testloader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False,collate_fn=myfunc,drop_last=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 这行注释掉

model = CenterPoint(
    voxel_size=voxel_size,
    pc_range=pc_range,
    feature_map_stride=feature_map_stride,
    ).to(device)

torch.cuda.empty_cache()  

optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 100


In [None]:
dataset[0][1]['gt_boxes'].size()

In [None]:
# for i in trainloader:
#     pc_, gt_,len_ = i
#     print(type(pc_))
#     print(type(gt_))
#     print(pc_.size())
#     print(pc_[:,:3,:].size())
#     print(len(gt_))
#     pprint(gt_[0])
#     print(type(gt_[0]['gt_boxes']))
#     break

In [None]:
for epoch in range(num_epochs):
    model.train()  # 训练模式
    running_loss = 0.0
    
    for i, (pointclouds, gt_dicts, lengths) in enumerate(trainloader):
        # 将点云数据和对应的标签移到 CPU 上（如果不调用 .cuda() 则本来就在 CPU）
        pointclouds = move_to_gpu(pointclouds)
        gt_dicts = move_to_gpu(gt_dicts)
        
        optimizer.zero_grad()  # 清除梯度
        
        # 前向传播
        pred_dicts = model(pointclouds)

        # 计算损失
        loss = compute_loss(pred_dicts, gt_dicts)
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        
        running_loss += loss.item()
        
        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(trainloader)}], Loss: {running_loss / (i + 1):.4f}")
    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {running_loss / len(trainloader):.4f}")