In [1]:
"""Testing On Segmentation Task."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import h5py
import argparse
import importlib
import data_utils
import numpy as np
import tensorflow as tf
from datetime import datetime


  from ._conv import register_converters as _register_converters


In [2]:
args = argparse.Namespace(
    filelist = '../data/Amsterdam/test_files_utrecht2-las.txt',
    load_ckpt = '../models/seg/pointcnn_seg_amsterdam_x4_12288_fps_2019-02-05-22-40-26_20662/ckpts/iter-25173',
#     load_ckpt = '../models/seg/pointcnn_seg_amsterdam_x4_12288_fps_2019-02-10-21-10-01_17799/ckpts/iter-50346',
    max_point_num = 24576,
    repeat_num = 1,
    model = 'pointcnn_seg',
    setting = 'amsterdam_x4_12288_fps',
    save_ply = False
)

In [3]:
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--filelist', '-t', help='Path to input .h5 filelist (.txt)', required=True)
#     parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load', required=True)
#     parser.add_argument('--max_point_num', '-p', help='Max point number of each sample', type=int, default=8192)
#     parser.add_argument('--repeat_num', '-r', help='Repeat number', type=int, default=1)
#     parser.add_argument('--model', '-m', help='Model to use', required=True)
#     parser.add_argument('--setting', '-x', help='Setting to use', required=True)
#     parser.add_argument('--save_ply', '-s', help='Save results as ply', action='store_true')
#     args = parser.parse_args()
#     print(args)


In [4]:
model = importlib.import_module(args.model)
setting_path = os.path.join(args.model)
sys.path.append(setting_path)
setting = importlib.import_module(args.setting)

sample_num = setting.sample_num
max_point_num = args.max_point_num
batch_size = args.repeat_num * math.ceil(max_point_num / sample_num)

######################################################################
# Placeholders
indices = tf.placeholder(tf.int32, shape=(batch_size, None, 2), name="indices")
is_training = tf.placeholder(tf.bool, name='is_training')
pts_fts = tf.placeholder(tf.float32, shape=(batch_size, max_point_num, setting.data_dim), name='points')
######################################################################

######################################################################
pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled')
if setting.data_dim > 3:
    points_sampled, features_sampled = tf.split(pts_fts_sampled,
                                                [3, setting.data_dim - 3],
                                                axis=-1,
                                                name='split_points_features')
    if not setting.use_extra_features:
        features_sampled = None
else:
    points_sampled = pts_fts_sampled
    features_sampled = None

net = model.Net(points_sampled, features_sampled, is_training, setting)
seg_probs_op = tf.nn.softmax(net.logits, name='seg_probs')

# for restore model
saver = tf.train.Saver()

parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))
    

Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
Use the retry module or similar alternatives.
2019-03-08 14:02:05.672273-Parameter number: 3278203.


In [5]:
import time
start_time_dict = {}
total_time_dict = {}

current_time_ms = lambda:int(round(time.time()*1000))

def timer_start(msg): 
    global start_time_dict
    global total_time_dict
    start_time_dict[msg] = current_time_ms()
    if not msg in total_time_dict:
        total_time_dict[msg] = 0
    
def timer_pause(msg):
    global start_time_dict
    global total_time_dict
    total_time_dict[msg] += current_time_ms() - start_time_dict[msg]
    
def timer_stop(msg):
    global total_time_dict
    timer_pause(msg)
    print("{} completed in {}ms".format(msg, total_time_dict[msg]))
    total_time_dict[msg] = 0

In [6]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True

In [7]:
with tf.Session(config=config) as sess:
    # Load the model
    saver.restore(sess, args.load_ckpt)
    print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))

    indices_batch_indices = np.tile(np.reshape(np.arange(batch_size), (batch_size, 1, 1)), (1, sample_num, 1))

    folder = os.path.dirname(args.filelist)
    filenames = [os.path.join(folder, line.strip()) for line in open(args.filelist)]
    for filename in filenames:
        msg = 'Reading {}...'.format(filename)
        timer_start(msg)
        data_h5 = h5py.File(filename)
        timer_stop(msg)
        
        data = data_h5['data'][...].astype(np.float32)
        data_num = data_h5['data_num'][...].astype(np.int32)
        batch_num = data.shape[0]

        labels_pred = np.full((batch_num, max_point_num), -1, dtype=np.int32)
        confidences_pred = np.zeros((batch_num, max_point_num), dtype=np.float32)

        print('{}-{:d} testing batches.'.format(datetime.now(), batch_num))
        for batch_idx in range(batch_num):
            if batch_idx % 100 == 0 and batch_idx > 0 or batch_idx == batch_num-1:
                print('{}-Processing {} of {} batches.'.format(datetime.now(), batch_idx, batch_num))
                
            msg = "Preprocessing"
            timer_start(msg)
            points_batch = data[[batch_idx] * batch_size, ...]
            point_num = data_num[batch_idx]

            tile_num = math.ceil((sample_num * batch_size) / point_num)
            indices_shuffle = np.tile(np.arange(point_num), tile_num)[0:sample_num * batch_size]
            np.random.shuffle(indices_shuffle)
            indices_batch_shuffle = np.reshape(indices_shuffle, (batch_size, sample_num, 1))
            indices_batch = np.concatenate((indices_batch_indices, indices_batch_shuffle), axis=2)
            timer_pause(msg)

            msg = "Inference"
            timer_start(msg)
            seg_probs = sess.run([seg_probs_op],
                                    feed_dict={
                                        pts_fts: points_batch,
                                        indices: indices_batch,
                                        is_training: False,
                                    })
            timer_pause(msg)
            
            msg = "Postprocessing"
            timer_start(msg)
            probs_2d = np.reshape(seg_probs, (sample_num * batch_size, -1))

            predictions = [(-1, 0.0)] * point_num
            for idx in range(sample_num * batch_size):
                point_idx = indices_shuffle[idx]
                probs = probs_2d[idx, :]
                confidence = np.amax(probs)
                label = np.argmax(probs)
                if confidence > predictions[point_idx][1]:
                    predictions[point_idx] = [label, confidence]
            labels_pred[batch_idx, 0:point_num] = np.array([label for label, _ in predictions])
            confidences_pred[batch_idx, 0:point_num] = np.array([confidence for _, confidence in predictions])
            timer_pause(msg)            

        timer_stop("Preprocessing")
        timer_stop("Inference")
        timer_stop("Postprocessing")
            
        filename_pred = filename[:-3] + '_pred.h5'
        msg = 'Saving {}...'.format(filename_pred)
        timer_start(msg)
        file = h5py.File(filename_pred, 'w')
        file.create_dataset('data_num', data=data_num)
        file.create_dataset('label_seg', data=labels_pred)
        file.create_dataset('confidence', data=confidences_pred)
        has_indices = 'indices_split_to_full' in data_h5
        if has_indices:
            file.create_dataset('indices_split_to_full', data=data_h5['indices_split_to_full'][...])
        file.close()
        timer_stop(msg)

        if args.save_ply:
            msg = 'Saving ply of {}...'.format(filename_pred)
            timer_start(msg)
            filepath_label_ply = os.path.join(filename_pred[:-3] + 'ply_label')
            data_utils.save_ply_property_batch(data[:, :, 0:3], labels_pred[...],
                                               filepath_label_ply, data_num[...], setting.num_class)
            timer_stop(msg)
        ######################################################################
        
    print('{}-Done!'.format(datetime.now()))



INFO:tensorflow:Restoring parameters from ../models/seg/pointcnn_seg_amsterdam_x4_12288_fps_2019-02-05-22-40-26_20662/ckpts/iter-25173
2019-03-08 14:02:09.260587-Checkpoint loaded from ../models/seg/pointcnn_seg_amsterdam_x4_12288_fps_2019-02-05-22-40-26_20662/ckpts/iter-25173!
Reading ../data/Amsterdam/./tmp/utrecht_las_00107_zero_0.h5... completed in 1ms
2019-03-08 14:02:09.503222-385 testing batches.
2019-03-08 14:02:40.138440-Processing 100 of 385 batches.
2019-03-08 14:03:07.840315-Processing 200 of 385 batches.
2019-03-08 14:03:34.685784-Processing 300 of 385 batches.
2019-03-08 14:03:57.222791-Processing 384 of 385 batches.
Preprocessing completed in 547ms
Inference completed in 35853ms
Postprocessing completed in 72219ms
Saving ../data/Amsterdam/./tmp/utrecht_las_00107_zero_0_pred.h5... completed in 122ms
Reading ../data/Amsterdam/./tmp/utrecht_las_00107_half_0.h5... completed in 0ms
2019-03-08 14:03:57.904224-395 testing batches.
2019-03-08 14:04:25.874746-Processing 100 of 39