In [1]:
from general_tools.notebook.gpu_utils import setup_one_gpu
GPU = 0
setup_one_gpu(GPU)

Picking GPU 0


In [2]:
import sys
import numpy as np
import os.path as osp
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn import mixture

from geo_tool import Point_Cloud

from general_tools.notebook.tf import reset_tf_graph
from general_tools.in_out.basics import create_dir

from tf_lab.in_out.basics import Data_Splitter, read_saved_epochs
from tf_lab.point_clouds.ae_templates import mlp_architecture_ala_iclr_18, default_train_params
from tf_lab.point_clouds.autoencoder import Configuration as Conf
from tf_lab.point_clouds.point_net_ae import PointNetAutoEncoder
from tf_lab.point_clouds.in_out import load_point_clouds_from_filenames, PointCloudDataSet
from tf_lab.point_clouds.convenience import reconstruct_pclouds, get_latent_codes

from tf_lab.data_sets.shape_net import pc_loader as snc_loader
from tf_lab.data_sets.shape_net import snc_category_to_synth_id
from tf_lab.nips.helper import pclouds_centered_and_half_sphere
from tf_lab.iclr.helper import load_multiple_version_of_pcs, find_best_validation_epoch_from_train_stats

from pcloud_benchmark.evaluate_gan import entropy_of_occupancy_grid, jensen_shannon_divergence

PyTorch not working. MMD measurement won't be available


In [3]:
%load_ext autoreload
%matplotlib inline
%autoreload 2

In [4]:
top_data_dir = '/orions4-zfs/projects/optas/DATA'
class_name = 'chair'
syn_id = snc_category_to_synth_id()[class_name]
n_pc_points = 2048
n_pc_versions = 3 # load ae trained with that many versions of PCs
voxel_resolution = 28
cmp_in_sphere = True

In [5]:
in_data = load_multiple_version_of_pcs('uniform_one', syn_id, n_classes=1)
train_data = in_data['train'].point_clouds
test_data = in_data['test'].point_clouds

# Prepare GT for JSD comparisons
gt_train_data = pclouds_centered_and_half_sphere(train_data)
gt_test_data = pclouds_centered_and_half_sphere(test_data) 

Loading test data.
/orions4-zfs/projects/optas/DATA/Point_Clouds/Shape_Net/Splits/single_class_splits/03001627/85_5_10/test.txt
679 pclouds were loaded. They belong in 1 shape-classes.
Loading train data.
/orions4-zfs/projects/optas/DATA/Point_Clouds/Shape_Net/Splits/single_class_splits/03001627/85_5_10/train.txt
5761 pclouds were loaded. They belong in 1 shape-classes.
Loading val data.
/orions4-zfs/projects/optas/DATA/Point_Clouds/Shape_Net/Splits/single_class_splits/03001627/85_5_10/val.txt
338 pclouds were loaded. They belong in 1 shape-classes.


In [6]:
_, train_grid_var = entropy_of_occupancy_grid(gt_train_data, voxel_resolution, in_sphere=cmp_in_sphere)
train_grid_var += 1

_, test_grid_var = entropy_of_occupancy_grid(gt_test_data, voxel_resolution, in_sphere=cmp_in_sphere)
test_grid_var += 1

In [8]:
ae_loss = 'chamfer'
b_necks = [128, 256]
cov_types = ['full']
full_n_clusters = [12, 14, 16]
# full_n_clusters = [15, 17, 18, 19, 20]
diag_n_clusters = [2, 4, 6, 8, 10, 12, 14, 16]

In [7]:
def load_an_auto_encoder(b_neck, ae_loss, n_pc_versions, n_pc_points=2048):
    # Load Auto-Encoder
    
    ae_experiment_tag = 'mlp_with_split_' + str(n_pc_versions) + 'pc_usampled_bnorm_on_encoder_only'

    ae_id = '_'.join(['ae', class_name, ae_experiment_tag, str(n_pc_points), 'pts', str(b_neck), 'bneck', ae_loss])
    
    ae_train_dir = osp.join(top_data_dir, 'OUT/iclr/nn_models/', ae_id)
    ae_conf = Conf.load(osp.join(ae_train_dir, 'configuration'))
    
    val_error, best_epoch = find_best_validation_epoch_from_train_stats(osp.join(ae_train_dir, 'train_stats.txt'))
        
    if best_epoch % ae_conf.saver_step != 0:
        best_epoch += best_epoch % ae_conf.saver_step

    ae_conf.encoder_args['verbose'] = False
    ae_conf.decoder_args['verbose'] = False

    reset_tf_graph()
    ae = PointNetAutoEncoder(ae_conf.experiment_name, ae_conf)
    ae.restore_model(ae_conf.train_dir, best_epoch, verbose=True)
    return ae

def jsd_on_reconstructed_data(ae_model, pclouds, cmp_grid_var, voxel_resolution, cmp_in_sphere):
    recon, _ = reconstruct_pclouds(ae_model, pclouds, batch_size=100)
    recon = pclouds_centered_and_half_sphere(recon)
    _, recon_grid_var = entropy_of_occupancy_grid(recon, voxel_resolution, in_sphere=cmp_in_sphere)
    recon_grid_var += 1
    return jensen_shannon_divergence(recon_grid_var, cmp_grid_var)

In [9]:
for b_neck in b_necks:
    ae_model = load_an_auto_encoder(b_neck, ae_loss, n_pc_versions)
    latent_codes = get_latent_codes(ae_model, train_data)
    
    print 'bneck size:', b_neck
    j1 = jsd_on_reconstructed_data(ae_model, train_data, train_grid_var, voxel_resolution, cmp_in_sphere)
    j2 = jsd_on_reconstructed_data(ae_model, test_data,  test_grid_var,  voxel_resolution, cmp_in_sphere)
    
    print 'Train-Test JSD of the AE-decoded data:', j1, j2
    
    for cov_t in cov_types:
        if cov_t == 'diag':
            choose_from = diag_n_clusters
        else: 
            choose_from = full_n_clusters
        
        for n_cluster in choose_from:    
            gmm = mixture.GaussianMixture(n_cluster, cov_t)
            gmm.fit(latent_codes)
            sample_codes = gmm.sample(len(latent_codes))[0]
            gmm_pcs = ae_model.decode(sample_codes)
            gmm_pcs = pclouds_centered_and_half_sphere(gmm_pcs)
            _, gmm_grid_var = entropy_of_occupancy_grid(gmm_pcs, voxel_resolution, in_sphere=cmp_in_sphere)
            gmm_grid_var += 1

            tr_jsd = jensen_shannon_divergence(gmm_grid_var, train_grid_var)
            te_jsd = jensen_shannon_divergence(gmm_grid_var, test_grid_var)
            print cov_t, n_cluster, tr_jsd, te_jsd, gmm.bic(latent_codes), gmm.aic(latent_codes)            

Model restored in epoch 490.
bneck size: 128
Train-Test JSD of the AE-decoded data: 0.0187125541936 0.02083304752
full 12 0.0210072940229 0.024661735538 -3994511.48619 -4664519.95944
full 14 0.0209798193459 0.0241189138689 -3915769.66441 -4697447.32633
full 16 0.0204943885502 0.0237265109537 -3824222.41814 -4717569.26876
Model restored in epoch 500.
bneck size: 256
Train-Test JSD of the AE-decoded data: 0.0183122861793 0.020093400088
full 12 0.0206991891001 0.0235051755183 -6903476.4561 -9552606.55035
full 14 0.0210930908708 0.0242374984359 -6523038.32507 -9613691.2115
full 16 0.0207186018536 0.0238218711497 -6129024.40855 -9661200.08716
