In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import defaultdict
from det3d.models.centerpoint import CenterPoint
from det3d.types.pointcloud import PointCloud
from det3d.utils import move_to_gpu, move_to_cpu, print_dict_tensors_size
from torchinfo import summary
from pprint import pprint
import matplotlib.pyplot as plt
import math
import glob
from torch.nn.utils.rnn import pad_sequence

# 1. 定义一个Toy Dataset（与模型输出保持一致）

In [2]:
def myfunc(batch_data):
    """
    batch_data: list of samples, 每个样本格式为 (pc, label)
        其中 pc 为一个对象，其属性 points 为 numpy array (N_i, feature_dim)
        label 为该帧对应的标签数据（例如 gt dict 等）。
    
    返回：
        padded_points: shape [batch_size, max_num_points, feature_dim+1]
           —— 第一列为 batch_index, 后面的列为原始特征
        labels: 列表形式保存每个样本的 label
        lengths: 每帧原始点数，方便后续做 mask
    """
    point_clouds = []
    labels = []
    lengths = []
    # 遍历每个样本，同时记录下样本在 batch 中的 index
    for batch_idx, sample in enumerate(batch_data):
        # sample[0].points 是一个 numpy array，形状为 (N_i, feature_dim)
        points_tensor = torch.tensor(sample[0].points, dtype=torch.float)
        lengths.append(points_tensor.shape[0])
        labels.append(sample[1])
        # 为当前样本所有的点创建 batch index 列
        batch_idx_tensor = torch.full((points_tensor.size(0), 1), batch_idx, dtype=torch.float)
        # 拼接，得到新的 tensor，形状为 (N_i, feature_dim + 1)
        points_tensor = torch.cat([batch_idx_tensor, points_tensor], dim=1)
        point_clouds.append(points_tensor)
    
    # 使用 pad_sequence 对齐所有点云，padding_value=0
    padded_points = pad_sequence(point_clouds, batch_first=True, padding_value=0)
    
    return padded_points, labels, lengths

In [3]:
class PointCloudDataset(Dataset):
    def __init__(
        self,
        voxel_size: list,          # [voxel_x, voxel_y, voxel_z] (这里只使用 x, y)
        feature_map_stride: int,
        num_classes: int,
        pc_range:list,
        debug: bool = False        # 调试模式，默认 False
    ):
        self.num_samples = 100
        self.num_points = 1000
        self.voxel_size = voxel_size
        self.feature_map_stride = feature_map_stride
        self.num_classes = num_classes
        self.debug = debug
        
        self.label_mapping = {
            "Car": 0,
            "Cyclist": 1,
            "Pedestrian": 2,
            # "Hero": 3
        }
        
        # ------------------------------
        # 读取本地点云数据
        # ------------------------------
        pc_path = '/workspace/rosbags/archive/data_city/data_city/lidar_livox'
        npy_files = glob.glob(f"{pc_path}/*.npy")
        self.pc_list = [np.load(file) for file in npy_files]
        
        # 根据所有加载的点云数据计算全局的点云范围
        # 假设每个点云数组的形状为 (N, 4)，其中前三列分别为 x, y, z
        # all_points = np.concatenate(self.pc_list, axis=0)
        # x_min = np.min(all_points[:, 0])
        # y_min = np.min(all_points[:, 1])
        # z_min = np.min(all_points[:, 2])
        # x_max = np.max(all_points[:, 0])
        # y_max = np.max(all_points[:, 1])
        # z_max = np.max(all_points[:, 2])
        # self.point_cloud_range = [x_min, y_min, z_min, x_max, y_max, z_max]
        self.point_cloud_range = pc_range
        if self.debug:
            print(f'[DEBUG] Computed point_cloud_range: {self.point_cloud_range}')
        
        # ------------------------------
        # 读取本地标签数据
        # ------------------------------
        label_path = '/workspace/rosbags/archive/data_city/data_city/label3'
        labels = glob.glob(f"{label_path}/*.txt")
        label_list = []
        for label_file in labels:
            gt_dict = dict(
                gt_boxes = [],
                gt_labels = [],
            )
            with open(label_file, "r") as file:
                lines = file.readlines()
                for line in lines:
                    raw_line_list = line.split(" ")
                    # 前7个数字对应 [x, y, z, l, w, h, rot]
                    xyzlwhr = torch.Tensor([float(x) for x in raw_line_list[:-1]])
                    cur_label = raw_line_list[-1].replace('\n','').strip()
                    if cur_label == "Hero":
                        continue
                    gt_dict["gt_boxes"].append(xyzlwhr)
                    gt_dict["gt_labels"].append(self.label_mapping[cur_label])
            if len(gt_dict["gt_boxes"]) > 0:
                gt_dict["gt_boxes"] = torch.stack(gt_dict["gt_boxes"])
                gt_dict["gt_labels"] = torch.tensor(gt_dict["gt_labels"], dtype=torch.long)
            else:
                gt_dict["gt_boxes"] = torch.empty((0, 7))
                gt_dict["gt_labels"] = torch.empty((0,), dtype=torch.long)
            label_list.append(gt_dict)
        self.label_list = label_list

    def __len__(self):
        return self.num_samples

    # ---------- 辅助函数 ----------
    def get_bev_size(self):
        """
        计算 BEV 特征图尺寸 (H, W)。
        """
        x_min, y_min, z_min, x_max, y_max, z_max = self.point_cloud_range
        vx, vy = self.voxel_size[0], self.voxel_size[1]
        stride = self.feature_map_stride
        W = round((x_max - x_min) / (vx * stride))
        H = round((y_max - y_min) / (vy * stride))
        return H, W

    def map_to_bev(self, x, y):
        """
        将 (x, y) 坐标映射到 BEV 特征图中，返回：
        - bev_x, bev_y: 浮点数坐标
        - x_int, y_int: 对应的整数索引（grid cell 坐标）
        """
        x_min, y_min, _, _, _, _ = self.point_cloud_range
        vx, vy = self.voxel_size[0], self.voxel_size[1]
        stride = self.feature_map_stride
        bev_x = (x - x_min) / (vx * stride)
        bev_y = (y - y_min) / (vy * stride)
        x_int, y_int = int(bev_x), int(bev_y)
        return bev_x, bev_y, x_int, y_int

    def get_gaussian_patch_indices(self, center_idx, radius, H, W):
        """
        根据中心坐标 center_idx = (x_int, y_int) 以及高斯核 radius，
        计算 BEV 区域及 Gaussian patch 索引范围。
        返回：
            h_x_min, h_x_max, h_y_min, h_y_max,
            g_x_min, g_x_max, g_y_min, g_y_max
        """
        x_int, y_int = center_idx
        left = x_int - radius
        right = x_int + radius + 1
        top = y_int - radius
        bottom = y_int + radius + 1

        g_x_min = max(0, -left)
        g_x_max = (2 * radius + 1) - max(0, right - W)
        g_y_min = max(0, -top)
        g_y_max = (2 * radius + 1) - max(0, bottom - H)

        h_x_min = max(0, left)
        h_x_max = min(W, right)
        h_y_min = max(0, top)
        h_y_max = min(H, bottom)

        return h_x_min, h_x_max, h_y_min, h_y_max, g_x_min, g_x_max, g_y_min, g_y_max

    def __getitem__(self, idx):
        # 1. 加载点云与对应 GT 标签字典
        pc = PointCloud(self.pc_list[idx])
        gt_dict = self.label_list[idx]

        # 2. 生成 heatmap
        gt_boxes_tensor = gt_dict['gt_boxes']  # [num_boxes, 7]
        gt_labels_tensor = gt_dict['gt_labels']  # [num_boxes]
        heatmap = self.generate_heatmap(
            gt_boxes_tensor, gt_labels_tensor,
            point_cloud_range=self.point_cloud_range,
            voxel_size=self.voxel_size,
            feature_map_stride=self.feature_map_stride,
            num_classes=self.num_classes
        )

        # 3. 计算 BEV 特征图尺寸
        H, W = self.get_bev_size()

        # 4. 根据每个 GT 框计算回归目标
        ind_list = []
        reg_list = []
        size_list = []
        for box in gt_boxes_tensor:
            # box 格式: [x, y, z, l, w, h, rot]
            x, y, z_val, l_box, w_box, h_box, rot_val = box.tolist()
            bev_x, bev_y, x_int, y_int = self.map_to_bev(x, y)
            # 记录偏移（offset）
            offset_x = bev_x - x_int
            offset_y = bev_y - y_int
            reg_list.append([offset_x, offset_y])
            # 计算 flatten 后的索引
            ind = y_int * W + x_int
            ind_list.append(ind)
            # 计算 BEV 上的尺寸回归目标 (l, w)
            size_w = l_box / (self.voxel_size[0] * self.feature_map_stride)
            size_h = w_box / (self.voxel_size[1] * self.feature_map_stride)
            size_list.append([size_w, size_h])
        if len(ind_list) > 0:
            reg = torch.tensor(reg_list, dtype=torch.float32)     # [num_boxes, 2]
            ind = torch.tensor(ind_list, dtype=torch.long)          # [num_boxes]
            size = torch.tensor(size_list, dtype=torch.float32)     # [num_boxes, 2]
            reg_mask = torch.ones(len(ind_list), dtype=torch.uint8) # 回归 mask: 有效GT标记
            height = gt_boxes_tensor[:, 2]  # [num_boxes] z 值
            rot = gt_boxes_tensor[:, 6]     # [num_boxes] 旋转角
        else:
            reg = torch.empty((0, 2), dtype=torch.float32)
            ind = torch.empty((0,), dtype=torch.long)
            size = torch.empty((0, 2), dtype=torch.float32)
            reg_mask = torch.empty((0,), dtype=torch.uint8)
            height = torch.empty((0,), dtype=torch.float32)
            rot = torch.empty((0,), dtype=torch.float32)

        # 5. 构造返回的 GT 字典
        final_box_dicts = {
            'gt_boxes': gt_boxes_tensor,   # [num_boxes, 7]
            'gt_labels': gt_labels_tensor, # [num_boxes]
            'heatmap': heatmap,            # [num_classes, H, W]
            'ind': ind,                    # [num_boxes]
            'reg': reg,                    # [num_boxes, 2]
            'reg_mask': reg_mask,          # [num_boxes]
            'size': size,                  # [num_boxes, 2]
            'height': height,              # [num_boxes]
            'rot': rot                     # [num_boxes]
        }
        
        return pc, final_box_dicts

    def gaussian2D(self, shape, sigma=1):
        """生成 2D 高斯核矩阵"""
        m, n = [(ss - 1.) / 2. for ss in shape]
        y, x = np.ogrid[-m:m+1, -n:n+1]
        h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
        h[h < np.finfo(h.dtype).eps * h.max()] = 0
        return h

    def gaussian_radius(self, det_size, min_overlap=0.5):
        """计算高斯核半径（基于目标尺寸）"""
        height, width = det_size
        a1 = 1
        b1 = (height + width)
        c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
        sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
        r1 = (b1 - sq1) / (2 * a1)
        
        a2 = 4
        b2 = 2 * (height + width)
        c2 = (1 - min_overlap) * width * height
        sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
        r2 = (b2 - sq2) / (2 * a2)
        
        a3 = 4 * min_overlap
        b3 = -2 * min_overlap * (height + width)
        c3 = (min_overlap - 1) * width * height
        sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
        r3 = (b3 + sq3) / (2 * a3)
        
        return min(r1, r2, r3)

    def generate_heatmap(self, gt_boxes, gt_labels, point_cloud_range, voxel_size,
                         feature_map_stride=4, num_classes=3):
        """
        根据 gt_boxes 和 gt_labels 生成 shape 为 [num_classes, H, W] 的 heatmap。
        """
        x_min, y_min, z_min, x_max, y_max, z_max = point_cloud_range
        dx, dy = voxel_size[0], voxel_size[1]
        H, W = self.get_bev_size()
        if self.debug:
            print(f'[DEBUG] BEV feature map size: H={H}, W={W}')

        heatmap = torch.zeros((num_classes, H, W))
        for i, (box, label) in enumerate(zip(gt_boxes, gt_labels)):
            x, y, _, dx_size, dy_size, _, _ = box.tolist()
            bev_x, bev_y, x_int, y_int = self.map_to_bev(x, y)
            if self.debug:
                print(f'[DEBUG] Box {i}: center=({x:.2f}, {y:.2f}), mapped BEV=({bev_x:.2f}, {bev_y:.2f}) -> (x_int, y_int)=({x_int}, {y_int}), label={label}')
            if not (0 <= x_int < W and 0 <= y_int < H):
                if self.debug:
                    print(f'[DEBUG] Box {i} is out of BEV bounds, skipped.')
                continue
            # 计算 box 在 BEV 上的尺寸 (注意顺序与 voxel_size 保持一致)
            box_hw = (dy_size / dy / feature_map_stride, dx_size / dx / feature_map_stride)
            radius = self.gaussian_radius(box_hw)
            radius = max(0, int(radius))
            diameter = 2 * radius + 1
            if self.debug:
                print(f'[DEBUG] Box {i}: box_hw={box_hw}, gaussian radius={radius}, diameter={diameter}')
            
            gaussian = self.gaussian2D((diameter, diameter), sigma=diameter / 6)
            gaussian = torch.from_numpy(gaussian).float()
            
            h_x_min, h_x_max, h_y_min, h_y_max, g_x_min, g_x_max, g_y_min, g_y_max = \
                self.get_gaussian_patch_indices((x_int, y_int), radius, H, W)
            
            if self.debug:
                print(f'[DEBUG] Box {i}: BEV region: x[{h_x_min}:{h_x_max}], y[{h_y_min}:{h_y_max}]')
                print(f'[DEBUG] Box {i}: Gaussian patch indices: x[{g_x_min}:{g_x_max}], y[{g_y_min}:{g_y_max}]')
            
            masked_heatmap = heatmap[label, h_y_min:h_y_max, h_x_min:h_x_max]
            masked_gaussian = gaussian[g_y_min:g_y_max, g_x_min:g_x_max]
            if self.debug:
                print(f'[DEBUG] Box {i}: before update, heatmap sum={masked_heatmap.sum():.4f}, gaussian sum={masked_gaussian.sum():.4f}')
            if masked_gaussian.shape != masked_heatmap.shape:
                if self.debug:
                    print(f'[DEBUG] Box {i}: Shape mismatch: heatmap patch shape {masked_heatmap.shape}, gaussian patch shape {masked_gaussian.shape}. Skipping this box.')
                continue
            heatmap[label, h_y_min:h_y_max, h_x_min:h_x_max] = torch.maximum(masked_heatmap, masked_gaussian)
            if self.debug:
                print(f'[DEBUG] Box {i}: after update, new heatmap channel {label} sum={heatmap[label].sum():.4f}')
        
        if self.debug:
            fig, axs = plt.subplots(1, num_classes, figsize=(5 * num_classes, 4))
            if num_classes == 1:
                axs = [axs]
            for i in range(num_classes):
                heat = heatmap[i].numpy()
                im = axs[i].imshow(heat, cmap='hot', interpolation='nearest')
                axs[i].set_title(f"Heatmap for class {i}")
                fig.colorbar(im, ax=axs[i])
            plt.suptitle("All Channel Heatmaps")
            plt.tight_layout()
            plt.show()
        
        return heatmap

