In [1]:
'''
Standford CS233, HW4.

Note: Code tested with Python 2.7, adapted to work with Python 3.0 too.

Written by Panos Achlioptas, 2018.
'''

import tensorflow as tf
import numpy as np
import os.path as osp
import matplotlib.pylab as plt
import six

from hw4_code.numpy_dataset import NumpyDataset
from hw4_code.cs233_point_auto_encoder import cs233PointAutoEncoder
from hw4_code.neural_net import Neural_Net_Conf
from hw4_code.encoders_decoders import encoder_with_convs_and_symmetry, decoder_with_fc_only
from hw4_code.in_out_utils import unpickle_data, pickle_data, create_dir


# Students: If you run your code with Python3 instead of Python2 (like in Azzure) set python_2to3=True
python_2to3 = False


# Students: Default options for splits. Do NOT change.
split_loads = [0.75, 0.15, 0.10]
random_seed = 42
verbose = True
n_total_shapes = 1000

# Training options.
do_training = True
batch_size = 50        # Students: do NOT change this one.
held_out_step = 10
n_epochs = 400

# Loading DATA/employing train/val/test splits.
hw4_data = six.next(unpickle_data('data/in/part_labeled_point_clouds.pkl', python_2to3))
sids = six.next(unpickle_data('data/in/randomized_ids.pkl', python_2to3))
hw4_data.part_masks = hw4_data.part_masks.astype(np.int32)

hw4_data = hw4_data.subsample(n_total_shapes, replace=False, seed=random_seed)
net_data = {}
for s in sids:
    idx = sids[s]
    net_data[s] = hw4_data.extract(idx)
test_data = net_data['test'].freeze()


# Define Configuration of the Point-AE.
pc_ae_conf = Neural_Net_Conf()

n_pc_per_model = net_data['train'].pcs.shape[-2]
bneck = 128

pc_ae_conf.encoder = encoder_with_convs_and_symmetry
pc_ae_conf.decoder = decoder_with_fc_only

pc_ae_conf.n_points = n_pc_per_model

pc_ae_conf.encoder_args = {'n_filters': 5,
                           'filter_sizes': [32, 64, 64, 128, 128],
                           'verbose': False} # Students add your encoder's options.

pc_ae_conf.decoder_args = {'layer_sizes': [256, 256, n_pc_per_model * 3],                           
                           'verbose': False}

pc_ae_conf.learning_rate = 0.0009
pc_ae_conf.saver_max_to_keep = 1
pc_ae_conf.allow_gpu_growth = True


# Students: Will it predict part-segmentation too? If so, set to true. 
pc_ae_conf.use_parts = False
pc_ae_conf.n_parts = 4
pc_ae_conf.part_pred_with_one_layer = True


# How much is the relative importance of part-prediction vs. pc-reconstruction.
pc_ae_conf.part_weight = 0.005 #Students: leave this option unchanged for the (non-bonus) questions.

if pc_ae_conf.use_parts:
    if pc_ae_conf.part_pred_with_one_layer:
        pc_ae_conf.name = 'pc_aware_ae'
    else:
        pc_ae_conf.name = 'pc_aware_ae_bonus'
    n_losses = 2
else:
    pc_ae_conf.name = 'vanilla_ae'
    n_losses = 1


# Establish tensor-flow graph.
ae = cs233PointAutoEncoder(pc_ae_conf.name, pc_ae_conf)


if do_training:
    save_dir = create_dir(osp.join('../data/out/Neural_nets', pc_ae_conf.name))    
    tf.set_random_seed(random_seed)
    np.random.seed(random_seed)
    with open(osp.join(save_dir, 'net_stats.txt'), 'w') as file_out:            
        train_loss, val_loss, test_loss = ae.train_model(net_data, n_epochs, batch_size, save_dir,\
                                                         held_out_step, fout=file_out)

