In [1]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import numpy as np
from numpy import *
import numpy.random as rng
import os
import matplotlib.pyplot as plt
from models.pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
from sklearn.metrics import accuracy_score
import pandas as pd


In [2]:
torch.cuda.set_device(1)
# os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

In [3]:

class get_model(nn.Module):
    def __init__(self,num_class=10,normal_channel=False):
        super(get_model, self).__init__()
        in_channel = 3 if normal_channel else 0
        self.normal_channel = normal_channel
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
        self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(256,5)
        self.fc5 = nn.Linear(100, 1)
    
    def forward_once(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        output = self.fc3(x)
        return output
    
    def forward(self, input1, input2):
  
        output1 = self.forward_once(input1)

        output2 = self.forward_once(input2)
        
#         dis = F.pairwise_distance(output1, output2)
#         dis = torch.abs(output1 - output2)

#         pre = torch.sigmoid(self.fc5(dis.unsqueeze(-1)))
#         pre = self.fc5(dis)
#         pre = torch.sigmoid(self.fc5(dis))
#         print("pre = ", pre)
        
        return output1,output2


In [4]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):

        euclidean_distance = F.pairwise_distance(output1, output2)

        print(euclidean_distance)

        loss_contrastive = torch.mean((label) * torch.pow(euclidean_distance, 2) +
                                      (1 - label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
#         print("test:", loss_contrastive)
        return loss_contrastive*0.5

In [5]:
def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc

In [6]:
def test_data(idx = 2120):
    npoints = 1024    
    split='test'
    root = "/home/jovyan/code/PointCloudClassifier/PointNet/pointnet/modelnet40_normal_resampled"
    num_category = 10 

    catfile = os.path.join(root, 'modelnet10_shape_names.txt')

    cat = [line.rstrip() for line in open(catfile)]  # 以列表的形式存放40类物品
    classes = dict(zip(cat, range(len(cat))))
    shape_ids = {}

    shape_ids['train'] = [line.rstrip() for line in open(os.path.join(root, 'modelnet10_train.txt'))]
    shape_ids['test'] = [line.rstrip() for line in open(os.path.join(root, 'modelnet10_test.txt'))]

    assert (split == 'train' or split == 'test')
    shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
    datapath = [(shape_names[i], os.path.join(root, shape_names[i], shape_ids[split][i]) + '.txt') 
                for i in range(len(shape_ids[split]))]
    
#     print('The size of %s data is %d' % (split, len(datapath)))

#     np.random.shuffle(datapath)
    save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (num_category, split, npoints))

    if split == 'train':    
        y_train_point = []
        for index in range(len(datapath)):
            cls = classes[datapath[index][0]]
            y_train_point.append(cls)

        y_train_point = np.array(y_train_point).astype(np.int32)

        indices = [np.where(y_train_point == i)[0] for i in sorted(list(set(y_train_point)))]
        n_classes = len(sorted(list(set(y_train_point))))
    if split == 'test':
        y_test_point = []
        for index in range(len(datapath)):
            cls = classes[datapath[index][0]]
            y_test_point.append(cls)
        
        #需要输出，表示分类标签，908
        y_test_point = np.array(y_test_point).astype(np.int32)

        indices = [np.where(y_test_point == i)[0] for i in sorted(list(set(y_test_point)))]
        sort_classes = sorted(list(set(y_test_point)))
        n_classes = len(sorted(list(set(y_test_point))))
        N = n_classes
        
  
        #这里
#         point_set = np.loadtxt(datapath[1][1], delimiter=',').astype(np.float32)[0:npoints, :][:, 0:3]
        point_set = pc_normalize(np.loadtxt(datapath[1][1], delimiter=',').astype(np.float32)[0:npoints, :][:, 0:3])#要改为3个长度的数据
        w, h = point_set.shape

        temp = np.loadtxt(datapath[idx][1], delimiter=',').astype(np.float32)[0:npoints, :][:, 0:3]
        test_image = np.asarray([temp]*N).reshape(N, w, h)
        
        support_set = np.zeros((N,w,h))

        for index in range(N):
            #这里
            support_set[index,:,:] = pc_normalize(np.loadtxt(datapath[int(rng.choice(indices[index],size=(1,),replace=False))][1], delimiter=',').astype(np.float32)[0:npoints, :][:, 0:3])
#             support_set[index,:,:] = np.loadtxt(datapath[int(rng.choice(indices[index],size=(1,),replace=False))][1], delimiter=',').astype(np.float32)[0:npoints, :][:, 0:3]
            

        targets = np.zeros((N,))
        
        true_index = sort_classes.index(y_test_point[idx])
        targets[true_index] = 1

        categories = sort_classes
        pairs = [test_image,support_set]
        pairs = torch.from_numpy(np.array(pairs).astype(np.float32))
        pairs = pairs.transpose(3, 2).cuda()
        
    return pairs,targets,categories,y_test_point

In [7]:
def y_test_point(idx = 2120):
    npoints = 1024    
    split='test'
    root = "/home/jovyan/code/PointCloudClassifier/PointNet/pointnet/modelnet40_normal_resampled"
    num_category = 10 

    catfile = os.path.join(root, 'modelnet10_shape_names.txt')

    cat = [line.rstrip() for line in open(catfile)]  # 以列表的形式存放40类物品
    classes = dict(zip(cat, range(len(cat))))
    shape_ids = {}

    shape_ids['train'] = [line.rstrip() for line in open(os.path.join(root, 'modelnet10_train.txt'))]
    shape_ids['test'] = [line.rstrip() for line in open(os.path.join(root, 'modelnet10_test.txt'))]

    assert (split == 'train' or split == 'test')
    shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
    datapath = [(shape_names[i], os.path.join(root, shape_names[i], shape_ids[split][i]) + '.txt') 
                for i in range(len(shape_ids[split]))]
    
#     print('The size of %s data is %d' % (split, len(datapath)))

#     np.random.shuffle(datapath)
    save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (num_category, split, npoints))

    if split == 'train':    
        y_train_point = []
        for index in range(len(datapath)):
            cls = classes[datapath[index][0]]
            y_train_point.append(cls)

        y_train_point = np.array(y_train_point).astype(np.int32)

        indices = [np.where(y_train_point == i)[0] for i in sorted(list(set(y_train_point)))]
        n_classes = len(sorted(list(set(y_train_point))))
    if split == 'test':
        y_test_point = []
        for index in range(len(datapath)):
            cls = classes[datapath[index][0]]
            y_test_point.append(cls)
        
        #需要输出，表示分类标签，908
        y_test_point = np.array(y_test_point).astype(np.int32)
    
    return y_test_point

In [8]:
def test_oneshot_new(model,verbose=0):

    n_correct = 0
    k = 908
    preds = []
    probs_all = []
    err_print_num = 0
    for idx in range(k):
        inputs,targets,categories,y_test = test_data(idx) #通过函数调用测试数据
        model.eval()
        output1,output2 = model(inputs[0],inputs[1])  #需要改
#         probs.detach() #将图中的变量解除出来
        #probs = F.pairwise_distance(output1, output2)
           
        euclidean_distance = F.pairwise_distance(output1, output2)

#         diff = euclidean_distance.cpu().detach().numpy()[0]
#         print(euclidean_distance.cpu().detach().numpy()[0])


        if np.argmin(euclidean_distance.detach().cpu().numpy()) == np.argmax(targets):
#         if np.argmax(probs) == np.argmax(targets):
            n_correct+=1
            
        preds.append([categories[np.argmax(targets)],categories[np.argmin(euclidean_distance.detach().cpu().numpy())]])
        probs_all.append(euclidean_distance.detach().cpu().numpy())

    percent_correct = (100.0*n_correct / k)
    print("*"*50)
    print(n_correct)
#     print(probs)

    return percent_correct,np.array(preds),np.array(probs_all),categories

In [9]:
exps = [150,180,210]
y_test = y_test_point()
exp_name = "test"
for exp in exps:
    checkpoint = torch.load('./checkpoints_new/best_model_%s.pth'% exp)
    model = get_model().cuda()
#     model = nn.DataParallel(model,device_ids=[0,1,2])
#     model = nn.DataParallel(model).cuda()
#     device=torch.device("cuda:0" )
#     model.to(device)

    model.load_state_dict(checkpoint['model_state_dict'], False)

    scores_1_shot = []
    scores_5_shot = []
    scores_5_shot_prod = []
    
    preds_5_shot = []
    prods_5_shot = []
    scores = []
    with torch.no_grad():
        for k in range(5):
            val_acc,preds, prods, categories = test_oneshot_new(model)
        #                 utils.confusion_plot(preds[:,1],preds[:,0])
            print(val_acc,preds.shape, prods.shape)
            scores.append(val_acc)
        #                     print("preds[:,1]",preds[:,1])
            preds_5_shot.append(preds[:,1])
            prods_5_shot.append(prods)
        preds = []
        #                 print("np.array(preds_5_shot).T:", np.array(preds_5_shot).T)
        for line in np.array(preds_5_shot).T:
            pass
            preds.append(np.argmax(np.bincount(line)))
        #             utils.confusion_plot(np.array(preds),data.y_test) 
        prod_preds = np.argmin(np.sum(prods_5_shot,axis=0),axis=1).reshape(-1)

        score_5_shot = accuracy_score(y_test,np.array(preds))*100
        print('5_shot:',score_5_shot)

        score_5_shot_prod = accuracy_score(y_test,prod_preds)*100
        print('5_shot_prod:',score_5_shot_prod)

        scores_1_shot.append(scores[0])
        scores_5_shot.append(score_5_shot)
        scores_5_shot_prod.append(score_5_shot_prod)

        a =pd.DataFrame(np.array(scores_1_shot).reshape(-1))
        a.to_csv("scores_1_shot.csv")

        a =pd.DataFrame(np.array(scores_5_shot).reshape(-1))
        a.to_csv("scores_5_shot.csv")

        a =pd.DataFrame(np.array(scores_5_shot_prod).reshape(-1))
        a.to_csv("scores_5_shot_prod.csv")


**************************************************
771
84.91189427312776 (908, 2) (908, 10)
**************************************************
758
83.48017621145374 (908, 2) (908, 10)
**************************************************
767
84.47136563876651 (908, 2) (908, 10)
**************************************************
760
83.70044052863436 (908, 2) (908, 10)
**************************************************
761
83.81057268722466 (908, 2) (908, 10)
5_shot: 86.78414096916299
5_shot_prod: 84.91189427312776
**************************************************
753
82.9295154185022 (908, 2) (908, 10)
**************************************************
756
83.25991189427313 (908, 2) (908, 10)
**************************************************
748
82.37885462555066 (908, 2) (908, 10)
**************************************************
749
82.48898678414096 (908, 2) (908, 10)
**************************************************
738
81.27753303964758 (908, 2) (908, 10)
5_shot: 85.0220264317180

In [10]:
# import torch
# exps = [60,90,120,200,300,600,900]
# for exp in exps:
#     checkpoint = torch.load('./checkpoints/best_model_%s.pth'% exp)
#     print(checkpoint['instance_acc'])