In [1]:
import numpy as np
import os.path as osp

from general_tools.in_out.basics import create_dir
from tf_lab.evaluate.generative_pc_nets import entropy_of_occupancy_grid, jensen_shannon_divergence
from tf_lab.nips.helper import center_pclouds_in_unit_sphere, pclouds_centered_and_half_sphere, pclouds_with_zero_mean_in_unit_sphere

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
n_pc_samples = 2048
cmp_in_sphere = True
voxel_resolution = 28
class_name = 'chair'
save_res = False

def identity(x):
    return x

# pc_normalizer = pclouds_with_zero_mean_in_unit_sphere
pc_normalizer = identity

In [26]:
is_lwgan = False
is_rgan = False
is_lgan = True

assert(np.sum([is_lwgan, is_rgan, is_lgan]) == 1)

if is_lwgan:
    epochs_to_check = np.hstack([np.array([1, 5, 10]), np.arange(100, 2001, 100)])
    gan_tag = 'l_wgan_gp'
elif is_rgan:
    epochs_to_check = np.hstack([np.array([1, 5, 10]), np.arange(50, 2001, 50)])
    gan_tag = 'r_gan'
elif is_lgan:
    epochs_to_check = np.hstack([np.array([1, 5, 10]), np.arange(100, 2001, 100)])
    gan_tag = 'l_gan'
    

epochs_to_check = [1000]

In [7]:
# Load Ground-Truth Train Data.
top_gt_dir = '/orions4-zfs/projects/optas/DATA/OUT/iclr/evaluations/gt_data/'
gt_train_file = osp.join(top_gt_dir, class_name + '_train.npz')
gt_train_data = np.load(gt_train_file)  
gt_train_data = pc_normalizer(gt_train_data[gt_train_data.keys()[0]])
_, gt_train_var = entropy_of_occupancy_grid(gt_train_data, voxel_resolution, in_sphere=cmp_in_sphere)    

# Load Ground-Truth Test Data.
gt_test_file = osp.join(top_gt_dir, class_name + '_test.npz')
gt_test_data = np.load(gt_test_file)  
gt_test_data = pc_normalizer(gt_test_data[gt_test_data.keys()[0]])
_, gt_test_var = entropy_of_occupancy_grid(gt_test_data, voxel_resolution, in_sphere=cmp_in_sphere)

In [27]:
# Specify where to load synthetic data.
top_in_dir = '/orions4-zfs/projects/optas/DATA/OUT/iclr/synthetic_samples/'

if is_lwgan:
    top_synthetic_dir = osp.join(top_in_dir, 'l_w_gan_chair_disc_512_1024_emd_bneck_128_this_was_also_using_3pc_ae/lam_10/')
    special_tag = 'chair_disc_512_1024_emd_bneck_128_lam_10'
elif is_rgan:
    top_synthetic_dir = osp.join(top_in_dir, 'r_gan/chair_test_raw_gan_2048_pts')
#                                  chair_mlp_disc_4_fc_gen_raw_gan_2048_pts/')    
    special_tag = 'chair_mlp_disc_4_fc_gen_raw_gan_2048_pts'
    
elif is_lgan:        
    top_synthetic_dir = osp.join(top_in_dir, 'l_gan_chair_mlp_with_split_1pc_usampled_bnorm_on_encoder_only_emd_bneck_128')
    special_tag = 'l_gan_chair_mlp_with_split_1pc_usampled_bnorm_on_encoder_only_emd_bneck_128'
    
top_out_dir = '/orions4-zfs/projects/optas/DATA/OUT/iclr/evaluations/jsd'
create_dir(top_out_dir)

'/orions4-zfs/projects/optas/DATA/OUT/iclr/evaluations/jsd'

In [28]:
if save_res:
    out_file = '_'.join([gan_tag, special_tag, pc_normalizer.__name__])
    out_file = osp.join(top_out_dir, out_file + '.txt')
    fout = open(out_file, 'w', 1)
    fout.write('#Metric Epoch (Train-Test) Measurements\n')
    print 'Saving measurements at: ' + out_file
    
for epoch in epochs_to_check:
    sample_file = osp.join(top_synthetic_dir, 'epoch_%d.npz' % (epoch,) )
    sample_data = np.load(sample_file)
    sample_data = sample_data[sample_data.keys()[0]]
    sample_data = pc_normalizer(sample_data)
                    
    _, sample_var = entropy_of_occupancy_grid(sample_data, voxel_resolution, in_sphere=cmp_in_sphere)        
    jsd_train = jensen_shannon_divergence(gt_train_var, sample_var)
    jsd_test = jsd_train
#     jensen_shannon_divergence(gt_test_var, sample_var)
        
    log_data = 'JSD %d %f %f' % (epoch, jsd_train, jsd_test)
    print log_data
    if save_res:
        fout.write(log_data + '\n')
if save_res:
    fout.close()

JSD 1000 0.076609 0.076609


In [None]:
# OLD STUFF
# import warnings
# import os
# # Ignore TF related warnings.
# warnings.filterwarnings("ignore")
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# parser = argparse.ArgumentParser()
# parser.add_argument('--sample_dir', type=str, default = '', help='Directory of point-cloud samples.', required=True)
# parser.add_argument('--ref', type=str, default = '', help='Path to reference point-cloud.', required=True)
# parser.add_argument('--out_file', type=str, help='Save results in this file.', required=True)
# parser.add_argument('--epochs', type=list, default = [1, 3, 10, 30, 100, 300, 400, 500], help='Epochs to evaluate.')
# opt = parser.parse_args()