In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from plotting import plot_3d, plot_2d
%matplotlib widget

### First lets compare the pure loss computation

In [2]:
input_data = tf.convert_to_tensor(np.load('data/input.npy'))
weights = tf.convert_to_tensor(np.load('data/weights.npy'))
target = tf.convert_to_tensor(np.load('data/target.npy'))
predictions = input_data@weights
loss = tf.reduce_mean((predictions - target) ** 2)
grads = tf.gradients(loss, weights)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [3]:
with tf.Session() as sess:
    loss_, grads_ = sess.run([loss, grads])
    print('Loss', loss_)
    print('Grads', grads_)

Loss 5.943384089129164
Grads [array([[ 0.47266078],
       [-5.36427606]])]


### This time with kfac

In [4]:
import kfac






In [5]:
input_data = tf.convert_to_tensor(np.load('data/input.npy'), dtype=tf.float32)
weights = tf.Variable(tf.convert_to_tensor(np.load('data/weights.npy'), dtype=tf.float32))
target = tf.convert_to_tensor(np.load('data/target.npy'), dtype=tf.float32)
prediction = input_data@weights
loss = tf.reduce_mean((prediction - target) ** 2)

In [6]:
layer_collection = kfac.LayerCollection()
layer_collection.register_normal_predictive_distribution(prediction)
layer_collection.auto_register_layers()







In [7]:
kfac_optimizer = kfac.PeriodicInvCovUpdateKfacOpt(learning_rate=0.005, damping=0.001, layer_collection=layer_collection)
kfac_train_op = kfac_optimizer.minimize(loss)

Instructions for updating:
Colocations handled automatically by placer.


In [8]:
history = []
with tf.train.MonitoredTrainingSession() as sess:
    for i in range(100):
        loss_, converged_, weights_ = sess.run([loss, kfac_train_op, weights])
        history.append(np.concatenate([weights_.reshape(-1), loss_[None]]))
history = np.array(history)

INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


In [9]:
plt.close()
plot_3d(history, np.load('data/input.npy'), np.load('data/target.npy'))
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.close()
plt.title('TF1 KFAC')
plot_2d(history, np.load('data/input.npy'), np.load('data/target.npy'), (-0.5, 1, -3, 1))
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …