In [70]:
import torch
import torch.nn as nn

class ShapeClassifier(nn.Module):
    def __init__(self, sub_function, net):
        super().__init__()
        self.sub_function = sub_function
        self.encoder = net
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn1 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=0.5)
        self.linear2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=0.5)
        self.linear3 = nn.Linear(256, 40)
    
    def forward(self, x):
        x = x.numpy()
        sub = torch.Tensor(self.sub_function(x, 1024))
        print('sub shape: ', sub.shape)
        with torch.no_grad():
            for p in self.encoder.parameters():
                p.requires_grad = False
            x = self.encoder(sub)
        x = x.reshape(x.shape[0], -1)  # bs, 1024
        x = self.relu(self.bn1(self.linear1(x)))
        x = self.dp1(x)
        x = self.relu(self.bn2(self.linear2(x)))
        x = self.dp2(x)
        x = self.linear3(x)
        return x

随机输入，检查数据流

In [71]:
from SimAttention.network.encoder import PCT_Encoder
from SimAttention.network.augmentation import Batch_PointWOLF
from SimAttention.utils.crops import *
from SimAttention.utils.provider import *
from SimAttention.dataloader import ModelNetDataSet
from torch.utils.data import DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root = '/home/akira/下载/Pointnet2_PyTorch-master/byol_pcl/data/modelnet40_normal_resampled'
dataset = ModelNetDataSet(root)
trainDataLoader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

The size of train data is 9843


In [72]:
online_encoder = PCT_Encoder()
eval_net = ShapeClassifier(net=online_encoder, sub_function=b_fps)
rand_input = torch.ones((4, 10000, 3))
rand_output = eval_net(rand_input)
print(rand_output.shape)

sub shape:  torch.Size([4, 1024, 3])
torch.Size([4, 40])


In [None]:
# 也需要在optimizer中增加filter，表示不更新参数

# optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad, model.parameters()), lr=0.01,
                                   # betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5)

In [None]:
# 保存训练好的模型

# 那么现在有一个问题了，就是我这个分类器的后面几层用什么数据来训练
# 首先可以肯定的是，真正测试的时候这些都是固定的，不能继续训练了，所以用的应该还是训练数据来训练后面这些层
# 流程： trainLoader - 训练net模型 - 读取最佳net模型 - 继续使用trainLoader训练分类模型 - 用testLoader检测
# 需要训练多epoch，然后记录loss表现最好的一个epoch的网络参数，这个也是个问题，应该怎么才算是最好？

In [None]:
# 假设现在已经训练好了一个encoder，并保存了
# 下面这个就是训练验证模型的过程了
net = torch.load('net.pkl')
classifier = ShapeClassifier(net)
criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad, classifier.parameters()), lr=0.01,
                                   betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.3)

for epoch in range(0, max_epoch):
    classifier.train()
    
    for data in trainDataloader:
        points, target = data
        points = points.data.numpy()
        points = provider.random_point_dropout(points)
        points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3])
        points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3])
        points = torch.Tensor(points)
        target = target[:, 0]
        
        points, target = points.cuda(), target.cuda()
        optimizer.zero_grad()

        pred = classifier(points)
        loss = criterion(pred, target.long())
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.long().data).cpu().sum()
        mean_correct.append(correct.item() / float(points.size()[0]))
        loss.backward()
        optimizer.step()
    scheduler.step()

In [None]:
# 就是遍历test数据，然后检查正确性
def test(testDataloader, pred_func):
    acc_number = 0.0
    total_num = len(testDataloader)
    for data in testDataloader:
        points, target = data
        target = target[:, 0]
        pred = pred_func(points)
        for i in range(data.shape[0]):
            if pred[i] == target[i]:
                acc_number += 1.0
    acc_ratio = acc_number / total_num
    print('Instance Accuracy: ', acc_ratio)