In [1]:
import gpflow
import tensorflow as tf
import numpy as np

from gpflow import Parameterized, Param

from gpflow.kernels import RBF

In [2]:
def xavier(dim_in, dim_out):
    return np.random.randn(dim_in, dim_out)*(2./(dim_in+dim_out))**0.5

class NN(Parameterized):
    def __init__(self, dims):
        Parameterized.__init__(self)
        self.dims = dims
        for i, (dim_in, dim_out) in enumerate(zip(dims[:-1], dims[1:])):
            setattr(self, 'W_{}'.format(i), Param(xavier(dim_in, dim_out)))
            setattr(self, 'b_{}'.format(i), Param(np.zeros(dim_out)))

    def forward(self, X):
        if X is not None:
            for i in range(len(self.dims) - 1):
                W = getattr(self, 'W_{}'.format(i))
                b = getattr(self, 'b_{}'.format(i))
#                 print(i)
#                 print(X.shape, W.shape)
                X = tf.nn.tanh(tf.matmul(X, W) + b)
            return X


In [3]:
class NN_RBF(RBF):
    def __init__(self, nn, *args, **kw):
        RBF.__init__(self, *args, **kw)
        self.nn = nn
    
    def scaled_square_dist(self, X, X2):
#         print('calling from derived class!!!')
        M = self.nn.forward(X)
        M2 = self.nn.forward(X2)
#         print(M.get_shape())
        return super().scaled_square_dist(M, M2)
#         return RBF.scaled_square_dist(self, X, X2)
          

In [4]:
from tensorflow.python.ops import variables

net = NN([5, 8, 2])  # for 5D inputs and a 2D GP
kern = NN_RBF(net, 5)

train_data_x = np.random.randn(100, 5)
train_data_y = np.random.randn(100, 1)




In [8]:
m = gpflow.models.GPR(train_data_x, train_data_y, kern=kern)
m.compile()
print(m.read_trainables())
print(m.compute_log_likelihood())
opt = gpflow.train.ScipyOptimizer()
opt.minimize(m)
# print(variables.trainable_variables())
print(m.read_trainables())
print(m.compute_log_likelihood())

{'GPR/kern/variance': array(1.), 'GPR/kern/lengthscales': array(1.), 'GPR/kern/nn/b_1': array([0., 0.]), 'GPR/kern/nn/W_0': array([[-4.75919158e-01,  1.88486857e-01, -3.84147178e-02,
         2.62720026e-01,  9.06124464e-01, -5.94537359e-01,
         1.46138407e-01, -2.08862906e-01],
       [ 3.51798937e-01,  3.66754050e-01,  1.52819541e-01,
         5.76621505e-02, -1.24355239e-01, -1.29296224e-01,
         9.73364549e-01, -1.25636947e-01],
       [ 1.01263429e-03,  3.56751715e-02, -3.60008216e-01,
        -1.43553584e-01,  1.52471062e-01, -1.08164789e-01,
        -8.35118749e-01, -5.67854701e-01],
       [-6.38571955e-01,  2.20556593e-01,  2.21126659e-02,
         6.51005736e-01, -4.89658665e-01,  5.52800170e-02,
        -3.14614542e-01, -2.86436014e-02],
       [ 4.74934825e-02,  3.03534535e-01,  1.02455474e+00,
         3.81645273e-01,  1.06951060e-01, -5.84495480e-01,
        -1.13905403e+00,  1.74627356e-01]]), 'GPR/likelihood/variance': array(1.), 'GPR/kern/nn/b_0': array([0., 0

In [7]:
with tf.Session() as sess:
    m.initialize()
    print(variables.global_variables())
    print()
#     print(m.X._dataholder_tensor)
#     grads = tf.gradients(m._likelihood_tensor, variables.global_variables()[4])
    grads = tf.gradients(m._likelihood_tensor, variables.global_variables()[3])
    print(grads)
    mygrads = grads[0].eval()
    
    print(mygrads)

[<tf.Variable 'NN-48ce325b-0/W_0/unconstrained:0' shape=(5, 8) dtype=float64_ref>, <tf.Variable 'NN-48ce325b-0/b_0/unconstrained:0' shape=(8,) dtype=float64_ref>, <tf.Variable 'NN-48ce325b-0/b_1/unconstrained:0' shape=(2,) dtype=float64_ref>, <tf.Variable 'NN-48ce325b-0/W_1/unconstrained:0' shape=(8, 2) dtype=float64_ref>, <tf.Variable 'NN_RBF-7ec9a3f1-5/lengthscales/unconstrained:0' shape=() dtype=float64_ref>, <tf.Variable 'NN_RBF-7ec9a3f1-5/variance/unconstrained:0' shape=() dtype=float64_ref>, <tf.Variable 'GPR-0603eec7-12/Y/dataholder:0' shape=<unknown> dtype=float64_ref>, <tf.Variable 'GPR-0603eec7-12/X/dataholder:0' shape=<unknown> dtype=float64_ref>, <tf.Variable 'GPR-0603eec7-12/likelihood/variance/unconstrained:0' shape=() dtype=float64_ref>]

[<tf.Tensor 'gradients_1/GPR-0603eec7-12/likelihood_1/MatMul_1_grad/MatMul_1:0' shape=(8, 2) dtype=float64>]
[[-0.35956263 -0.92269742]
 [-0.2451445   0.25261824]
 [-0.94166083  0.23038071]
 [ 0.54413949  0.83356337]
 [ 0.64446429  0.47