('Training epoch/loss/duration: ', 1, array([ 0.03178302]), 5.2698869705200195)
('Training epoch/loss/duration: ', 2, array([ 0.00595877]), 2.6379458904266357)
('Training epoch/loss/duration: ', 3, array([ 0.00417842]), 2.6383249759674072)
('Training epoch/loss/duration: ', 4, array([ 0.00386667]), 2.629732847213745)
('Training epoch/loss/duration: ', 5, array([ 0.00375417]), 2.632874011993408)
('Training epoch/loss/duration: ', 6, array([ 0.00370139]), 2.624575138092041)
('Training epoch/loss/duration: ', 7, array([ 0.00362947]), 2.621664047241211)
('Training epoch/loss/duration: ', 8, array([ 0.00358747]), 2.640124797821045)
('Training epoch/loss/duration: ', 9, array([ 0.00354283]), 2.64296293258667)
('Training epoch/loss/duration: ', 10, array([ 0.00350767]), 2.6305699348449707)
('Val/Test epoch/loss:', 10, array([ 0.00360934]), array([ 0.00385782]))
('Training epoch/loss/duration: ', 11, array([ 0.00348691]), 2.631300926208496)
('Training epoch/loss/duration: ', 12, array([ 0.0034

('Training epoch/loss/duration: ', 95, array([ 0.00122414]), 2.677034854888916)
('Training epoch/loss/duration: ', 96, array([ 0.00121391]), 2.687134027481079)
('Training epoch/loss/duration: ', 97, array([ 0.00120778]), 2.6720571517944336)
('Training epoch/loss/duration: ', 98, array([ 0.00120461]), 2.6776950359344482)
('Training epoch/loss/duration: ', 99, array([ 0.00124788]), 2.6928088665008545)
('Training epoch/loss/duration: ', 100, array([ 0.00122546]), 2.681439161300659)
('Val/Test epoch/loss:', 100, array([ 0.00138094]), array([ 0.00154701]))
('Training epoch/loss/duration: ', 101, array([ 0.00119061]), 2.6799559593200684)
('Training epoch/loss/duration: ', 102, array([ 0.0011785]), 2.6708719730377197)
('Training epoch/loss/duration: ', 103, array([ 0.00119411]), 2.675503969192505)
('Training epoch/loss/duration: ', 104, array([ 0.00118489]), 2.6777279376983643)
('Training epoch/loss/duration: ', 105, array([ 0.00117423]), 2.6735920906066895)
('Training epoch/loss/duration: ',

('Training epoch/loss/duration: ', 188, array([ 0.0009382]), 2.679776906967163)
('Training epoch/loss/duration: ', 189, array([ 0.0009432]), 2.685349941253662)
('Training epoch/loss/duration: ', 190, array([ 0.00093097]), 2.685971975326538)
('Val/Test epoch/loss:', 190, array([ 0.00124719]), array([ 0.00140964]))
('Training epoch/loss/duration: ', 191, array([ 0.00096467]), 2.673844814300537)
('Training epoch/loss/duration: ', 192, array([ 0.00094737]), 2.687716007232666)
('Training epoch/loss/duration: ', 193, array([ 0.00095161]), 2.6754558086395264)
('Training epoch/loss/duration: ', 194, array([ 0.00093548]), 2.6843888759613037)
('Training epoch/loss/duration: ', 195, array([ 0.00092504]), 2.6837940216064453)
('Training epoch/loss/duration: ', 196, array([ 0.00092858]), 2.6907198429107666)
('Training epoch/loss/duration: ', 197, array([ 0.00092654]), 2.685966968536377)
('Training epoch/loss/duration: ', 198, array([ 0.00092346]), 2.68147611618042)
('Training epoch/loss/duration: ',

('Val/Test epoch/loss:', 280, array([ 0.00117187]), array([ 0.00131139]))
('Training epoch/loss/duration: ', 281, array([ 0.0008056]), 2.6795310974121094)
('Training epoch/loss/duration: ', 282, array([ 0.00079888]), 2.6949198246002197)
('Training epoch/loss/duration: ', 283, array([ 0.0007971]), 2.6892471313476562)
('Training epoch/loss/duration: ', 284, array([ 0.00081624]), 2.6853699684143066)
('Training epoch/loss/duration: ', 285, array([ 0.00082176]), 2.683816909790039)
('Training epoch/loss/duration: ', 286, array([ 0.00081429]), 2.6868410110473633)
('Training epoch/loss/duration: ', 287, array([ 0.00080725]), 2.6802918910980225)
('Training epoch/loss/duration: ', 288, array([ 0.00080003]), 2.676831007003784)
('Training epoch/loss/duration: ', 289, array([ 0.00079906]), 2.6878631114959717)
('Training epoch/loss/duration: ', 290, array([ 0.00079429]), 2.6910319328308105)
('Val/Test epoch/loss:', 290, array([ 0.00116881]), array([ 0.00131248]))
('Training epoch/loss/duration: ', 2

('Training epoch/loss/duration: ', 373, array([ 0.00073345]), 2.6685729026794434)
('Training epoch/loss/duration: ', 374, array([ 0.00074872]), 2.682281017303467)
('Training epoch/loss/duration: ', 375, array([ 0.00078547]), 2.6599531173706055)
('Training epoch/loss/duration: ', 376, array([ 0.00077605]), 2.663541078567505)
('Training epoch/loss/duration: ', 377, array([ 0.00074447]), 2.6764800548553467)
('Training epoch/loss/duration: ', 378, array([ 0.00073466]), 2.664583921432495)
('Training epoch/loss/duration: ', 379, array([ 0.00071939]), 2.678179979324341)
('Training epoch/loss/duration: ', 380, array([ 0.00071588]), 2.672420024871826)
('Val/Test epoch/loss:', 380, array([ 0.00115106]), array([ 0.00128983]))
('Training epoch/loss/duration: ', 381, array([ 0.00071894]), 2.669562816619873)
('Training epoch/loss/duration: ', 382, array([ 0.00072132]), 2.6917450428009033)
('Training epoch/loss/duration: ', 383, array([ 0.00073354]), 2.672029972076416)
('Training epoch/loss/duration: