### In this notebook, we generate the results of the animal experiment of the article with precomputed persistence diagrams obtained from the authors code.

### Do not forget to move the .so files to the dist-packages repo of your python version! 

In [1]:
import numpy              as np
np.set_printoptions(threshold='nan')
import math       
import sys
from   random             import shuffle   
import matplotlib.pyplot  as plt
import h5py

import tensorflow as tf
import _persistence_vector_grad
persistence_vector_module = tf.load_op_library('persistence_vector.so')

  from ._conv import register_converters as _register_converters


## Data reading
This code assumes that precomputed persistence diagrams are available in a h5 file called npht_animal_32dirs.h5 in a repo called animal. This file can be obtained by running code provided by the authors, or by computing the diagrams manually with Gudhi using the notebook "Persistence Diagram Computations for Binary Images". 

In [2]:
f = h5py.File("animal/npht_animal_32dirs.h5")
animals = ('bird' ,'butterfly','cat','cow','crocodile','deer','dog','dolphine','duck','elephant','fish',
           'flyingbird','hen','horse','leopard','monkey','rabbit','rat','spider','tortoise')

num_labels = 20
num_dir    = 32

labels = []
curr_lab = 0
for animal in animals:
    for _ in range(100):
        vector_lab = np.zeros(num_labels)
        vector_lab[curr_lab] = 1
        labels.append(vector_lab)
    curr_lab += 1

D = []
C = []
for k in range(num_dir):
    diag_dir = []
    card_dir = []    
    for animal in animals:
        for name in f['data_views']['dim_0_dir_' + str(k)][animal].keys():
            diag = f['data_views']['dim_0_dir_' + str(k)][animal][name]
            diag_dir.append(diag)
            card_dir.append(diag.shape[0])
    D.append(diag_dir)
    C.append(card_dir)
    
num_pts = len(labels)

## Definition of the network as described in the article

In [6]:
# Network parameters as described in the article
num_gaussians = 75
nu_tensor = tf.Variable([[0.01]], trainable=False)
keep_prob = 0.9
    
# Random initialization of Gaussians
gaussians = []
for j in range(num_dir):
    mu     = tf.Variable(tf.random_uniform([num_gaussians,2]))
    sigma  = tf.Variable(np.float32(3*np.ones([num_gaussians,2])))
    gaussians.append(tf.concat([mu,sigma],1))
    
# Random initialization for layer weights
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

# Placeholders
diag      = [tf.placeholder(tf.float32, shape=(None, 2)) for _ in xrange(num_dir)]
card      = [tf.placeholder(tf.float32, shape=(None, 1)) for _ in xrange(num_dir)]
label     = tf.placeholder(tf.float32, shape=[None, num_labels])

layer10_list = []

# For each direction
for k in range(num_dir):
    
    # Use current direction
    v_middle = persistence_vector_module.persistence_vector(diag[k], 
                                                            card[k], 
                                                            gaussians[k], nu_tensor)
    v_middle = tf.reshape(v_middle,[-1,num_gaussians,1])
    
    # Use previous direction
    v_left   = persistence_vector_module.persistence_vector(diag[(k-1) % num_dir], 
                                                            card[(k-1) % num_dir], 
                                                            gaussians[k], nu_tensor)
    v_left   = tf.reshape(v_left,[-1,num_gaussians,1])
    
    # Use next direction
    v_right  = persistence_vector_module.persistence_vector(diag[(k+1) % num_dir], 
                                                            card[(k+1) % num_dir],  
                                                            gaussians[k], nu_tensor)
    v_right  = tf.reshape(v_right,[-1,num_gaussians,1])
        
    # Conv1D -> Conv1D -> Max Pool -> Linear -> Batch Norm -> Linear -> ReLu -> Dropout
    layer1 = tf.concat([v_middle,v_left,v_right], 2)
    W2 = weight_variable([1,3,16])
    b2 = bias_variable([16])
    layer2 = tf.nn.conv1d(layer1, W2, stride=1, padding='SAME') + b2
    W3 = weight_variable([1,16,4])  
    b3 = bias_variable([4])
    layer3 = tf.nn.conv1d(layer2, W3, stride=1, padding='SAME') + b3
    layer4 = tf.reduce_max(input_tensor=layer3, axis=[2])
    W5 = weight_variable([num_gaussians,25])  
    b5 = bias_variable([25])
    layer5 = tf.matmul(layer4,W5) + b5
    moments5 = tf.nn.moments(layer5, axes=[0])
    layer6   = tf.nn.batch_normalization(layer5, moments5[0], moments5[1], None, None, 1e-10)
    W7 = weight_variable([25,25])  
    b7 = bias_variable([25])
    layer7 = tf.matmul(layer6,W7) + b7
    layer8 = tf.nn.relu(layer7)
    layer9 = tf.nn.dropout(layer8, keep_prob)
    layer10_list.append(layer9)

# Concatenate all branches
layer10 = tf.concat(layer10_list, 1)

# Linear -> Batch Norm -> Dropout -> Linear
W11 = weight_variable([num_dir*25, 100])
b11 = bias_variable([100])
layer11 = tf.matmul(layer10,W11) + b11
moments11 = tf.nn.moments(layer11, axes=[0])
layer12   = tf.nn.batch_normalization(layer11, moments11[0], moments11[1], None, None, 1e-10)
layer13 = tf.nn.dropout(layer12, keep_prob)
W14 = weight_variable([100,num_labels])
b14 = bias_variable([num_labels])
layer14 = tf.matmul(layer13, W14) + b14

## Shuffle Data

In [4]:
# Shuffle data so as to be homogeneous
p = [i for i in range(num_pts)]
shuffle(p)
labels = [labels[i] for i in p]
for k in range(num_dir):
    D[k] = [D[k][i] for i in p]
    C[k] = [C[k][i] for i in p]

## Do the optimization and evaluate on dataset

In [8]:
# Use 90% of data as training
training_ratio = 0.9
training_size  = np.round(training_ratio*num_pts).astype(np.int)

# Split data into batches of size 128 and optimize over 30 epochs 
batch_size     = 128
nb_epoch       = 30

# Compute number of batches
if training_size % batch_size == 0:
    num_batches = training_size/batch_size
else:
    num_batches = training_size/batch_size + 1 
    
# Define accuracy
correct_prediction    = tf.equal(tf.argmax(layer14, 1), tf.argmax(label, 1))
accuracy              = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Optimize cross entropy loss with Gradient Descent and decaying learning rate
cross_entropy         = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=layer14))
global_step           = tf.Variable(0, trainable=False)
learning_rate         = tf.train.exponential_decay(0.1, global_step, 25*num_batches, 0.5, staircase=True)
learning              = tf.train.GradientDescentOptimizer(learning_rate)
learning_step         = learning.minimize(cross_entropy, global_step=global_step)

# Define training batches
chunked_data = []

for i in range(0,training_size,batch_size):
    
    if (training_size < i + batch_size):
        num_pts_in_batch = training_size - i
    else:
        num_pts_in_batch = batch_size
    
    batch_diag     = [[] for _ in range(num_dir)]
    batch_card     = [[] for _ in range(num_dir)]
    batch_labels   = []
    for j in range(num_pts_in_batch):
        batch_labels.append(labels[i+j])
        for k in range(num_dir):
            batch_diag[k].append(D[k][i+j])
            batch_card[k].append(C[k][i+j])
            
    # Reshape data for dimensionality agreement with network 
    for k in range(num_dir):
        batch_diag[k] = np.concatenate(batch_diag[k],0)
        batch_card[k] = np.reshape(np.float32(batch_card[k]),[-1,1])
    chunked_data.append((batch_labels, batch_diag, batch_card))

# Define test set
test_diag     = [[] for _ in range(num_dir)]
test_card     = [[] for _ in range(num_dir)]
test_labels   = []
for i in range(training_size,num_pts):
    test_labels.append(labels[i])
    for k in range(num_dir):
        test_diag[k].append(D[k][i])
        test_card[k].append(C[k][i])
        
# Reshape data for dimensionality agreement with network 
for k in range(num_dir):
    test_diag[k] = np.concatenate(test_diag[k],0)
    test_card[k] = np.reshape(np.float32(test_card[k]),[-1,1])
    
with tf.Session() as sess:

    # Initialize
    sess.run(tf.global_variables_initializer())

    # For each epoch
    for ep in range(nb_epoch):
                
        # Compute training accuracy   
        acc = 0
        for b in range(0,num_batches):
            acc += accuracy.eval(feed_dict={label:  chunked_data[b][0],
                                            diag[0]:chunked_data[b][1][0], 
                                            diag[1]:chunked_data[b][1][1], 
                                            diag[2]:chunked_data[b][1][2], 
                                            diag[3]:chunked_data[b][1][3], 
                                            diag[4]:chunked_data[b][1][4], 
                                            diag[5]:chunked_data[b][1][5], 
                                            diag[6]:chunked_data[b][1][6], 
                                            diag[7]:chunked_data[b][1][7], 
                                            diag[8]:chunked_data[b][1][8], 
                                            diag[9]:chunked_data[b][1][9], 
                                            diag[10]:chunked_data[b][1][10], 
                                            diag[11]:chunked_data[b][1][11], 
                                            diag[12]:chunked_data[b][1][12], 
                                            diag[13]:chunked_data[b][1][13], 
                                            diag[14]:chunked_data[b][1][14], 
                                            diag[15]:chunked_data[b][1][15],
                                            diag[16]:chunked_data[b][1][16], 
                                            diag[17]:chunked_data[b][1][17], 
                                            diag[18]:chunked_data[b][1][18], 
                                            diag[19]:chunked_data[b][1][19], 
                                            diag[20]:chunked_data[b][1][20], 
                                            diag[21]:chunked_data[b][1][21], 
                                            diag[22]:chunked_data[b][1][22], 
                                            diag[23]:chunked_data[b][1][23],
                                            diag[24]:chunked_data[b][1][24], 
                                            diag[25]:chunked_data[b][1][25], 
                                            diag[26]:chunked_data[b][1][26], 
                                            diag[27]:chunked_data[b][1][27], 
                                            diag[28]:chunked_data[b][1][28], 
                                            diag[29]:chunked_data[b][1][29], 
                                            diag[30]:chunked_data[b][1][30], 
                                            diag[31]:chunked_data[b][1][31],
                                            card[0]:chunked_data[b][2][0], 
                                            card[1]:chunked_data[b][2][1], 
                                            card[2]:chunked_data[b][2][2], 
                                            card[3]:chunked_data[b][2][3], 
                                            card[4]:chunked_data[b][2][4], 
                                            card[5]:chunked_data[b][2][5], 
                                            card[6]:chunked_data[b][2][6], 
                                            card[7]:chunked_data[b][2][7], 
                                            card[8]:chunked_data[b][2][8], 
                                            card[9]:chunked_data[b][2][9], 
                                            card[10]:chunked_data[b][2][10], 
                                            card[11]:chunked_data[b][2][11], 
                                            card[12]:chunked_data[b][2][12], 
                                            card[13]:chunked_data[b][2][13], 
                                            card[14]:chunked_data[b][2][14], 
                                            card[15]:chunked_data[b][2][15],
                                            card[16]:chunked_data[b][2][16], 
                                            card[17]:chunked_data[b][2][17], 
                                            card[18]:chunked_data[b][2][18], 
                                            card[19]:chunked_data[b][2][19], 
                                            card[20]:chunked_data[b][2][20], 
                                            card[21]:chunked_data[b][2][21], 
                                            card[22]:chunked_data[b][2][22], 
                                            card[23]:chunked_data[b][2][23],
                                            card[24]:chunked_data[b][2][24], 
                                            card[25]:chunked_data[b][2][25], 
                                            card[26]:chunked_data[b][2][26], 
                                            card[27]:chunked_data[b][2][27], 
                                            card[28]:chunked_data[b][2][28], 
                                            card[29]:chunked_data[b][2][29], 
                                            card[30]:chunked_data[b][2][30], 
                                            card[31]:chunked_data[b][2][31]})
        acc /= num_batches
        
        # Compute test accuracy
        acc_test = accuracy.eval(feed_dict={    label:  test_labels,
                                                diag[0]:test_diag[0], 
                                                diag[1]:test_diag[1], 
                                                diag[2]:test_diag[2], 
                                                diag[3]:test_diag[3], 
                                                diag[4]:test_diag[4], 
                                                diag[5]:test_diag[5], 
                                                diag[6]:test_diag[6], 
                                                diag[7]:test_diag[7], 
                                                diag[8]:test_diag[8], 
                                                diag[9]:test_diag[9], 
                                                diag[10]:test_diag[10], 
                                                diag[11]:test_diag[11], 
                                                diag[12]:test_diag[12], 
                                                diag[13]:test_diag[13], 
                                                diag[14]:test_diag[14], 
                                                diag[15]:test_diag[15],
                                                diag[16]:test_diag[16], 
                                                diag[17]:test_diag[17], 
                                                diag[18]:test_diag[18], 
                                                diag[19]:test_diag[19], 
                                                diag[20]:test_diag[20], 
                                                diag[21]:test_diag[21], 
                                                diag[22]:test_diag[22], 
                                                diag[23]:test_diag[23],
                                                diag[24]:test_diag[24], 
                                                diag[25]:test_diag[25], 
                                                diag[26]:test_diag[26], 
                                                diag[27]:test_diag[27], 
                                                diag[28]:test_diag[28], 
                                                diag[29]:test_diag[29], 
                                                diag[30]:test_diag[30], 
                                                diag[31]:test_diag[31],
                                                card[0]:test_card[0], 
                                                card[1]:test_card[1], 
                                                card[2]:test_card[2], 
                                                card[3]:test_card[3], 
                                                card[4]:test_card[4], 
                                                card[5]:test_card[5], 
                                                card[6]:test_card[6], 
                                                card[7]:test_card[7], 
                                                card[8]:test_card[8], 
                                                card[9]:test_card[9], 
                                                card[10]:test_card[10], 
                                                card[11]:test_card[11], 
                                                card[12]:test_card[12], 
                                                card[13]:test_card[13], 
                                                card[14]:test_card[14], 
                                                card[15]:test_card[15],
                                                card[16]:test_card[16], 
                                                card[17]:test_card[17], 
                                                card[18]:test_card[18], 
                                                card[19]:test_card[19], 
                                                card[20]:test_card[20], 
                                                card[21]:test_card[21], 
                                                card[22]:test_card[22], 
                                                card[23]:test_card[23],
                                                card[24]:test_card[24], 
                                                card[25]:test_card[25], 
                                                card[26]:test_card[26], 
                                                card[27]:test_card[27], 
                                                card[28]:test_card[28], 
                                                card[29]:test_card[29], 
                                                card[30]:test_card[30], 
                                                card[31]:test_card[31]})
        sys.stdout.write("Epoch %d, Train Accuracy %g %s, Test Accuracy %g %s \n" 
                         % (ep,100*acc, '%', 100*acc_test, '%')) 
            
        # Optimize each batch
        for b in range(0,num_batches):
            learning_step.run(feed_dict={       label:  chunked_data[b][0],
                                                diag[0]:chunked_data[b][1][0], 
                                                diag[1]:chunked_data[b][1][1], 
                                                diag[2]:chunked_data[b][1][2], 
                                                diag[3]:chunked_data[b][1][3], 
                                                diag[4]:chunked_data[b][1][4], 
                                                diag[5]:chunked_data[b][1][5], 
                                                diag[6]:chunked_data[b][1][6], 
                                                diag[7]:chunked_data[b][1][7], 
                                                diag[8]:chunked_data[b][1][8], 
                                                diag[9]:chunked_data[b][1][9], 
                                                diag[10]:chunked_data[b][1][10], 
                                                diag[11]:chunked_data[b][1][11], 
                                                diag[12]:chunked_data[b][1][12], 
                                                diag[13]:chunked_data[b][1][13], 
                                                diag[14]:chunked_data[b][1][14], 
                                                diag[15]:chunked_data[b][1][15],
                                                diag[16]:chunked_data[b][1][16], 
                                                diag[17]:chunked_data[b][1][17], 
                                                diag[18]:chunked_data[b][1][18], 
                                                diag[19]:chunked_data[b][1][19], 
                                                diag[20]:chunked_data[b][1][20], 
                                                diag[21]:chunked_data[b][1][21], 
                                                diag[22]:chunked_data[b][1][22], 
                                                diag[23]:chunked_data[b][1][23],
                                                diag[24]:chunked_data[b][1][24], 
                                                diag[25]:chunked_data[b][1][25], 
                                                diag[26]:chunked_data[b][1][26], 
                                                diag[27]:chunked_data[b][1][27], 
                                                diag[28]:chunked_data[b][1][28], 
                                                diag[29]:chunked_data[b][1][29], 
                                                diag[30]:chunked_data[b][1][30], 
                                                diag[31]:chunked_data[b][1][31],
                                                card[0]:chunked_data[b][2][0], 
                                                card[1]:chunked_data[b][2][1], 
                                                card[2]:chunked_data[b][2][2], 
                                                card[3]:chunked_data[b][2][3], 
                                                card[4]:chunked_data[b][2][4], 
                                                card[5]:chunked_data[b][2][5], 
                                                card[6]:chunked_data[b][2][6], 
                                                card[7]:chunked_data[b][2][7], 
                                                card[8]:chunked_data[b][2][8], 
                                                card[9]:chunked_data[b][2][9], 
                                                card[10]:chunked_data[b][2][10], 
                                                card[11]:chunked_data[b][2][11], 
                                                card[12]:chunked_data[b][2][12], 
                                                card[13]:chunked_data[b][2][13], 
                                                card[14]:chunked_data[b][2][14], 
                                                card[15]:chunked_data[b][2][15],
                                                card[16]:chunked_data[b][2][16], 
                                                card[17]:chunked_data[b][2][17], 
                                                card[18]:chunked_data[b][2][18], 
                                                card[19]:chunked_data[b][2][19], 
                                                card[20]:chunked_data[b][2][20], 
                                                card[21]:chunked_data[b][2][21], 
                                                card[22]:chunked_data[b][2][22], 
                                                card[23]:chunked_data[b][2][23],
                                                card[24]:chunked_data[b][2][24], 
                                                card[25]:chunked_data[b][2][25], 
                                                card[26]:chunked_data[b][2][26], 
                                                card[27]:chunked_data[b][2][27], 
                                                card[28]:chunked_data[b][2][28], 
                                                card[29]:chunked_data[b][2][29], 
                                                card[30]:chunked_data[b][2][30], 
                                                card[31]:chunked_data[b][2][31]})
        
    # Compute final accuracy
    acc_test = accuracy.eval(feed_dict={label:  test_labels,
                                                diag[0]:test_diag[0], 
                                                diag[1]:test_diag[1], 
                                                diag[2]:test_diag[2], 
                                                diag[3]:test_diag[3], 
                                                diag[4]:test_diag[4], 
                                                diag[5]:test_diag[5], 
                                                diag[6]:test_diag[6], 
                                                diag[7]:test_diag[7], 
                                                diag[8]:test_diag[8], 
                                                diag[9]:test_diag[9], 
                                                diag[10]:test_diag[10], 
                                                diag[11]:test_diag[11], 
                                                diag[12]:test_diag[12], 
                                                diag[13]:test_diag[13], 
                                                diag[14]:test_diag[14], 
                                                diag[15]:test_diag[15],
                                                diag[16]:test_diag[16], 
                                                diag[17]:test_diag[17], 
                                                diag[18]:test_diag[18], 
                                                diag[19]:test_diag[19], 
                                                diag[20]:test_diag[20], 
                                                diag[21]:test_diag[21], 
                                                diag[22]:test_diag[22], 
                                                diag[23]:test_diag[23],
                                                diag[24]:test_diag[24], 
                                                diag[25]:test_diag[25], 
                                                diag[26]:test_diag[26], 
                                                diag[27]:test_diag[27], 
                                                diag[28]:test_diag[28], 
                                                diag[29]:test_diag[29], 
                                                diag[30]:test_diag[30], 
                                                diag[31]:test_diag[31],
                                                card[0]:test_card[0], 
                                                card[1]:test_card[1], 
                                                card[2]:test_card[2], 
                                                card[3]:test_card[3], 
                                                card[4]:test_card[4], 
                                                card[5]:test_card[5], 
                                                card[6]:test_card[6], 
                                                card[7]:test_card[7], 
                                                card[8]:test_card[8], 
                                                card[9]:test_card[9], 
                                                card[10]:test_card[10], 
                                                card[11]:test_card[11], 
                                                card[12]:test_card[12], 
                                                card[13]:test_card[13], 
                                                card[14]:test_card[14], 
                                                card[15]:test_card[15],
                                                card[16]:test_card[16], 
                                                card[17]:test_card[17], 
                                                card[18]:test_card[18], 
                                                card[19]:test_card[19], 
                                                card[20]:test_card[20], 
                                                card[21]:test_card[21], 
                                                card[22]:test_card[22], 
                                                card[23]:test_card[23],
                                                card[24]:test_card[24], 
                                                card[25]:test_card[25], 
                                                card[26]:test_card[26], 
                                                card[27]:test_card[27], 
                                                card[28]:test_card[28], 
                                                card[29]:test_card[29], 
                                                card[30]:test_card[30], 
                                                card[31]:test_card[31]})
    sys.stdout.write("Test Accuracy = %g %s " % (100*acc_test, '%'))

Epoch 0, Train Accuracy 5.98958 %, Test Accuracy 10 % 
Epoch 1, Train Accuracy 40.3646 %, Test Accuracy 32 % 
Epoch 2, Train Accuracy 52.5 %, Test Accuracy 47 % 
Epoch 3, Train Accuracy 61.9792 %, Test Accuracy 54 % 
Epoch 4, Train Accuracy 65.5729 %, Test Accuracy 57.5 % 
Epoch 5, Train Accuracy 68.9062 %, Test Accuracy 61.5 % 
Epoch 6, Train Accuracy 70.625 %, Test Accuracy 61 % 
Epoch 7, Train Accuracy 71.6146 %, Test Accuracy 59.5 % 
Epoch 8, Train Accuracy 72.3438 %, Test Accuracy 62 % 
Epoch 9, Train Accuracy 74.0625 %, Test Accuracy 68 % 
Epoch 10, Train Accuracy 76.5625 %, Test Accuracy 64.5 % 
Epoch 11, Train Accuracy 77.8125 %, Test Accuracy 62.5 % 
Epoch 12, Train Accuracy 79.0625 %, Test Accuracy 65 % 
Epoch 13, Train Accuracy 80.4688 %, Test Accuracy 67.5 % 
Epoch 14, Train Accuracy 80.2083 %, Test Accuracy 66 % 
Epoch 15, Train Accuracy 81.3021 %, Test Accuracy 64.5 % 
Epoch 16, Train Accuracy 82.1354 %, Test Accuracy 66.5 % 
Epoch 17, Train Accuracy 82.6562 %, Test Accur