可以看到dataset中每个data的input为一帧点云 (N,4), gt为boxes,labels,scores

In [4]:
# voxel_size = [0.2, 0.2, 8]
# feature_map_stride = 4
# num_classes = 3

# dataset = PointCloudDataset(debug=True,
#                             voxel_size=voxel_size,
#                             feature_map_stride=feature_map_stride,
#                             num_classes=num_classes)

# # 划分训练集和测试集
# train_size = int(len(dataset) * 0.8)
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])


# # Dataloader
# trainloader = DataLoader(train_dataset, batch_size=2,shuffle=True,collate_fn=myfunc,drop_last=True)

# testloader = DataLoader(test_dataset, batch_size=2,shuffle=True,collate_fn=myfunc,drop_last=True)

# for i in trainloader:
#     pc_, gt_,len_ = i
#     print(type(pc_))
#     print(type(gt_))
#     print(pc_.size())
#     print(len(gt_))
#     pprint(gt_[0])
#     print(type(gt_[0]['gt_boxes']))
#     break

In [5]:
# pc,gt = dataset[0]
# pc

In [6]:
# gt

In [7]:
# gt['gt_boxes'].size() # [#boxes,7]

In [8]:
# gt['gt_labels'].size() #[boxes]

In [9]:
# gt['heatmap'].size() # [#classes, H, W]

In [10]:
# gt['height'].size()

In [11]:
# gt['ind'].size() # 每个框的中心点在热力图上的坐标

In [12]:
# gt['reg_mask'].size()

In [13]:
# gt['rot'].size()

# 2. 定义一个简单的Toy模型，使用提供的CenterPointModel作为输出


In [14]:
# class ToyModel(nn.Module):
#     def __init__(self, voxel_size, pc_range):
#         super(ToyModel, self).__init__()
#         # 使用预训练模型（例如 CenterPointModel）
#         self.centerpoint_model = CenterPointModel(voxel_size=voxel_size,pc_range=pc_range)

#     def forward(self, x):
#         temp = dict(
#             batch_size=1,
#             points=x
#         )
#         return self.centerpoint_model.forward(temp)

In [15]:
# model = CenterPoint([0.2,0.2,0.2],[0.2,0.2,0.2,0.2,0.2,0.2])
# summary(model,input_size=(2,10, 5))

In [16]:
# # 设置随机点云数据：每个点的格式为 [batch_index, x, y, z, intensity]
# # 确保 batch_index 取值在 [0, batch_size)，这里 batch_size=1，因此所有点的 batch_index 应为 0
# num_points = 43201
# batch_size = 1
# # 假设 point_cloud_range 为 [0, -44.8, -2, 224, 44.8, 4]
# x_vals = torch.rand(num_points) * (224 - 0) + 0
# y_vals = torch.rand(num_points) * (44.8 - (-44.8)) + (-44.8)
# z_vals = torch.rand(num_points) * (4 - (-2)) + (-2)
# intensity = torch.rand(num_points)
# batch_idx = torch.zeros(num_points)  # 所有点所属 batch 为 0

# # 拼接成 (num_points, 5) 的 tensor
# toy_points = torch.stack([batch_idx, x_vals, y_vals, z_vals, intensity], dim=1)

# # 构造 toy 模型
# toy_model = CenterPoint(voxel_size=None, pc_range=None)

# # 将模型与输入数据放到同一设备：优先使用 GPU（如果可用）
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# toy_model.to(device)
# toy_points = toy_points.to(device)

# pprint(toy_points)
# pprint(toy_points.size())

# # 前向传播
# toy_model.eval()
# with torch.no_grad():
#     outputs = toy_model(toy_points)

# # 输出结果
# print("输出的结果：")

In [17]:
# type(toy_points.size()[-1])

In [18]:
# toy_points.shape

In [19]:
# pprint(outputs)

In [20]:
# print(outputs[0]['raw_hm'].size())
# print(outputs[0]['raw_center'].size())
# print(outputs[0]['raw_center_z'].size())
# print(outputs[0]['raw_dim'].size())
# print(outputs[0]['raw_rot'].size())

这些尺寸实际上正是由各个预测分支的设计决定的。一般来说，CenterPoint 和类似模型会为每个任务分配固定数量的输出通道，然后对整个 BEV 特征图进行卷积预测。具体解释如下：

1. **raw_hm (热力图)**  
   - 输出尺寸为 `[3, 448, 1120]`，其中 3 表示预测类别数（这里类别数为 3，比如 Vehicle、Pedestrian、Cyclist）。因此，每个类别对应一个通道，最终形成一个形状为 `[num_classes, H, W]` 的热力图。

2. **raw_center (中心偏移)**  
   - 输出尺寸为 `[2, 448, 1120]`，这里 2 通常表示在 BEV 特征图上，每个 grid cell 内的预测中心坐标（x 和 y 的残差）。即从离散的网格中心到真实中心的偏移。

3. **raw_center_z (高度/中心z值)**  
   - 输出尺寸为 `[1, 448, 1120]`，只有 1 个通道，用于预测目标在 z 轴方向上的位置（或者说目标的高度信息）。

4. **raw_dim (尺寸)**  
   - 输出尺寸为 `[3, 448, 1120]`，这里 3 个通道分别对应目标的长、宽、高（l, w, h）的预测值，通常预测的是对数尺度（后续可能会做 exp 处理）或者直接在 BEV 中的尺寸。

5. **raw_rot (旋转角度)**  
   - 输出尺寸为 `[2, 448, 1120]`，通常用两个通道分别表示旋转角度的 sin 和 cos 值，这样可以避免角度回归带来的周期性问题。

**总结：**  
- 每个分支的通道数即对应任务的输出数量。  
- 448 和 1120 是 BEV 特征图的高度和宽度（由点云范围、voxel 大小和下采样系数决定）。  
- 例如，hm 分支输出 3 个通道对应 3 个类别，中心偏移输出 2 个通道（x,y），中心高度（z）输出 1 个通道，尺寸输出 3 个通道，而旋转输出 2 个通道。

# 3. 定义损失函数

In [21]:
def compute_loss(pred_dicts, gt_dicts):
    """
    计算框的损失（回归损失）和分类损失
    pred_dicts: 模型的输出
    gt_dicts: ground truth
    """
    assert(isinstance(pred_dicts,list))
    assert(isinstance(gt_dicts, list))
    assert(len(pred_dicts) == len(gt_dicts))
    # print(f'Batch size: {len(pred_dicts)}')
    
    batch_loss = 0
    batch_size = len(pred_dicts)
    
    for i in range(batch_size):
        print(f"计算第{i+1}个batch的loss ...")
    
    
        print("pred_dicts keys:")
        print(pred_dicts[i].keys())
        
        print("gt_dict keys:")
        print(gt_dicts[i].keys())
        
        print_dict_tensors_size(pred_dicts[i])
        print_dict_tensors_size(gt_dicts[i])
    
    
    return 0


    box_loss = 0.0
    cls_loss = 0.0
    num_boxes = len(pred_dicts[0]['pred_boxes'])  # 假设每个batch只有1个frame
    
    pred_boxes = pred_dicts[0]['pred_boxes']
    pred_labels = pred_dicts[0]['pred_labels']
    pred_scores = pred_dicts[0]['pred_scores']
    
    gt_boxes = gt_dicts[0]['gt_boxes']
    gt_labels = gt_dicts[0]['gt_labels']
    
    # 计算回归损失 (L1 loss)
    for i in range(num_boxes):
        pred_box = pred_boxes[i]
        for j in range(len(gt_boxes)):
            gt_box = gt_boxes[j]
            box_loss += torch.sum(torch.abs(pred_box - gt_box))  # L1 loss
    
    # 计算分类损失
    cls_loss = nn.CrossEntropyLoss()(pred_labels, gt_labels)
    
    return box_loss + cls_loss

In [22]:
# CELoss = nn.CrossEntropyLoss()

In [23]:
# temp_loss = CELoss(torch.tensor([[3, 1, 1,2]],dtype=torch.float32),torch.tensor([0],dtype=torch.long))
# temp_loss

In [24]:
# temp_loss.requires_grad

# 4. 训练流程

