In [2]:
import sys
import time
import numpy as np
import os.path as osp
import tensorflow as tf
import matplotlib.pyplot as plt

import tf_lab.point_clouds.in_out as pio
from tf_lab.point_clouds.in_out import PointCloudDataSet
from tf_lab.point_clouds.vae import VariationalAutoencoder
from tf_lab.point_clouds.autoencoder import Configuration as Conf
import tf_lab.models.point_net_based_AE as pnAE
from tf_lab.fundamentals.utils import set_visible_GPUs

from general_tools.in_out.basics import create_dir, delete_files_in_directory
from geo_tool import Point_Cloud

from general_tools.in_out.basics import files_in_subdirs
from tf_lab.autopredictors.scripts.helper import shape_net_category_to_synth_id



In [8]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
experiment_name = 'chamfer_full_shapes_all_SN_vae'

seed = 42
np.random.seed(seed)
tf.set_random_seed(seed)

top_data_dir = '/orions4-zfs/projects/lins2/Panos_Space/DATA/'
full_pclouds_path = osp.join(top_data_dir, 'ShapeNetPointClouds/from_manifold_meshes/1024/')

train_dir = osp.join(top_data_dir, 'OUT/models/variational')
train_dir = osp.join(train_dir, experiment_name)
create_dir(train_dir)

'/orions4-zfs/projects/lins2/Panos_Space/DATA/OUT/models/variational/chamfer_full_shapes_all_SN_vae'

In [5]:
full_file_names = pio.load_filenames_of_input_data(full_pclouds_path)
full_pclouds, full_model_names, class_ids = pio.load_crude_point_clouds(file_names=full_file_names, n_threads=11)
print '%d files containing complete point clouds were found.' % (len(full_pclouds), )

52103 files containing complete point clouds were found.


In [6]:
train_data_, val_data_, test_data_ = pio.train_validate_test_split([full_pclouds, 
                                                                    full_model_names],
                                                                    train_perc=0.95,
                                                                    validate_perc=0.0,
                                                                    test_perc=0.05,                                                                   
                                                                    seed=seed)

train_data = PointCloudDataSet(train_data_[0], labels=train_data_[1])
test_data = PointCloudDataSet(test_data_[0], labels=test_data_[1])

In [None]:
def reset_graph():
    if 'sess' in globals() and sess:
        sess.close()
    tf.reset_default_graph()

set_visible_GPUs([0])

conf = Conf(n_input = [1024, 3],
            training_epochs = 1000,
            batch_size = 45,
            loss = 'chamfer',
            train_dir = train_dir,
            loss_display_step = 1,
            saver_step = 5,
            learning_rate = 0.00005,
            saver_max_to_keep = 500,
            gauss_augment = {'mu': 0, 'sigma': 0.02},
            encoder = pnAE.encoder,
            decoder = pnAE.decoder,
            spatial_trans = True,
            denoising = False,
            n_z = 1024,
            latent_vs_recon = 100,
            z_rotate = True
           )

reset_graph()

vae = VariationalAutoencoder(experiment_name, conf)
vae.train(train_data, conf)

('Epoch:', '0001', 'training time (minutes)=', '1.9959', 'loss=', '13.434676707')
('Epoch:', '0002', 'training time (minutes)=', '1.9957', 'loss=', '0.008600579')
('Epoch:', '0003', 'training time (minutes)=', '2.0012', 'loss=', '0.008249746')
('Epoch:', '0004', 'training time (minutes)=', '2.0025', 'loss=', '0.008151806')
('Epoch:', '0005', 'training time (minutes)=', '2.0067', 'loss=', '0.008115648')
('Epoch:', '0006', 'training time (minutes)=', '2.0008', 'loss=', '0.008099382')
('Epoch:', '0007', 'training time (minutes)=', '2.0034', 'loss=', '0.008085662')
('Epoch:', '0008', 'training time (minutes)=', '2.0043', 'loss=', '0.008075480')
('Epoch:', '0009', 'training time (minutes)=', '2.0019', 'loss=', '0.008052266')
('Epoch:', '0010', 'training time (minutes)=', '2.0101', 'loss=', '0.008046586')
('Epoch:', '0011', 'training time (minutes)=', '1.9992', 'loss=', '0.008033528')