In [1]:
import tensorflow as tf
import numpy as np
import warnings
import os.path as osp
import time

from scipy.stats import entropy
from general_tools.simpletons import iterate_in_chunks

from tf_lab.external.Chamfer_EMD_losses.tf_nndistance import nn_distance
from tf_lab.external.Chamfer_EMD_losses.tf_approxmatch import approx_match, match_cost

import time

In [20]:
def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=False, sess=None, verbose=False, use_sqrt=False, use_EMD=False):
    ''' normalize (boolean): if True the Chamfer distance between two point-clouds is the average of matched
                             point-distances. Alternatively, is their sum.
    '''
    s = time.time()
    if normalize:
        reducer = tf.reduce_mean
    else:
        reducer = tf.reduce_sum

    if sess is None:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

    n_ref, n_pc_points, pc_dim = ref_pcs.shape
    _, n_pc_points_s, pc_dim_s = sample_pcs.shape

    if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
        raise ValueError('Incompatible Point-Clouds.')

    # TF Graph Operations
    ref_pl = tf.placeholder(tf.float32, shape=(1, n_pc_points, pc_dim))
    sample_pl = tf.placeholder(tf.float32, shape=(None, n_pc_points, pc_dim))

#     repeat_times = tf.shape(sample_pl)[0]   # slower- could be used to use entire set of samples.
    repeat_times = batch_size
    ref_repeat = tf.tile(ref_pl, [repeat_times, 1, 1])
    ref_repeat = tf.reshape(ref_repeat, [repeat_times, n_pc_points, pc_dim])

    if not use_EMD:
        ref_to_s, _, s_to_ref, _ = nn_distance(ref_repeat, sample_pl)

        if use_sqrt:
            ref_to_s = tf.sqrt(ref_to_s)
            s_to_ref = tf.sqrt(s_to_ref)

        chamfer_dist_batch = reducer(ref_to_s, 1) + reducer(s_to_ref, 1)
    else:
        match = approx_match(ref_repeat, sample_pl)
        chamfer_dist_batch = reducer(match_cost(ref_repeat, sample_pl, match))

    best_in_batch = tf.reduce_min(chamfer_dist_batch)   # Best distance, of those that were matched to single ref pc.
    print time.time()-s
    matched_dists = []
    for i in xrange(n_ref):
        best_in_all_batches = []
        if verbose and i % 50 == 0:
            print i
        for sample_chunk in iterate_in_chunks(sample_pcs, batch_size):
            if len(sample_chunk) != batch_size:
                continue
            feed_dict = {ref_pl: np.expand_dims(ref_pcs[i], 0), sample_pl: sample_chunk}
            b = sess.run(best_in_batch, feed_dict=feed_dict)
            best_in_all_batches.append(b)

        matched_dists.append(np.min(best_in_all_batches))

    mmd = np.mean(matched_dists)
    sess.close()
    return mmd, matched_dists

In [14]:
a = np.random.randn(100, 2048, 3)
b = np.random.randn(100, 2048, 3)


In [21]:
# batch_size, normalize=False, sess=None, verbose=False, use_sqrt=False, use_EMD=False):
minimum_mathing_distance(a, b, 10, use_EMD=True)

0.0418698787689


(7030.9985,
 [6985.9863,
  7057.9023,
  7506.1123,
  6847.8872,
  7092.9609,
  7093.8433,
  7016.083,
  6828.1416,
  7064.2026,
  7064.2549,
  7090.9551,
  7193.2124,
  6991.4473,
  7210.5063,
  7040.8296,
  6946.7637,
  7059.5664,
  6902.0117,
  6844.4365,
  6846.5977,
  6935.8066,
  7035.7363,
  7070.8657,
  7258.5732,
  7195.7227,
  6931.6025,
  7140.8228,
  6960.124,
  7052.2197,
  6952.4004,
  6794.0679,
  7360.8062,
  7001.1997,
  6977.5947,
  6973.3037,
  7047.3091,
  7049.1719,
  7033.8516,
  6993.2773,
  7025.46,
  7158.8623,
  6870.4238,
  7018.7314,
  7057.2939,
  7042.9062,
  7088.4648,
  6913.9189,
  6997.7402,
  6946.5645,
  7050.0186,
  7168.9268,
  7239.6753,
  6932.8857,
  7014.0449,
  6781.6182,
  7123.126,
  7291.2974,
  6900.0908,
  6990.958,
  7014.8999,
  6950.6646,
  7166.1533,
  7204.5195,
  7126.0498,
  6969.6074,
  7170.2051,
  7153.6201,
  6890.3726,
  6840.1475,
  6939.8008,
  7041.8682,
  6823.1279,
  6933.084,
  7059.9092,
  6968.498,
  6966.4517,
  7009.9