In [25]:
voxel_size = [0.2, 0.2, 0.2]
feature_map_stride = 1
num_classes = 3
batch_size = 1
pc_range = [0, -44.8, -2, 224, 44.8, 4]

dataset = PointCloudDataset(debug=False,
                            voxel_size=voxel_size,
                            feature_map_stride=feature_map_stride,
                            pc_range=pc_range,
                            num_classes=num_classes)

# 划分训练集和测试集
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])


# Dataloader
trainloader = DataLoader(train_dataset, batch_size=batch_size,shuffle=False,collate_fn=myfunc,drop_last=True)

testloader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False,collate_fn=myfunc,drop_last=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 这行注释掉

model = CenterPoint(
    voxel_size=voxel_size,
    pc_range=pc_range,
    feature_map_stride=feature_map_stride,
    ).to(device)
torch.cuda.empty_cache()  

optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 1


In [26]:
# for i in trainloader:
#     pc_, gt_,len_ = i
#     print(type(pc_))
#     print(type(gt_))
#     print(pc_.size())
#     print(pc_[:,:3,:].size())
#     print(len(gt_))
#     pprint(gt_[0])
#     print(type(gt_[0]['gt_boxes']))
#     break

In [27]:
for epoch in range(num_epochs):
    model.train()  # 训练模式
    running_loss = 0.0
    
    for i, (pointclouds, gt_dicts, lengths) in enumerate(trainloader):
        # 将点云数据和对应的标签移到 CPU 上（如果不调用 .cuda() 则本来就在 CPU）
        pointclouds = pointclouds.float().to(device)
        gt_dicts = move_to_cpu(gt_dicts)  # 此函数需要保证将所有 tensor 移到 CPU；如果没有可以直接使用原数据
        
        optimizer.zero_grad()  # 清除梯度
        
        # pprint(pointclouds)
        # pprint(pointclouds[0].size())
        
        # 前向传播
        pred_dicts = model(pointclouds)
        
        # pprint(type(pred_dicts))
        # pprint(type(pred_dicts[0]))
        # print(pred_dicts[0]['pred_boxes'].requires_grad)
        # pprint(pred_dicts[0])
        # pprint(gt_dicts[0])
        
        
        
        # 计算损失
        loss = compute_loss(pred_dicts, gt_dicts)
        # loss.backward()  # 反向传播
        # optimizer.step()  # 更新参数
        
        # running_loss += loss.item()
        
        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(trainloader)}], Loss: {running_loss / (i + 1):.4f}")
    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {running_loss / len(trainloader):.4f}")

计算第1个batch的loss ...
pred_dicts keys:
dict_keys(['pred_boxes', 'pred_scores', 'pred_labels', 'raw_hm', 'raw_center', 'raw_center_z', 'raw_dim', 'raw_rot'])
gt_dict keys:
dict_keys(['gt_boxes', 'gt_labels', 'heatmap', 'ind', 'reg', 'reg_mask', 'size', 'height', 'rot'])
Size of pred_boxes: [51, 7]
Size of pred_scores: [51]
Size of pred_labels: [51]
Size of raw_hm: [3, 448, 1120]
Size of raw_center: [2, 448, 1120]
Size of raw_center_z: [1, 448, 1120]
Size of raw_dim: [3, 448, 1120]
Size of raw_rot: [2, 448, 1120]
Size of gt_boxes: [256, 7]
Size of gt_labels: [256]
Size of heatmap: [3, 448, 1120]
Size of ind: [256]
Size of reg: [256, 2]
Size of reg_mask: [256]
Size of size: [256, 2]
Size of height: [256]
Size of rot: [256]
计算第1个batch的loss ...
pred_dicts keys:
dict_keys(['pred_boxes', 'pred_scores', 'pred_labels', 'raw_hm', 'raw_center', 'raw_center_z', 'raw_dim', 'raw_rot'])
gt_dict keys:
dict_keys(['gt_boxes', 'gt_labels', 'heatmap', 'ind', 'reg', 'reg_mask', 'size', 'height', 'rot'])
Size