In [None]:
import tensorflow as tf
import numpy as np
from utils import pc_io
from tqdm.notebook import tqdm, trange
import os
from datetime import datetime
from pathlib import Path
from utils.perceptual_model import PerceptualModel, input_fn
from utils.distance_grid import distance_grid_3d

In [None]:
# Constants and parameters
train_glob = '../ModelNet40_200_pc512_oct3_4k/**/*.ply'

dense_tensor_shape = np.array([1, 64, 64, 64])
resolution = 64
alpha = 0.75
gamma = 2.0
data_format = 'channels_first'
# binary: binary occupancy map
# tdf: truncated distance field
# To reproduce results, two models should be trained: one for binary blocks and one for tdf blocks
# To do so, comment/uncomment the corresponding lines below
# Model for binary
# data_mode = 'binary'
# checkpoint_dir = 'data/model'
# Model for tdf
data_mode = 'tdf'
checkpoint_dir = 'data/model_tdf'
# tdf upper bound
tdf_ub = 3
# Training
batch_size = 32
validation_interval = 500
early_stop_patience = validation_interval * 4
validation_steps = 8
summary_interval = 250
max_steps = 100000

In [None]:
# Load dataset
files = pc_io.get_files(train_glob)
assert len(files) > 0
points = pc_io.load_points(files)
print(f'Loaded {len(files)} files')

In [None]:
files_cat = np.array([os.path.split(os.path.split(x)[0])[1] for x in files])
points_train = points[files_cat == 'train']
points_val = points[files_cat == 'test']

In [None]:
tf.reset_default_graph()

In [None]:
if data_mode == 'tdf':
    points_train_tdf = [distance_grid_3d(x, dense_tensor_shape[1:], ub=tdf_ub).astype(np.float32) for x in tqdm(points_train)]
    points_val_tdf = [distance_grid_3d(x, dense_tensor_shape[1:], ub=tdf_ub).astype(np.float32) for x in tqdm(points_val)]
    points_train_tdf = [x[np.newaxis] for x in points_train_tdf]
    points_val_tdf = [x[np.newaxis] for x in points_val_tdf]

In [None]:
def input_fn_tdf(points, batch_size, dense_tensor_shape, data_format, repeat=True, shuffle=True, prefetch_size=1):
    # Create input data pipeline.
    with tf.device('/cpu:0'):
        dataset = tf.data.Dataset.from_generator(lambda: iter(points), tf.float32, tf.TensorShape(dense_tensor_shape))
        if shuffle:
            dataset = dataset.shuffle(len(points))
        if repeat:
            dataset = dataset.repeat()
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(prefetch_size)

    return dataset

In [None]:
if data_mode == 'binary':
    train_ds = input_fn(points_train, batch_size, dense_tensor_shape, data_format, repeat=True, shuffle=True)
    val_ds = input_fn(points_val, batch_size, dense_tensor_shape, data_format, repeat=True, shuffle=True)
elif data_mode == 'tdf':
    train_ds = input_fn_tdf(points_train_tdf, batch_size, dense_tensor_shape, data_format, repeat=True, shuffle=True)
    val_ds = input_fn_tdf(points_val_tdf, batch_size, dense_tensor_shape, data_format, repeat=True, shuffle=True)
else:
    raise RuntimeError(f'Unknown data mode {data_mode}')

In [None]:
# Initialize input pipeline
train_iterator = tf.data.make_one_shot_iterator(train_ds)
val_iterator = tf.data.make_one_shot_iterator(val_ds)
handle = tf.placeholder(tf.string, shape=[], name='handle')
output_shapes = tf.data.get_output_shapes(train_ds)
iterator = tf.data.Iterator.from_string_handle(handle, tf.data.get_output_types(train_ds))
x = tf.placeholder_with_default(iterator.get_next(), output_shapes, name='x')

In [None]:
model = PerceptualModel()
x_tilde = model(x)

In [None]:
# Init optimization
if data_mode == 'binary':
    from utils.focal_loss import focal_loss
    train_loss = focal_loss(x, x_tilde, alpha=alpha, gamma=gamma)
elif data_mode == 'tdf':
    loss_mask = tf.maximum(tf.minimum(tf.cast(x < 1, tf.float32), alpha), 1 - alpha)
    train_loss = tf.reduce_mean(tf.square(x - x_tilde) * loss_mask)
else:
    raise RuntimeError(f'Unknown data mode {data_mode}')

step = tf.train.get_or_create_global_step()
main_optimizer = tf.train.AdamOptimizer()
train_op = main_optimizer.minimize(train_loss, global_step=step)

In [None]:
graph = tf.get_default_graph()
model_name = model.name
tensors = {
    'x': x,
    'x1': graph.get_tensor_by_name(f'{model_name}/conv3d_0/Relu:0'),
    'x2': graph.get_tensor_by_name(f'{model_name}/conv3d_1/Relu:0'),
    'y0': graph.get_tensor_by_name(f'{model_name}/conv3d_2/Relu:0'),
    'y1': graph.get_tensor_by_name(f'{model_name}/conv3dt_0/Relu:0'),
    'y2': graph.get_tensor_by_name(f'{model_name}/conv3dt_1/Relu:0'),
    'x_tilde': graph.get_tensor_by_name(f'{model_name}/conv3dt_2/Sigmoid:0'),
}
for name in tensors:
    tf.summary.histogram(name, tensors[name])
tf.summary.scalar('train_loss', train_loss)
merged_summary = tf.summary.merge_all()

In [None]:
[n.name for n in tf.get_default_graph().as_graph_def().node]

In [None]:
# Train model
# Summary writers
train_writer = tf.summary.FileWriter(os.path.join(checkpoint_dir, 'train'))
val_writer = tf.summary.FileWriter(os.path.join(checkpoint_dir, 'val'))

# Checkpoints
saver = tf.train.Saver(save_relative_paths=True)
checkpoint_path = os.path.join(checkpoint_dir, 'model.ckpt')
# Init
init = tf.global_variables_initializer()

print('Starting session')
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
with tf.Session(config=tf_config) as sess:
    print('Init session')
    sess.run(init)

    train_handle, test_handle = sess.run([train_iterator.string_handle(), val_iterator.string_handle()])

    checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    if checkpoint is not None:
        print(f'Restoring checkpoint {checkpoint}')
        saver.restore(sess, checkpoint)
    train_writer.add_graph(sess.graph)

    step_val = sess.run(step)
    first_step_val = step_val
    pbar = tqdm(total=max_steps)
    print(f'Starting training')
    best_loss = 1e+32
    best_loss_step = step_val
    while step_val <= max_steps:
        pbar.update(step_val - pbar.n)

        # Validation
        if step_val != first_step_val and step_val % validation_interval == 0:
            print(f'{datetime.now().isoformat()} Executing validation')
            losses = []
            for i in trange(validation_steps):
                summary, vloss = sess.run([merged_summary, train_loss], feed_dict={handle: test_handle})
                losses.append(vloss)
                val_writer.add_summary(summary, step_val + i)
            loss = np.mean(losses)
            print('')

            # Early stopping
            if (loss - best_loss) / best_loss < -1e-3:
                print(f'Val loss {loss:.3E}@{step_val} lower than previous best {best_loss:.3E}@{best_loss_step}')
                best_loss_step = step_val
                best_loss = loss
                save_path = saver.save(sess, checkpoint_path, global_step=step_val)
                print(f'Model saved to {save_path}')
            elif step_val - best_loss_step >= early_stop_patience:
                print(f'Val loss {loss:.3E}@{step_val} higher than previous best {best_loss:.3E}@{best_loss_step}')
                print(f'Early stopping')
                break
            else:
                print(f'Val loss {loss:.3E}@{step_val} higher than previous best {best_loss:.3E}@{best_loss_step}')

        # Training
        get_summary = step_val % summary_interval == 0
        sess_args = {'train_op': train_op, 'train_loss': train_loss}
        if get_summary:
            sess_args['merged_summary'] = merged_summary
        sess_output = sess.run(sess_args, feed_dict={handle: train_handle})

        step_val += 1
        if get_summary:
            train_writer.add_summary(sess_output['merged_summary'], step_val)

        pbar.set_description(f"loss: {sess_output['train_loss']:.3E}")


In [None]:
Path(os.path.join(checkpoint_dir, 'done')).touch()