In [None]:
from __future__ import print_function
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from data import CTDataset
from regression import PointNetReg
import torch.nn.functional as F

batch_size = 2
workers = 4
threshold_min = 1700
threshold_max = 2700
npoints = 50000

test_dataset = CTDataset(root='../data',
                         threshold_min=int(threshold_min),
                         threshold_max=int(threshold_max),
                         npoints=npoints,
                         train=False, dim4=True)

print("# of  testing examples: {0}".format(len(test_dataset)))
num_classes = test_dataset.nclasses
print("# of    object classes: {0}".format(num_classes))

testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=int(workers))

In [None]:
model = "./pth/model_99.pth"

classifier = PointNetReg(num_points=npoints)
classifier.load_state_dict(torch.load(model))
classifier.cuda()
None

In [None]:
i, data = next(enumerate(testdataloader, 0))
points, target, centroid, index = data
points, centroid = Variable(points), Variable(centroid)
points = points.transpose(2, 1)
points, centroid = points.cuda(), centroid.cuda()
pred, _ = classifier(points)
pred = pred.view(-1)
centroid = centroid.view(-1)

loss = F.mse_loss(pred, centroid)
print("test loss: {0} ".format(
    loss.item()
))

In [None]:
import visualize

data = points[0, :, :].transpose(1, 0).cpu().numpy()
prediction = pred.view(batch_size, -1)[0, :].data.cpu().numpy()
print(prediction)

In [None]:
gt = target.view(batch_size, -1)[0, :].cpu().numpy()
visualize.scatter_with_target(data, gt)

In [None]:
from utils import region_grow
seed = tuple(np.round(prediction).astype(np.int))
print(seed)

_, _, volume, centroid, _, _ = test_dataset.load(index[0])


seg = region_grow(volume, seed, 1)

result = np.argwhere(seg == True)

visualize.scatter(result)

In [None]:
gt_seed = tuple(np.round(centroid).astype(np.int))
gt_seg = region_grow(volume, gt_seed, 1)
gt_result = np.argwhere(gt_seg == True)
visualize.scatter(gt_result)