In [2]:
import numpy as np
import tensorflow as tf
from collections import OrderedDict
import tqdm
import matplotlib.pyplot as plt
from plotting import plot_3d, plot_2d
from tf2kfac import tf2KFAC
%matplotlib widget

### First lets compare the pure loss computation

In [3]:
input_data = tf.convert_to_tensor(np.load('data/input.npy'), dtype=tf.float32)
weights = tf.convert_to_tensor(np.load('data/weights.npy'), dtype=tf.float32)
target = tf.convert_to_tensor(np.load('data/target.npy'), dtype=tf.float32)

In [4]:
# need a softmax layer to define the output distribution p(y | x) to compute the fisher

with tf.GradientTape() as g:
    g.watch(weights)
    loss = tf.reduce_mean(((input_data@weights) - target)**2)  
print('Loss', loss)
print('Gradients', g.gradient(loss, weights))

Loss tf.Tensor(5.943384, shape=(), dtype=float32)
Gradients tf.Tensor(
[[ 0.47266078]
 [-5.364277  ]], shape=(2, 1), dtype=float32)


### SGD

In [6]:
input_data = tf.convert_to_tensor(np.load('data/input.npy'), dtype=tf.float32)
weights = tf.convert_to_tensor(np.load('data/weights.npy'), dtype=tf.float32)
target = tf.convert_to_tensor(np.load('data/target.npy'), dtype=tf.float32)
history = []
for i in range(100):
    with tf.GradientTape(persistent=True) as g:
        g.watch(weights)
        prediction = input_data@weights
        loss = tf.reduce_mean((prediction - target)**2)
    grads = g.gradient(loss, weights)
    history.append(np.concatenate([weights.numpy().reshape(-1), loss.numpy()[None]]))
    weights -= 0.01 * grads
    if i % 10 == 0:
        print(i, 'Loss', loss)
history = np.array(history, dtype=np.float32)

0 Loss tf.Tensor(5.943384, shape=(), dtype=float32)
10 Loss tf.Tensor(3.706224, shape=(), dtype=float32)
20 Loss tf.Tensor(2.4449236, shape=(), dtype=float32)
30 Loss tf.Tensor(1.7306938, shape=(), dtype=float32)
40 Loss tf.Tensor(1.3236729, shape=(), dtype=float32)
50 Loss tf.Tensor(1.0895984, shape=(), dtype=float32)
60 Loss tf.Tensor(0.9532451, shape=(), dtype=float32)
70 Loss tf.Tensor(0.8724056, shape=(), dtype=float32)
80 Loss tf.Tensor(0.8233486, shape=(), dtype=float32)
90 Loss tf.Tensor(0.7926893, shape=(), dtype=float32)


In [7]:
plt.close()
plot_3d(history, input_data, target)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
plt.close()
plt.title('TF2 SGD')
plot_2d(history, input_data, target, (-0.5, 1, -3, 1))
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …