In [1]:
import os
import torch
import numpy as np

import provider

from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
from ModelNetDataLoader import ModelNetDataLoader
from visualizer.pc_utils import pyplot_draw_point_cloud

from PointNet.PointNetModel import PointNetCls, PointNetCls_Loss

matplotlib.use('TkAgg')

In [2]:
# 加载数据集
data_path = "./../data/modelnet40_normal_resampled/"
batch_size = 24

train_dataset = ModelNetDataLoader(root=data_path, split='train')
test_dataset = ModelNetDataLoader(root=data_path, split='test')
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1,
                                              drop_last=True)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)


The size of train data is 9843
The size of test data is 2468


In [3]:
# 某个点云数据的可视化
_x, _y = train_dataset.__getitem__(5000)
pyplot_draw_point_cloud(_x)
plt.show()

In [3]:
# 加载模型，设置损失函数、优化器、scheduler
num_class = 40
lr = 0.001
model_save_pth = './best_model_%d.pth' % num_class

classifier = PointNetCls(num_class, normal_channel=False).cuda()

criterion = PointNetCls_Loss().cuda()
optimizer = torch.optim.Adam(classifier.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)

In [23]:
# 是否要使用已经存在的“预训练模型”？
if os.path.exists(model_save_pth):
    print("Load state dict")
    classifier.load_state_dict(torch.load(model_save_pth))
else:
    print("No model state dict saved!")

No model state dict saved!


In [4]:
# 定义测试函数
def test(model, loader, num_class=40):
    mean_correct = []
    class_acc = np.zeros((num_class, 3))
    classifier = model.eval()

    for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):

        points, target = points.cuda(), target.cuda()

        points = points.transpose(2, 1)
        pred, _ = classifier(points)
        pred_choice = pred.data.max(1)[1]

        for cat in np.unique(target.cpu()):
            classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
            class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
            class_acc[cat, 1] += 1

        correct = pred_choice.eq(target.long().data).cpu().sum()
        mean_correct.append(correct.item() / float(points.size()[0]))

    class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
    class_acc = np.mean(class_acc[:, 2])
    instance_acc = np.mean(mean_correct)

    return instance_acc, class_acc

In [6]:
# 训练模型
epochs = 200
best_instance_acc = 0.0

for epoch in range(epochs):
    print("Epoch %d (%d/%d):" % (epoch + 1, epoch + 1, epochs))
    classifier.train()

    mean_correct = []
    for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
        optimizer.zero_grad()

        # 点云数据增强
        points = points.data.numpy()
        points = provider.random_point_dropout(points)
        points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
        points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
        points = torch.Tensor(points)
        points = points.transpose(2, 1)
        #
        points = points.cuda()
        target = target.cuda()
        # 反向传播
        pred, trans_feat = classifier(points)
        loss = criterion(pred, target.long(), trans_feat)
        loss.backward()
        optimizer.step()

        # 统计
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.long().data).cpu().sum()
        mean_correct.append(correct.item() / float(points.size()[0]))

    train_instance_acc = np.mean(mean_correct)
    print('Train Instance Accuracy: %f' % train_instance_acc)
    scheduler.step()

    # 测试
    with torch.no_grad():
        instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)

        print("Test Instance Accuracy: %f" % instance_acc)

        if instance_acc >= best_instance_acc:
            best_instance_acc = instance_acc
            torch.save(classifier.state_dict(), model_save_pth)
            print("得到最好的Accuracy，保存当前模型参数")


Epoch 1 (1/200):


KeyboardInterrupt: 

In [16]:
# 加载模型进行预测验证
classifier = PointNetCls(num_class, normal_channel=False)
classifier.eval()
classifier.load_state_dict(torch.load('./best_model_40_acc_902.pth'))

while True:
    a = int(input())
    x, label = train_dataset.__getitem__(a)
    x_ = torch.unsqueeze(torch.tensor(x), 0)
    x_ = x_.transpose(2, 1)
    with torch.no_grad():
        pred, _ = classifier(x_)
    pred = int(pred.argmax(1)[0])
    print("pred: ", pred, " label: ", label)
    pyplot_draw_point_cloud(x)
    plt.show()

pred:  0  label:  0
pred:  1  label:  1
pred:  2  label:  2
pred:  2  label:  2
pred:  4  label:  4
pred:  6  label:  6


ValueError: invalid literal for int() with base 10: ''