In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import numpy as np
import torch.utils.data
import logging
from visualization.visualization_utils import pyplot_draw_point_cloud
from visualization.visualize_pointnet import make_one_critical

from tqdm import tqdm
from dataset.mydataset import PoisonDataset
from models.pointnet_cls import get_loss, get_model
from config import *
from visualization.customized_open3d import *
from load_data import load_data
import matplotlib.pyplot as plt
import data_utils

manualSeed = random.randint(1, 10000)  # fix seed
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  1590


<torch._C.Generator at 0x7f74b408c570>

In [2]:
categories = {
    0: 'airplane',
    1: 'bathtub',
    2: 'bed',
    3: 'bench',
    4: 'bookshelf',
    5: 'bottle',
    6: 'bowl',
    7: 'car',
    8: 'chair',
    9: 'cone',
    10: 'cup',
    11: 'curtain',
    12: 'desk',
    13: 'door',
    14: 'dresser',
    15: 'flower_pot',
    16: 'glass_box',
    17: 'guitar',
    18: 'keyboard',
    19: 'lamp',
    20: 'laptop',
    21: 'mantel',
    22: 'monitor',
    23: 'night_stand',
    24: 'person',
    25: 'piano',
    26: 'plant',
    27: 'radio',
    28: 'range_hood',
    29: 'sink',
    30: 'sofa',
    31: 'stairs',
    32: 'stool',
    33: 'table',
    34: 'tent',
    35: 'toilet',
    36: 'tv_stand',
    37: 'vase',
    38: 'wardrobe',
    39: 'xbox'
}

In [3]:
batch_size = 32
log_dir = 'train_500_24_modelnet40'
dataset = 'modelnet40'

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)
global x_train, y_train, x_test, y_test, num_classes
if dataset == "modelnet40":
    x_train, y_train, x_test, y_test = load_data()
    num_classes = 40
elif dataset == "scanobjectnn_pb_t50_rs":
    x_train, y_train = data_utils.load_h5("data/h5_files/main_split/training_objectdataset_augmentedrot_scale75.h5")
    x_test, y_test = data_utils.load_h5("data/h5_files/main_split/test_objectdataset_augmentedrot_scale75.h5")
    y_train = np.reshape(y_train, newshape=(y_train.shape[0], 1))
    y_test = np.reshape(y_test, newshape=(y_test.shape[0], 1))
    num_classes = 15
elif dataset == "scanobjectnn_obj_bg":
    x_train, y_train = data_utils.load_h5("data/h5_files/main_split/training_objectdataset.h5")
    x_test, y_test = data_utils.load_h5("data/h5_files/main_split/test_objectdataset.h5")
    y_train = np.reshape(y_train, newshape=(y_train.shape[0], 1))
    y_test = np.reshape(y_test, newshape=(y_test.shape[0], 1))
    num_classes = 15
elif dataset == "scanobjectnn_pb_t50_r":
    x_train, y_train = data_utils.load_h5("data/h5_files/main_split/training_objectdataset_augmentedrot.h5")
    x_test, y_test = data_utils.load_h5("data/h5_files/main_split/test_objectdataset_augmentedrot.h5")
    y_train = np.reshape(y_train, newshape=(y_train.shape[0], 1))
    y_test = np.reshape(y_test, newshape=(y_test.shape[0], 1))
    num_classes = 15
elif dataset == "scanobjectnn_pb_t25_r":
    x_train, y_train = data_utils.load_h5("data/h5_files/main_split/training_objectdataset_augmented25rot.h5")
    x_test, y_test = data_utils.load_h5("data/h5_files/main_split/test_objectdataset_augmented25rot.h5")
    y_train = np.reshape(y_train, newshape=(y_train.shape[0], 1))
    y_test = np.reshape(y_test, newshape=(y_test.shape[0], 1))
    num_classes = 15
elif dataset == "scanobjectnn_pb_t25":
    x_train, y_train = data_utils.load_h5("data/h5_files/main_split/training_objectdataset_augmented25_norot.h5")
    x_test, y_test = data_utils.load_h5("data/h5_files/main_split/test_objectdataset_augmented25_norot.h5")
    y_train = np.reshape(y_train, newshape=(y_train.shape[0], 1))
    y_test = np.reshape(y_test, newshape=(y_test.shape[0], 1))
    num_classes = 15

cpu


In [5]:
def log_string(str):
    logger.info(str)
    print(str)

In [7]:
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
experiment_dir = 'log/classification/' + log_dir
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')

PARAMETER ...


In [8]:
MAX_SAMPLE = 10
classifier = get_model(k=40, normal_channel=False)
classifier.to(device)
test_dataset = PoisonDataset(
    data_set=list(zip(x_test[0:MAX_SAMPLE], y_test[0:MAX_SAMPLE])),
    n_class=NUM_CLASSES,
    target=TARGETED_CLASS,
    name="test",
    is_sampling=False,
    uniform=False,
    data_augmentation=False,
    is_testing=True,
)

Getting original data : 100%|██████████| 10/10 [00:00<00:00, 1152.76it/s]


In [11]:
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    # batch_size=args.batch_size,
    batch_size=1,
    shuffle=False,
    num_workers=8,
)
print(len(test_dataset))

checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth',
                        map_location=lambda storage, loc: storage)
classifier.load_state_dict(checkpoint['model_state_dict'])
classifier.to(device)
classifier = classifier.eval()
sum_correct = 0.0

with torch.no_grad():
    for data in tqdm(test_loader):
        points, label, mask = data
        target = label[:, 0]
        # pyplot_draw_point_cloud(points.numpy().reshape(-1, 3))
        points = points.transpose(2, 1)
        points, target = points.to(device), target.to(device)
        predictions, feat_trans, hx, max_pool = classifier(points)
        points = points.transpose(2, 1)
        hx = hx.transpose(2, 1).cpu().numpy().reshape(-1, 1024)
        print(hx.shape)
        critical_mask = make_one_critical(hx=hx)
        # visualize_point_cloud_critical_point(points.cpu().numpy().reshape(-1, 3), critical_mask)
        pred_choice = predictions.max(1)[1]
        # print(categories[pred_choice.cpu().numpy()[0]])
        correct = pred_choice.eq(target.data).cpu().sum()

        sum_correct += correct

log_string('accuracy: %f' % (sum_correct / len(test_dataset)))


  0%|          | 0/10 [00:00<?, ?it/s]

10


 10%|█         | 1/10 [00:00<00:06,  1.36it/s]

(2048, 1024)


 20%|██        | 2/10 [00:01<00:05,  1.39it/s]

(2048, 1024)


 30%|███       | 3/10 [00:02<00:04,  1.42it/s]

(2048, 1024)


 40%|████      | 4/10 [00:02<00:04,  1.43it/s]

(2048, 1024)


 50%|█████     | 5/10 [00:03<00:03,  1.45it/s]

(2048, 1024)


 60%|██████    | 6/10 [00:04<00:02,  1.45it/s]

(2048, 1024)


 70%|███████   | 7/10 [00:04<00:02,  1.43it/s]

(2048, 1024)


 80%|████████  | 8/10 [00:05<00:01,  1.41it/s]

(2048, 1024)


 90%|█████████ | 9/10 [00:06<00:00,  1.42it/s]

(2048, 1024)


100%|██████████| 10/10 [00:06<00:00,  1.43it/s]

(2048, 1024)
accuracy: 0.900000



