In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import defaultdict
from det3d.models.centerpoint import CenterPointModel
from det3d.types.pointcloud import PointCloud
from det3d.utils import move_to_gpu
from torchinfo import summary
from pprint import pprint

# 1. 定义一个Toy Dataset（与模型输出保持一致）

In [2]:
class DummyPointCloudDataset(Dataset):
    def __init__(self, num_samples, num_points=10000):
        self.num_samples = num_samples
        self.num_points = num_points
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # 生成点云数据 [N, 4] [x, y, z, intensity]
        points = np.random.rand(self.num_points, 4).astype(np.float32)
        pc = PointCloud(points)
        
        # 生成目标框（boxes）: 每个框包含7个元素 [x, y, z, dx, dy, dz, heading] (6维框 + 1个分类标签)
        num_boxes = np.random.randint(1, 4)  # 随机生成1到3个框
        boxes = np.random.rand(num_boxes, 7).astype(np.float32)  # 每个框 [x, y, z, dx, dy, dz, heading]
        
        # 生成目标labels: Vehicle, Pedestrian, Cyclist
        labels = np.random.randint(0, 3, size=(num_boxes,))  # 0: Vehicle, 1: Pedestrian, 2: Cyclist
        
        # 生成scores（预测框的置信度）
        # scores = np.random.rand(num_boxes).astype(np.float32)
        
        # 将boxes和labels合并为模型输出的形式
        final_box_dicts = {
            'gt_boxes': torch.tensor(boxes),
            'gt_labels': torch.tensor(labels),
            # 'pred_scores': torch.tensor(scores)
        }
        
        # 返回点云数据以及对应的gt信息
        return pc, final_box_dicts

In [3]:
dataset = DummyPointCloudDataset(num_samples=1000)

In [4]:
# collect function
def myfunc(batch_data):
    '''
    batch_data: Nx2
    '''
    resData = []
    resLabel = []
    for i in batch_data:
        resData.append(i[0].points)
        resLabel.append(i[1])
    resData = np.array(resData)
    # resLabel = np.array(resLabel)
    return torch.tensor(resData,dtype=torch.float),resLabel
    
# 划分训练集和测试集
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])


# Dataloader
trainloader = DataLoader(train_dataset, batch_size=32,shuffle=True,collate_fn=myfunc,drop_last=True)

testloader = DataLoader(test_dataset, batch_size=32,shuffle=True,collate_fn=myfunc,drop_last=True)

In [5]:
pc,gt = dataset[0]

In [6]:
pc.points,pc.points.shape

(array([[0.7901998 , 0.34698248, 0.6846077 , 0.13670315],
        [0.42195234, 0.73787856, 0.14577417, 0.62523973],
        [0.8697555 , 0.9963083 , 0.23345533, 0.37873486],
        ...,
        [0.40082225, 0.8219846 , 0.4352875 , 0.41617623],
        [0.15004998, 0.32496294, 0.2700084 , 0.17962174],
        [0.08487051, 0.86554825, 0.89977574, 0.41413808]], dtype=float32),
 (10000, 4))

In [7]:
gt

{'gt_boxes': tensor([[0.9191, 0.0458, 0.5771, 0.1287, 0.1028, 0.2372, 0.7943],
         [0.0083, 0.6910, 0.3120, 0.7608, 0.2222, 0.5067, 0.6568]]),
 'gt_labels': tensor([2, 1])}

可以看到dataset中每个data的input为一帧点云 (N,4), gt为boxes,labels,scores

In [8]:
for i in trainloader:
    in_, out_ = i
    print(type(in_))
    print(type(out_))
    print(in_.size())
    print(len(out_))
    pprint(out_[0])
    print(type(out_[0]['gt_boxes']))
    break

<class 'torch.Tensor'>
<class 'list'>
torch.Size([32, 10000, 4])
32
{'gt_boxes': tensor([[0.6827, 0.0780, 0.6565, 0.8463, 0.3140, 0.7587, 0.2996],
        [0.0463, 0.5335, 0.0498, 0.6002, 0.0024, 0.1941, 0.5001]]),
 'gt_labels': tensor([1, 2])}
<class 'torch.Tensor'>


# 2. 定义一个简单的Toy模型，使用提供的CenterPointModel作为输出


In [9]:
class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        # 使用预训练模型（例如 CenterPointModel）
        self.centerpoint_model = CenterPointModel()

    def forward(self, x):
        temp = dict(
            batch_size=1,
            points=x
        )
        return self.centerpoint_model.forward(temp)

In [10]:
summary(ToyModel(),input_size=(10, 4))

Layer (type:depth-idx)                                  Output Shape              Param #
ToyModel                                                [24, 7]                   --
├─LD_base: 1-1                                          --                        --
│    └─BoolMap: 2-1                                     [1, 30, 448, 1120]        --
│    └─ResBEVBackboneConcat: 2-2                        [1, 128, 448, 1120]       --
│    │    └─ModuleList: 3-9                             --                        (recursive)
│    │    └─ModuleList: 3-10                            --                        --
│    │    └─ModuleList: 3-9                             --                        (recursive)
│    │    └─ModuleList: 3-10                            --                        --
│    │    └─ModuleList: 3-9                             --                        (recursive)
│    │    └─ModuleList: 3-10                            --                        --
│    │    └─ModuleList: 3-9      

# 3. 定义损失函数

In [11]:
def compute_loss(pred_dicts, gt_dicts):
    """
    计算框的损失（回归损失）和分类损失
    pred_dicts: 模型的输出
    gt_dicts: ground truth
    """
    box_loss = 0.0
    cls_loss = 0.0
    num_boxes = len(pred_dicts[0]['pred_boxes'])  # 假设每个batch只有1个frame
    
    pred_boxes = pred_dicts[0]['pred_boxes']
    pred_labels = pred_dicts[0]['pred_labels']
    pred_scores = pred_dicts[0]['pred_scores']
    
    gt_boxes = gt_dicts[0]['gt_boxes']
    gt_labels = gt_dicts[0]['gt_labels']
    
    # 计算回归损失 (L1 loss)
    for i in range(num_boxes):
        pred_box = pred_boxes[i]
        for j in range(len(gt_boxes)):
            gt_box = gt_boxes[j]
            box_loss += torch.sum(torch.abs(pred_box - gt_box))  # L1 loss
    
    # 计算分类损失
    cls_loss = nn.CrossEntropyLoss()(pred_labels, gt_labels)
    
    return box_loss + cls_loss

In [12]:
CELoss = nn.CrossEntropyLoss()

In [16]:
temp_loss = CELoss(torch.tensor([[3, 1, 1,2]],dtype=torch.float32),torch.tensor([0],dtype=torch.long))
temp_loss

tensor(0.4938)

In [18]:
temp_loss.requires_grad

False

# 4. 训练流程

In [14]:
num_samples = 100
batch_size = 4
dataset = DummyPointCloudDataset(num_samples=num_samples)

# 划分训练集和测试集
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])


# Dataloader
trainloader = DataLoader(train_dataset, batch_size=32,shuffle=True,collate_fn=myfunc,drop_last=True)

testloader = DataLoader(test_dataset, batch_size=32,shuffle=True,collate_fn=myfunc,drop_last=True)

In [20]:
model = ToyModel().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 1


for epoch in range(num_epochs):
    model.train()  # 设置为训练模式
    running_loss = 0.0
    
    for i, (pointclouds, gt_dicts) in enumerate(trainloader):

        pointclouds = pointclouds.float().cuda()  # 将点云数据移到GPU
        gt_dicts = move_to_gpu(gt_dicts)          # 将gt数据移到GPU
        
        optimizer.zero_grad()  # 清除梯度
        
        # 进行前向传播
        pred_dicts = model(pointclouds)

        pprint(type(pred_dicts))
        pprint(type(pred_dicts[0]))
        print(pred_dicts[0]['pred_boxes'].requires_grad)
        pprint(pred_dicts[0])
        pprint(gt_dicts[0])

        break
        
        # 计算损失
        # loss = compute_loss(pred_dicts, gt_dicts)
        
        # loss.backward()  # 反向传播
        # optimizer.step()  # 更新参数
        
        # running_loss += loss.item()
        
        # if (i + 1) % 10 == 0:  # 每10个batch打印一次
        #     print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], Loss: {running_loss / (i + 1):.4f}")
    
    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {running_loss / len(trainloader):.4f}")

<class 'list'>
<class 'dict'>
True
{'pred_boxes': tensor([[ 2.0887e+00,  1.3918e+00,  2.5717e+00,  2.8909e+01,  1.4069e+03,
          6.0850e+00,  9.7034e-01],
        [ 3.8051e+00, -2.0893e+00,  3.9188e+00,  7.4553e+03,  1.2600e+03,
          4.4188e+04,  2.1101e+00],
        [ 5.1828e+00, -6.0257e+00,  2.5024e+00,  1.3809e+02,  2.0294e-06,
          5.8427e-03,  1.2547e+00],
        [ 1.3964e+01,  4.7740e+00, -1.4441e+00,  2.0946e-04,  1.7454e+00,
          6.9788e-05,  1.2425e+00],
        [ 5.8305e+00, -6.9399e+00,  1.9456e+00,  1.0143e+04,  3.5738e-06,
          4.5877e-01,  1.2364e+00],
        [ 8.0057e-01, -1.0993e+01,  1.6278e+00,  2.0919e+32,  1.2613e+16,
          4.0584e+18,  1.2269e+00],
        [ 2.3284e+00, -3.0907e+00, -7.1253e-01,  8.2326e+01,  4.8963e+02,
          4.2024e+05,  1.5845e+00],
        [ 8.5516e-02, -4.9989e+00,  2.3617e-01,  3.3218e-01,  5.5053e-04,
          8.9094e+00,  2.7714e+00],
        [ 8.0948e+00, -6.6940e+00,  3.7929e+00,  2.9757e+05,  2.8290e-