In [None]:
import tensorflow as tf

import numpy as np
import scipy.sparse

import os
from os.path import join
import sys
sys.path.append('../')

from networks.spectral import spectral_ae

from training.model_config import AttrDict
from training.spectral import data_loader
from training.spectral.log import *
from training.spectral.loss import loss_function

from utils.utils import export_obj

# Suppress Tensorflow 2.0.0 deprecation warnings.
tf.logging.set_verbosity(tf.logging.ERROR)

## Load the mesh data

In [None]:
tf.reset_default_graph()

# Define the training parameters.
def define_config():
    config = AttrDict()
    
    config.latent_variable = 64
    config.filters = [16, 32, 32, 48]
    config.sampling_steps = len(config.filters)
    config.poly_order = [3] * config.sampling_steps

    config.n_epochs = 1
    config.batch_size = 64
    config.lr = 0.001
    config.l2_reg = 0.00005
    config.z_l2_penalty = 0.0000005

    config.batch_norm = False
    config.residual = False

    config.type = 'sampling_{}'.format(config.sampling_steps)
    config.optimizer = 'AdamW'

    config.info = "data: synth"
    return config

config = define_config()
latent_variable, sampling_steps, filters, n_epochs = config.latent_variable, config.sampling_steps, config.filters, config.n_epochs
poly_order, batch_size, lr, l2_reg, batch_norm = config.poly_order, config.batch_size, config.lr, config.l2_reg, config.batch_norm
z_l2_penalty = config.z_l2_penalty

model_id = 'z64_d4_1550943246.5350132'

# Define file paths.
ROOT = '..'
DATA_DIR = 'data'
DATASET_PATH = join(ROOT, DATA_DIR, 'datasets/mesh-samples', 'data_splits_sampler.pkl')

TENPLATE_DATA_PATH = join(ROOT, DATA_DIR, 'template')
GRAPH_STRUCTURE = join(TENPLATE_DATA_PATH, config.type)
TRILIST_PATH = join(TENPLATE_DATA_PATH, 'trilist.npy')
OUTPUT_PATH = join(ROOT, DATA_DIR, 'models/spectral-ae', model_id)
CHECKPOINT_PATH = join(OUTPUT_PATH, 'models')

# Get the trilist for exporting the meshes.
trilist = np.load(TRILIST_PATH)

## Load the test data

In [None]:
test_db_np, mean_points, std_points = data_loader.load_test_data(DATASET_PATH)

# Create dataset iterators
test_db, fp_test, lp_test = data_loader.create_dataset(test_db_np, n_epochs=1, batch_size=batch_size, reshuffle=False)

handle, iterator = data_loader.create_feedable_iterator(test_db)
next_X, next_Y = iterator.get_next()
test_db_it = test_db.make_initializable_iterator()

# Load spectral operators.
L, A, D, U, p = data_loader.load_spectral_operators(GRAPH_STRUCTURE)

## Load the model

In [None]:
# Build the spectral autoencoder.
is_train = tf.placeholder(tf.bool, name="is_train")
net, mesh_embedding = spectral_ae.build_model(next_X, L, D, U, A, filters, latent_variable, poly_order, lr, is_train, batch_norm=batch_norm)
loss = loss_function(net, next_Y, l2_reg, z_l2_penalty, mesh_embedding)

global_step = tf.Variable(0, name='global_step', trainable=False)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.control_dependencies(update_ops):
    opt = tf.contrib.opt.AdamWOptimizer(learning_rate=lr, weight_decay=0.000001).minimize(loss, global_step=global_step)

init = tf.global_variables_initializer()

with tf.Session() as sess:

    test_handle = sess.run(test_db_it.string_handle())

    sess.run(init, feed_dict={handle: test_handle})
    sess.run(test_db_it.initializer, feed_dict={fp_test: test_db_np, 
                                                lp_test: test_db_np})

    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)
    saver.restore(sess, join(CHECKPOINT_PATH, 'mesh_ae.1550965497.0982995-599'))

    predictions, gt_X, label = sess.run([net, next_X, next_Y], feed_dict={handle: test_handle, 
                                                                          is_train: False})

## Reconstruct meshes

In [None]:
PREDICTIONS_PATH = join(OUTPUT_PATH, 'mesh-predictions')

reconstruction = predictions * std_points + mean_points
gt = gt_X * std_points + mean_points

In [None]:
n_samples = 8
for index in range(n_samples):
    
    export_obj(reconstruction[index], trilist, 
                     join(PREDICTIONS_PATH, '{}_pred.obj'.format(index)))

    export_obj(gt[index], trilist, 
                     join(PREDICTIONS_PATH, '{}_gt.obj'.format(index)))

## Encode meshes

In [None]:
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, join(CHECKPOINT_PATH, 'mesh_ae.1550965497.0982995-599'))
    
    embedding = mesh_embedding.eval(feed_dict={next_X: gt_X[:32], is_train: False})

## Interpolate meshes

In [None]:
INTERPOLATE_PATH = join(OUTPUT_PATH, 'mesh-interpolate')

n_iterations = 5
source = embedding[5]
target = embedding[10]

with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, join(CHECKPOINT_PATH, 'mesh_ae.1550965497.0982995-599'))

    for iteration in range(n_iterations + 1):
        z_interpolate = source + (target - source) * iteration / n_iterations
        z_interpolate = np.expand_dims(z_interpolate, axis=0)
        output_interpolate = net.eval(feed_dict={mesh_embedding: z_interpolate, is_train: False})
        
        mesh_interpolate = output_interpolate[0] * std_points + mean_points
        
        export_obj(mesh_interpolate, trilist, 
                 join(INTERPOLATE_PATH, '{}.obj'.format(iteration)))

## Visualize the latent space

In [None]:
from mayavi import mlab

from traits.api import HasTraits, Range, Instance, on_trait_change, Button, Enum, Bool
from traitsui.api import View, Item, Group, VGroup, HGroup, HSplit, Tabbed
from mayavi.core.ui.api import MayaviScene, SceneEditor, MlabSceneModel
from traitsui.menu import RevertButton
from traits.api import HasTraits, Property, Array, Font
from traitsui.api import View, Item, TabularEditor
from traitsui.tabular_adapter import TabularAdapter

class Visualization(HasTraits):
    weird_hack =  ",".join(["shape_{}".format(ii) for ii in range(latent_variable)])
    
    scene = Instance(MlabSceneModel, ())
    btn_reset_shape = Button('Reset Shape')
    pose_deformations = Bool
    
    shape_sliders = [Item('shape_' + str(ii)) for ii in range(latent_variable)]
    shape_group = VGroup(shape_sliders,
                         Item('btn_reset_shape', show_label=False),
                         label='Shape')
    
    view = View(HSplit(
        shape_group,
        Item('scene', editor=SceneEditor(scene_class=MayaviScene), height=550, width=750, show_label=False)),
                resizable=True)

    def __init__(self, sess, embedding):
        HasTraits.__init__(self)
        self.betas = embedding
        self.sess = sess
        RANGE_INT = 4.
        for ii in range(latent_variable):
            R = Range(-RANGE_INT, RANGE_INT, embedding[0, ii])
            self.add_trait("shape_" + str(ii), R)

        init_v = net.eval(feed_dict={mesh_embedding: embedding, is_train: False})
        self.v = init_v[0] * std_points + mean_points
        
        x, y, z, f = self.v[:, 0], self.v[:, 1], self.v[:, 2], trilist
        self.mesh = mlab.triangular_mesh(x, y, z, f, figure=self.scene.mayavi_scene, color=(0, 1, 1)) 
    
    def update_betas(self):
        self.betas[0, :] = [getattr(self, "shape_{}".format(ii)) for ii in range(latent_variable)]    
    
    def update_plot(self):
        output = net.eval(feed_dict={mesh_embedding: self.betas, is_train: False})
        self.v = output[0] * std_points + mean_points
        x, y, z = self.v[:, 0], self.v[:, 1], self.v[:, 2]
        self.mesh.mlab_source.set(x=x, y=y, z=z)

    @on_trait_change(weird_hack)
    def shape_sliders_action(self):
        self.update_betas()
        self.update_plot()   
        
    def _btn_reset_shape_fired(self):
        self.betas = np.zeros((1, latent_variable))
        for ii in range(latent_variable):
            setattr(self, "shape_{}".format(ii), 0) 
        self.update_betas()
        self.update_plot()


In [None]:
# The UI was implemented for a smaller latent vector so the scroll view might be required.
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, join(CHECKPOINT_PATH, 'mesh_ae.1550965497.0982995-599'))
    z = embedding[0].reshape(1, -1).copy()
#     z = np.zeros((1, latent_variable))
    Visualization(sess, z).configure_traits()