In [1]:
import os
import sys
import math
import utils
import random
import layers
import dataset
import binvox_rw
import numpy as np
import pandas as pd
from PIL import Image
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf
# tf.contrib.eager.enable_eager_execution() 

In [2]:
# place holders
x=tf.placeholder(tf.float32,[None,24,137,137,3]) 
y=tf.placeholder(tf.float32,[None,32,32,32])

# encoder network
cur_tensor=x
encoder_outputs=[cur_tensor]
print(cur_tensor.shape)
k_s = [3,3]
conv_filter_count = [96, 128, 256, 256, 256, 256]
for i in range(6): 
    ks=[7,7]if i is 0 else k_s  
    with tf.name_scope("encoding_block"): 
        cur_tensor=tf.map_fn(lambda a:tf.layers.conv2d(a,filters=conv_filter_count[i],padding='SAME',kernel_size= k_s,activation=None),cur_tensor)
        cur_tensor=tf.map_fn(lambda a:tf.layers.max_pooling2d(a,2,2),cur_tensor)
        cur_tensor=tf.map_fn(tf.nn.relu,cur_tensor)
        print(cur_tensor.shape)
        encoder_outputs.append(cur_tensor)

# flatten tensors
cur_tensor=tf.map_fn(tf.contrib.layers.flatten,cur_tensor)
cur_tensor=tf.map_fn(lambda a:tf.contrib.layers.fully_connected(a,1024,activation_fn=None),cur_tensor)
encoder_outputs.append(cur_tensor)
print(cur_tensor.shape)

(?, 24, 137, 137, 3)
(?, 24, 68, 68, 96)
(?, 24, 34, 34, 128)
(?, 24, 17, 17, 256)
(?, 24, 8, 8, 256)
(?, 24, 4, 4, 256)
(?, 24, 2, 2, 256)
(?, 24, 1024)


In [3]:
# recurrent module
recurrent_module=layers.GRU_R2N2();

# prepare input
cur_tensor=encoder_outputs[-1]
stacked_input=cur_tensor
for i in range(3):
    stacked_input=tf.stack([stacked_input]*4,axis=0)
print(stacked_input.shape)

# initial hidden state
hidden_state= tf.zeros_like(stacked_input[:,:,:,:,0,0:256])

# feed batches of seqeuences
for t in range(24):
    input_frames = stacked_input[:,:,:,:,t,:]
    hidden_state = recurrent_module.call(input_frames, hidden_state)
print(hidden_state.shape)

(4, 4, 4, ?, 24, 1024)
(4, 4, 4, ?, 256)


In [4]:
# decoding network
cur_tensor=tf.transpose(hidden_state,[3,0,1,2,4])
print(cur_tensor.shape)

decoder_outputs=[cur_tensor]
cur_tensor=layers.unpool3D(cur_tensor)
print(cur_tensor.shape)
decoder_outputs.append(cur_tensor)

k_s = [3,3,3]
deconv_filter_count = [128, 128, 128, 64, 32, 2]
for i in range(2,4): 
    with tf.name_scope("decoding_block"):
        cur_tensor=tf.layers.conv3d(cur_tensor,padding='SAME',filters=deconv_filter_count[i],kernel_size= k_s,activation=None)
        cur_tensor=layers.unpool3D(cur_tensor)
        cur_tensor=tf.nn.relu(cur_tensor)
        print(cur_tensor.shape)
        decoder_outputs.append(cur_tensor)
            
for i in range(4,6): 
    with tf.name_scope("decoding_block_without_unpooling"):
        cur_tensor=tf.layers.conv3d(cur_tensor,padding='SAME',filters=deconv_filter_count[i],kernel_size= k_s,activation=None)
        cur_tensor=tf.nn.relu(cur_tensor)
        print(cur_tensor.shape)
        decoder_outputs.append(cur_tensor)

(?, 4, 4, 4, 256)
(?, 8, 8, 8, 256)
(?, 16, 16, 16, 128)
(?, 32, 32, 32, 64)
(?, 32, 32, 32, 32)
(?, 32, 32, 32, 2)


In [5]:
#3d voxel-wise softmax
y_hat=tf.nn.softmax(decoder_outputs[-1])
p=y_hat[:,:,:,:,0]
q=y_hat[:,:,:,:,1]
cross_entropies=tf.reduce_sum(-tf.multiply(tf.log(p),y)-tf.multiply(tf.log(q),1-y),[1,2,3])
loss=tf.reduce_mean(cross_entropies)
optimizer = tf.train.GradientDescentOptimizer(1).minimize(loss)

In [19]:
# debug
print(debugging)
sess=tf.InteractiveSession()
tf.global_variables_initializer().run()

# tensorboard
writer = tf.summary.FileWriter("./logs/")
writer.add_graph(sess.graph)

# load dataset
shapenet=dataset.ShapeNet()
shapenet.batch_size=2
data,label=shapenet.next_train_batch()    
feed_dict={x:data, y: label};

# enocder debug
out_root_dir="out_dir"
batch_size,time_steps,n_layers=shapenet.batch,24,7
for b in range(batch_size):
    for t in range(time_steps):
        out_dir=os.path.join(out_root_dir,str(b),str(t))
        if not os.path.isdir(out_dir):
            os.makedirs(out_dir)
        
        for i in range(n_layers):
            var=encoder_outputs[i]
            var=var[b,t]
            utils.imsave_multichannel(var.eval(feed_dict),out_dir+"/{}.png".format(str(i)))

        var=encoder_outputs[n_layers]
        var=var.eval(feed_dict)
        var=var[b,t]
        plt.plot(var)
        plt.savefig(out_dir+"/{}.png".format(str(7)))

out_dir/0/0


# debug decoder net
out=tf.cast(tf.argmax(y_hat,axis=4),dtype=tf.float32)
out=sess.run(out,feed_dict=fd)
outvoxel=binvox_rw.Voxels(out,out.shape,[0,0,0],1,'xzy')
with open("out/cur_output.binvox",'w') as f:
    outvoxel.write(f)

# setup training
sess=tf.Session()
root_train_dir = "train_dir"
cur_time = str(datetime.now().strftime('%I:%M:%S%p %Y/%m/%d'))
train_dir=os.path.join(root_train_dir,'session_{}'.format(cur_time))
saver = tf.train.Saver()
init=tf.global_variables_initializer()
sess.run(init)

# train network
print("starting training at {}".format(cur_time))
loss_session=[]
loss_all=[]
epoch=5
for e in range(epoch):
    loss_epoch=[]
    # print("starting epoch_{:03d}".format(e))
    epoch_dir="{}/epoch_{:03d}".format(train_dir,e)
    os.makedirs(epoch_dir)
    batch_number=0
    shapenet.reset()
    data,label=shapenet.next_train_batch()
    while(data is not None): 
        feed_dict={x:data, y: label};
        batch_info=sess.run([tf.trainable_variables(),loss],feed_dict=feed_dict)
        loss_batch=batch_info[-1]
        loss_epoch.append(loss_batch) 
        train=shapenet.next_train_batch() # update
        batch_number+=1
        # show info about current batch
        if batch_number%100==0:
            print("epoch_{:03d}-batch_{:03d}: loss={}".format(e,batch_number,loss_batch))

    loss_session.append(loss_epoch)
    loss_all+=loss_epoch
    # record parameters and generate plots 
    fig = plt.figure()
    plt.plot(loss_session)
    plt.savefig("{}/loss.png".format(epoch_dir),bbox_inches='tight')
    saver.save(sess,"{}/model.ckpt".format(epoch_dir))
    plt.close()
    # save epoch losses