In [None]:
from data_utils.ModelNetDataLoader import ModelNetDataLoader
import argparse
import numpy as np
import os
import torch
import logging
from tqdm import tqdm
import sys
import importlib

# 设置基本的目录变量
BASE_DIR = os.path.dirname(os.path.abspath('__file__')) # 基本目录
ROOT_DIR = BASE_DIR # 根目录
sys.path.append(os.path.join(ROOT_DIR, 'models')) # 模型目录

# 显示当前目录
print(f"当前工作目录：{BASE_DIR}")

In [None]:
def log_string(str):
        logger.info(str)
        print(str)

In [None]:
def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('training')
    parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
    parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
    parser.add_argument('--num_category', default=40, type=int, choices=[10, 40],  help='training on ModelNet10/40')
    parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
    parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
    parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
    parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
    parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
    parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
    parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
    parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
    parser.add_argument('--continue_from_checkpoint', action='store_true', help='Continue training from checkpoint if available')
    parser.add_argument('--no_continue_from_checkpoint', action='store_false', dest='continue_from_checkpoint', help='Start training from scratch')
    parser.set_defaults(continue_from_checkpoint=True)
    return parser.parse_args()
    
# 测试参数解析函数
args = parse_args()
print(f"解析到的参数：{args}")

In [None]:
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

'''CREATE DIR'''
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
exp_dir = Path('./log/')
exp_dir.mkdir(exist_ok=True)
exp_dir = exp_dir.joinpath('classification')
exp_dir.mkdir(exist_ok=True)
if args.log_dir is None:
    exp_dir = exp_dir.joinpath(timestr)
else:
    exp_dir = exp_dir.joinpath(args.log_dir)
exp_dir.mkdir(exist_ok=True)
checkpoints_dir = exp_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = exp_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)

In [None]:
# 定义一个inplace_relu函数。这个函数将被应用到模型的每个模块（或子模块）上。
# 如果模块是ReLU（Rectified Linear Unit）激活函数，这个函数会设置其为原地操作（in-place），这意味着直接修改输入而不是创建一个新的输出。
def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True

# 测试 inplace_relu 函数
print("测试 inplace_relu 函数...")

In [None]:


# 定义一个测试函数，用于评估模型在给定数据集上的性能。
# model: 要评估的模型
# loader: 数据加载器，包含测试数据
# num_class: 类别的数量，默认为40（ModelNet40数据集）
def test(model, loader, num_class=40):
    mean_correct = []  # 用于存储每个批次的平均正确率
    class_acc = np.zeros((num_class, 3))  # 用于存储每个类别的准确率
    classifier = model.eval()  # 将模型设置为评估模式

    # 对加载器中的每个批次进行迭代
    for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):

        # 如果不使用CPU，则将数据移动到GPU
        if not args.use_cpu:
            points, target = points.cuda(), target.cuda()

        # 调整点云数据的维度
        points = points.transpose(2, 1)
        # 通过模型进行预测
        pred, _ = classifier(points)
        # 获取每个点云的预测类别
        pred_choice = pred.data.max(1)[1]

        # 输出预测和实际的类别，以便观察模型的表现
        print(f"Batch {j+1}:")
        print("Predicted classes:", pred_choice)
        print("Actual classes:", target)


        # 对每个独特类别的点云进行迭代
        for cat in np.unique(target.cpu()):
            # 计算每个类别的准确率
            classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
            class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
            class_acc[cat, 1] += 1

        # 计算整个批次的正确率
        correct = pred_choice.eq(target.long().data).cpu().sum()
        mean_correct.append(correct.item() / float(points.size()[0]))

    # 计算每个类别的平均准确率
    class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
    class_acc = np.mean(class_acc[:, 2])
    # 计算所有批次的平均正确率
    instance_acc = np.mean(mean_correct)

    # 在函数结束时输出最终的分类准确率
    print(f"Instance Accuracy: {instance_acc}")
    print(f"Class Accuracy: {class_acc}")
    
    return instance_acc, class_acc


In [None]:
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('classification')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    # 解析命令行参数
    # 创建日志记录器
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    test_instance_acc_history = []
    test_class_acc_history = []

    '''DATA LOADING'''
    log_string('Load dataset ...')
    data_path = 'data/modelnet40_normal_resampled/'

    train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
    test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
    trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
    testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)

    '''MODEL LOADING'''
    num_class = args.num_category
    model = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_classification.py', str(exp_dir))

    classifier = model.get_model(num_class, normal_channel=args.use_normals)
    criterion = model.get_loss()
    classifier.apply(inplace_relu)

    if not args.use_cpu:
        classifier = classifier.cuda()
        criterion = criterion.cuda()

    if args.continue_from_checkpoint:
        try:
            checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
            start_epoch = checkpoint['epoch']
            classifier.load_state_dict(checkpoint['model_state_dict'])
            log_string('Use pretrain model')
        except:
            log_string('No existing model, starting training from scratch...')
            start_epoch = 0
    else:
        log_string('Starting training from scratch...')
        start_epoch = 0


    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        #训练过程
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        mean_correct = []
        classifier = classifier.train()

        scheduler.step()
        for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            optimizer.zero_grad()

            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)

            if not args.use_cpu:
                points, target = points.cuda(), target.cuda()

            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]

            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        #测试过程
        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)

            # 将测试精度添加到列表中
            test_instance_acc_history.append(instance_acc)
            test_class_acc_history.append(class_acc)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')

if __name__ == '__main__':
    args = parse_args()
    main(args)

In [None]:
# 参数解析
args = parse_args()
print("解析到的参数：", args)

# 创建日志记录器
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('参数设置：')
log_string(args)


In [None]:
test_instance_acc_history = []
test_class_acc_history = []
# 加载数据集
log_string('Load dataset ...')
data_path = 'data/modelnet40_normal_resampled/'

train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)

# 输出数据集信息
print("训练集大小：", len(train_dataset))
print("测试集大小：", len(test_dataset))


In [None]:

# 模型初始化
log_string('正在初始化模型...')
'''MODEL LOADING'''
num_class = args.num_category
model = importlib.import_module(args.model)
shutil.copy('./models/%s.py' % args.model, str(exp_dir))
shutil.copy('models/pointnet2_utils.py', str(exp_dir))
shutil.copy('./train_classification.py', str(exp_dir))

classifier = model.get_model(num_class, normal_channel=args.use_normals)
criterion = model.get_loss()
classifier.apply(inplace_relu)

if not args.use_cpu:
    classifier = classifier.cuda()
    criterion = criterion.cuda()

if args.continue_from_checkpoint:
    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
else:
    log_string('Starting training from scratch...')
    start_epoch = 0


if args.optimizer == 'Adam':
    optimizer = torch.optim.Adam(
        classifier.parameters(),
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        eps=1e-08,
        weight_decay=args.decay_rate
    )
else:
    optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)


In [None]:
# 训练模型
logger.info('开始训练...')
for epoch in range(start_epoch, args.epoch):
        #训练过程
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        mean_correct = []
        classifier = classifier.train()

        scheduler.step()
        for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            optimizer.zero_grad()

            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)

            if not args.use_cpu:
                points, target = points.cuda(), target.cuda()

            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]

            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1
            print(f"批次 {batch_id+1}/{len(trainDataLoader)}，损失：{loss.item()}")

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)
        # 输出训练轮次的结果
        print(f"训练完成，第 {epoch+1}/{args.epoch} 轮，准确率：{train_instance_acc}")

In [None]:
#测试过程
with torch.no_grad():
    instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)

    # 将测试精度添加到列表中
    test_instance_acc_history.append(instance_acc)
    test_class_acc_history.append(class_acc)

    # 输出测试结果
    log_string('测试实例准确率: %f, 类别准确率: %f' % (instance_acc, class_acc))
    log_string('最佳实例准确率: %f, 类别准确率: %f' % (best_instance_acc, best_class_acc))

    # 保存最佳模型
    if (instance_acc >= best_instance_acc):
        # 保存操作
        log_string('保存模型中...')
        best_instance_acc = instance_acc
        best_epoch = epoch + 1

    if (class_acc >= best_class_acc):
        best_class_acc = class_acc
    log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
    log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))

    if (instance_acc >= best_instance_acc):
        logger.info('Save model...')
        savepath = str(checkpoints_dir) + '/best_model.pth'
        log_string('Saving at %s' % savepath)
        state = {
            'epoch': best_epoch,
            'instance_acc': instance_acc,
            'class_acc': class_acc,
            'model_state_dict': classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        torch.save(state, savepath)
    global_epoch += 1


In [None]:
# 检查是否有图形界面环境（例如：使用DISPLAY环境变量）
if os.environ.get('DISPLAY', None):
    # 绘图
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(test_instance_acc_history, label='Test Instance Accuracy')
    plt.title('Test Instance Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(test_class_acc_history, label='Class Accuracy')
    plt.title('Class Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()
else:
    # 没有图形界面环境，将图表保存为文件
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(test_instance_acc_history, label='Test Instance Accuracy')
    plt.title('Test Instance Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(test_class_acc_history, label='Class Accuracy')
    plt.title('Class Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()
    plt.savefig('accuracy_plot.png')

