In [1]:
import tensorflow as tf
import numpy as np
import math
na = np.newaxis

import matplotlib.pyplot as plt

# Tensorflow Test

In [2]:
# https://github.com/GPflow/GPflow/issues/439
def tf_tril_indices(N, k=0):
    M1 = tf.tile(tf.expand_dims(tf.range(N), axis=0), [N,1])
    M2 = tf.tile(tf.expand_dims(tf.range(N), axis=1), [1,N])
    mask = (M1-M2) >= -k
    ix1 = tf.boolean_mask(M2, tf.transpose(mask))
    ix2 = tf.boolean_mask(M1, tf.transpose(mask))
    return ix1, ix2

In [3]:
def load_lookup_table(file = 'g_lookup_table.npy'):
    
    return tf.convert_to_tensor(np.load(file), dtype=tf.float32)

In [4]:
def table_lookup_op_parallel(table, keys):
    
    table_keys = table[0]
    table_vals = table[1]
    
    num_keys = tf.shape(keys)[0]
    
    # index from table value with closest table_key to given key
    table_ind = tf.argmin( tf.abs(tf.expand_dims(table_keys, 0) - tf.expand_dims(keys, 1) ) , output_type=tf.int32, axis=1)
    
    top_keys = tf.gather(table_keys, table_ind)
    
    # difference from closest table_key to given key
    shift     = keys - top_keys
    
    # -1 if table_ind == 0, 1 if table_ind > 0 (table ind always >= 0)
    ti_zero_indicator = - tf.sign( tf.cast(tf.subtract(tf.ones([num_keys], dtype=tf.int32), tf.sign(table_ind)), dtype=tf.float32) - tf.constant(.5))
    
    # shift to next table entry (used for gradient computation)
    # if table_key == key:
    # if key != 0 : next smaller table_key is used
    # if key == 0 : next greater table_key is used
    nonzero_shift = (1 - tf.sign(tf.abs(shift))) * (-1. * ti_zero_indicator) + shift
    
    shift_step        = tf.cast(tf.sign(nonzero_shift), tf.int32) 
    table_ind_shifted = table_ind + shift_step
    
    table_val      = tf.gather(table_vals, table_ind)
    next_table_val = tf.gather(table_vals, table_ind_shifted)
    
    table_key      = tf.gather(table_keys, table_ind)
    next_table_key = tf.gather(table_keys, table_ind_shifted)
    
    dx = (next_table_key - table_key)
    dy = (next_table_val - table_val)
    
    gradient               = dy / dx
    interpolated_fun_value = table_val + shift * gradient
    
    return tf.stop_gradient(gradient) * keys + tf.stop_gradient(interpolated_fun_value - gradient * keys)

In [5]:
def ard_kernel(X1, X2, gamma=1., alphas=None):
    # X1:  (n1 x d)
    # X2:  (n2 x d)
    # out: (n1 x n2
    with tf.name_scope('ard_kernel'):
        if alphas is None:
            alphas = tf.ones([tf.shape(X1)[1]])
        return gamma * tf.reduce_prod(tf.exp(- (tf.expand_dims(X1, 1) - tf.expand_dims(X2, 0))**2 / (2 * tf.expand_dims(tf.expand_dims(alphas, 0), 0))), axis=2) 


def mu_tilde_square(X_data, Z, S, m, Kzz_inv):
    k_zx = ard_kernel( Z,X_data, alphas=a_const)
    k_xz = tf.transpose(k_zx)
    K_xx = ard_kernel(X_data, X_data, alphas=a_const)
    mu_sqr = tf.matmul(tf.matmul(tf.transpose(tf.expand_dims(m,1)),Kzz_inv)
                                                     ,k_zx)**2
 
    sig_sqr = K_xx - tf.matmul(tf.matmul(k_xz,K_zz_inv),k_zx) + tf.matmul(tf.matmul(tf.matmul(tf.matmul(k_xz,Kzz_inv),S),Kzz_inv),k_zx)

    return mu_sqr,sig_sqr

def kl_term(m, S, K_zz, K_zz_inv, u_ovln):
    # mean_diff = (u_ovln * tf.ones([tf.shape(Z_ph)[0]]) - m)
    mean_diff = tf.expand_dims(u_ovln * tf.ones([tf.shape(Z_ph)[0]]) - m, 1)
    first  = tf.trace(tf.matmul(K_zz_inv, S))
    second = tf.log(tf.matrix_determinant(K_zz) / tf.matrix_determinant(S))
    third  = tf.to_float(tf.shape(m)[0])
    # fourth = tf.reduce_sum(tf.multiply(tf.reduce_sum(tf.multiply(mean_diff, tf.transpose(K_zz_inv)), axis=1) , mean_diff))
    
    fourth = tf.squeeze(tf.matmul(tf.matmul(tf.transpose(mean_diff), K_zz_inv), mean_diff))
    
    return 0.5 * (first  + second - third + fourth)

def psi_term(Z1, Z2,a,g,Tmin,Tmax):
    z_ovln = (tf.expand_dims(Z1,1)+tf.expand_dims(Z2,0))/2
    a_r = tf.expand_dims(tf.expand_dims(a,0),1)
    
    pi = tf.constant(math.pi)
    
    return g**2 * tf.reduce_prod(-(tf.sqrt(pi * a_r)/2
                   ) * tf.exp(-tf.pow(tf.expand_dims(Z1,1) - tf.expand_dims(Z2,0),2) / (4 * a_r)
                             ) * tf.erf((z_ovln-tf.expand_dims(tf.expand_dims(Tmax,0),1))/tf.sqrt(a_r)
                                     ) - tf.erf((z_ovln-tf.expand_dims(tf.expand_dims(Tmin,0),1))/tf.sqrt(a_r)),2)

def T_Integral(m, S, Kzz_inv,psi, g,Tmin, Tmax):
    #e_qf = tf.matmul(m,tf.matmul(Kzz_inv,tf.matmul(psi,tf.matmul(Kzz_inv,m))))
    e_qf = tf.matmul(tf.matmul(tf.matmul(tf.matmul(tf.transpose(tf.expand_dims(m,1)),Kzz_inv),psi),Kzz_inv),tf.expand_dims(m,1))
    T = tf.reduce_prod(Tmax-Tmin)
    var_qf = g * T - tf.trace(tf.matmul(Kzz_inv,psi)) + tf.trace(tf.matmul(tf.matmul(tf.matmul(Kzz_inv,S),Kzz_inv),psi))
    return e_qf + var_qf

def G(mu_sqr,sig_sqr_matrix):
    
    sig_sqr = tf.diag_part(sig_sqr_matrix)
    lookup_x = - tf.squeeze(mu_sqr) / (2*sig_sqr)
    
    # return lookup_x
    
    # TODO: einkomment
    lookup_table = load_lookup_table()
    return table_lookup_op_parallel(lookup_table, lookup_x)
    
    
def exp_at_datapoints(mu_sqr,sig_sqr,C):
    return tf.reduce_sum(-G(mu_sqr,sig_sqr)+tf.log(mu_sqr/2)-C,axis=1)

In [6]:
tf.reset_default_graph()

Z_ph = tf.placeholder(tf.float32, [None, None], name='inducing_point_locations')
u_ph = tf.placeholder(tf.float32, [],           name='inducing_point_mean')
n_ph = tf.placeholder(tf.int32,   [],           name='number_samples')

X_ph =tf.placeholder(tf.float32, [None, None])

a_const = tf.ones([1]) # dimension = tf.shape(Z_ph)[1]
g_const = tf.ones([1]) # later we have to define gamma as variable
C = tf.constant(0.57721566)

#Tlims
Tmins = tf.reduce_min(Z_ph, axis=0)
Tmaxs = tf.reduce_max(Z_ph, axis=0)

num_inducing_points = 11 # tf.shape(Z_ph)[0] TODO: use shape of Z_ph instead? Right now, the number is defined twice (once here, one above in the definition of Z)

# mean
m_init = tf.ones([num_inducing_points])
m = tf.Variable(m_init)

# vectorized version of covariance matrix S (ensure valid covariance matrix)
vech_size   = (num_inducing_points * (num_inducing_points+1)) / 2 
vech_indices= tf.transpose(tf_tril_indices(num_inducing_points))
L_vech_init = tf.ones([vech_size])
L_vech = tf.Variable(L_vech_init)
L_shape = tf.constant([num_inducing_points, num_inducing_points])
L_st = tf.SparseTensor(tf.to_int64(vech_indices), L_vech, tf.to_int64(L_shape))
L = tf.sparse_add(tf.zeros(L_shape), L_st)
S = tf.matmul(L, tf.transpose(L))

# kernel calls
K_zz  = ard_kernel(Z_ph, Z_ph, alphas=a_const)
K_zz_inv = tf.matrix_inverse(K_zz)

with tf.name_scope('intergration-over-region-T'):
    psi_matrix = psi_term(Z_ph,Z_ph,a_const,g_const,Tmins,Tmaxs)
    integral_over_T = T_Integral(m,S,K_zz_inv,psi_matrix,g_const,Tmins,Tmaxs)
    
with tf.name_scope('expectation_at_datapoints'):
    mu_t_sqr, sig_t_sqr = mu_tilde_square(X_ph,Z_ph,S,m,K_zz_inv)
    exp_term = exp_at_datapoints(mu_t_sqr,sig_t_sqr,C)
    
    
    sig_sqr = tf.diag_part(sig_t_sqr)
    in_g_val = - tf.squeeze(mu_t_sqr) / (2*sig_sqr) 

with tf.name_scope('KL-divergence'):
    kl_term_op = kl_term(m, S, K_zz, K_zz_inv, u_ph)
    tf.summary.scalar('kl_div', kl_term_op)

with tf.name_scope('calculate_bound'):
    lower_bound = -integral_over_T + exp_term - kl_term_op
    
with tf.name_scope('optimization'):
    train_step = tf.train.GradientDescentOptimizer(0.001).minimize(-lower_bound)

with tf.name_scope('prior_sampling'):
    cov  = K_zz
    mean = u_ph * tf.ones([num_inducing_points])
    ind_point_dist = tf.contrib.distributions.MultivariateNormalFullCovariance(mean, cov)
    samples = ind_point_dist.sample(n_ph)
    
m_grad = tf.gradients(kl_term_op, [m])[0]  
L_vech_grad = tf.gradients(kl_term_op, [L_vech])[0]

    
merged = tf.summary.merge_all()

In [7]:
max_iterations = 1500

# inducing point location
Zx = np.linspace(0, 20, 11)[:,na]
Zy = np.linspace(0, 10, 11)[:,na]

Z = np.concatenate((Zx,Zy),axis=1)

X = np.random.rand(10,2)

wr_means = []
wr_covar = []

with tf.Session() as sess:
    
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter('logs', sess.graph)
    
    for i in range(max_iterations):
        _, summary, g_in, kl, exp_val,mu,sig,integral,bound, m_val, S_val = sess.run([train_step, merged,  in_g_val, kl_term_op,exp_term,mu_t_sqr, sig_t_sqr , integral_over_T ,lower_bound, m, S], 
                                                     feed_dict={Z_ph:Z, u_ph:0.,X_ph:X})
        writer.add_summary(summary, i)
        
        if i % 100 == 0:
            print('..........')
            print(bound)
            #print(exp_val)
            #print(sig)
            #print(mu.shape)
            print(g_in)
            #print(sig)
            #print(np.max(np.absolute(S_val)))
            #print(np.all(np.linalg.eigvals(S_val) > 0))
            
            plt.scatter(i, bound)
            
        if np.isclose(kl, 0):
            print('KL is zero after {} iterations... break'.format(i))
            break

plt.show()

..........
[[-554.01104736]]
[-0.30560243 -0.40847099 -0.35915816 -0.31738538 -0.33168599 -0.35080644
 -0.3362048  -0.33171618 -0.34115359 -0.44091654]
..........
[[-222.37870789]]
[ -3.43889785 -10.14654827  -5.87465572  -3.67795753  -4.09985781
  -5.12037563  -4.52327013  -4.34702682  -4.56320429 -17.16475105]
..........
[[-158.84500122]]
[ -6.92291307 -20.36220932 -11.64884472  -7.48857784  -8.46676731
 -10.39805508  -9.11035538  -8.74897289  -9.35828781 -33.08942795]
..........
[[-122.99356079]]
[-11.52023125 -33.48563004 -19.0013485  -12.62526417 -14.44328594
 -17.40238571 -15.13853931 -14.53036976 -15.83311844 -52.7058754 ]
..........
[[-79.61817169]]
[-17.87254143 -51.38672256 -28.93497658 -19.83187675 -22.9222908
 -27.13115883 -23.45102501 -22.49812698 -24.93323708 -79.00804901]
..........
[[-10.64123535]]
[ -27.15439415  -77.32461548  -43.2389946   -30.4603157   -35.49303818
  -41.38575745  -35.57977295  -34.12007523  -38.35613251 -116.72941589]
..........
[[ 109.70430756]]
[ 

InvalidArgumentError: indices[9] = -1 is not in [0, 10000)
	 [[Node: expectation_at_datapoints/Gather_4 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](expectation_at_datapoints/strided_slice, expectation_at_datapoints/add_2)]]

Caused by op 'expectation_at_datapoints/Gather_4', defined at:
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.5/dist-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/usr/local/lib/python3.5/dist-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-2e2349e335a1>", line 43, in <module>
    exp_term = exp_at_datapoints(mu_t_sqr,sig_t_sqr,C)
  File "<ipython-input-5-2940b43b21cd>", line 65, in exp_at_datapoints
    return tf.reduce_sum(-G(mu_sqr,sig_sqr)+tf.log(mu_sqr/2)-C,axis=1)
  File "<ipython-input-5-2940b43b21cd>", line 61, in G
    return table_lookup_op_parallel(lookup_table, lookup_x)
  File "<ipython-input-4-73e313693e7d>", line 32, in table_lookup_op_parallel
    next_table_key = tf.gather(table_keys, table_ind_shifted)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 2486, in gather
    params, indices, validate_indices=validate_indices, name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1834, in gather
    validate_indices=validate_indices, name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): indices[9] = -1 is not in [0, 10000)
	 [[Node: expectation_at_datapoints/Gather_4 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](expectation_at_datapoints/strided_slice, expectation_at_datapoints/add_2)]]


In [None]:
# numpy ard_kernel calculation
def ard_kernel_np(X1, X2, gamma = 1., alphas = None):
    
    if alphas == None:
        alphas = np.ones(X1.shape[1])
    
    return gamma * np.prod(np.exp( - (X1[:,None,:] - X2[None,:,:])**2 / (2 * alphas[None,None,:])), axis=2)

In [None]:
# TEST

r_mean = np.zeros(num_inducing_points)
r_cov = ard_kernel_np(Z, Z)

print(np.allclose(r_mean, m_val))
print(np.allclose(r_cov, S_val))

print(np.sum(r_cov - S_val)**2)

# values not allclose but almost the same

### example sampling:

In [None]:
# sampling
num_samples = 10

with tf.Session() as sess:
    
    sess.run(tf.global_variables_initializer())
    
    writer = tf.summary.FileWriter('logs', sess.graph)
    sample_res, mean_res, cov_res = sess.run([samples, mean, cov], feed_dict={Z_ph:Z, u_ph:0., n_ph:num_samples})

In [None]:
for i in range(num_samples):
    plt.plot(Z, sample_res[i])
plt.show()

# Planning:

# 1. Inputs:
Domain $X = \mathbb{R}^R$ ($R$-Dimensional)
$T \subset X$


### Fixed:

#### General:
- R
- Tlims: (Rx2)
- Data D: (NxR) (all points in T)

#### Inducing points:
    -> fix nbr M 
    -> Z: (MxR) location
    -> u: (M) (each sample of function values at the inducing points is M dimensional) 

#### Hyperparameters ($\Theta$)
    -> fixed at first, might become be optimized later as well
$\Theta = (\gamma, \alpha_1,...,  \alpha_R, \overline{u})$


### Parameters:
variational dist at inducing points u: q(u) = N (u;m,S)

m: (M)

S: (MxM)

In [None]:
# pseudo stuff, no working code!!!

# constants:
# - R, T, M
# - D (NxR)
# - Z (MxR)
# - Theta values (gamma, alphas, ustrich)

# placeholders:
# - m (M)
# - S (MxM)


# kernel stuff:

def kernel_function(X, Y):
    return K_XY

# kernels to compute: K_zz, K_zd, trace(K_dd)

def lower_bound(D, m, S, Theta, T):
    K_zz = ...
    K_zz_inv = ...
    
    return - region_integral(m, S, K_zz_inv, Z, Theta, T) + datapoint_expectations(D, m, S, K_zz_inv, Theta) - kl_term(m,S,K_zz, K_zz_inv, Theta) 

def kl_term(m, S, K_zz, K_zz_inv, Theta):
    return scalar_value_node

def datapoint_expectations(D, m, S, K_zz_inv, Theta):
    
    k_zd = ...
    k_dd = ...
    
    musqare_N = ...
    sigsquare_N = ...
    
    C = 0.57721... #Euler-Masceroni constant
    
    lookup values = ... # problem: how to implement lookup table
    
    return scalar_value_node

def region_integral(m, S, K_zz_inv, Z, Theta, T):
    Psi = ... (MxM) 
    return scalar_value_node