In [18]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import tensorflow as tf
from utils import load_pk
from tf2KFAC import tf2KFAC
import numpy as np

In [None]:
kfac = tf2KFAC(tinv=10,
               tcov=1,
               lr0=0.0001,
               cov_weight=1.,
               cov_moving_weight=0.95,
               damping=0.001,
               conv_approx='mg',
               damping_method='factored_tikhonov',
               ft_method='original',
               norm_constraint=0.95,
               inp_dim=10,
               out_dim=5)

batch_size = 4

weights = load_pk('weights.pk')
data = load_pk('data.pk')
labels = load_pk('labels.pk')

data = tf.reshape(tf.convert_to_tensor(data[:batch_size, :], dtype=tf.float32), (batch_size, -1))
labels = tf.reshape(tf.convert_to_tensor(labels[:batch_size], dtype=tf.int32), (-1,))
weights = tf.convert_to_tensor(weights, dtype=tf.float32)

# build the model
weights = tf.Variable(weights)
with tf.GradientTape(True) as g:
    logits = data @ weights

    dist = tf.nn.softmax(logits, axis=-1)  # this is the model distribution

    nl_dist = -tf.math.log(dist)

    no_mean_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))

# sensitivities = g.gradient(dist, logits)  # * batch_size
sensitivities = g.gradient(nl_dist, logits)  # * batch_size
# sensitivities = g.gradient(no_mean_loss, logits)  # * batch_size
# sensitivities = g.gradient(loss, logits) # * batch_size

print(sensitivities.shape)

grads = g.gradient(loss, weights)

sensitivities = [sensitivities]
activations = [data]
grads = [grads]

ng = kfac.compute_updates(activations, sensitivities, grads, 0)

print(kfac.m_aa[0])
print(kfac.m_ss[0])

In [None]:
# TARGET
# [array([[ 0.03410191, -0.0048112 , -0.01537001, -0.05537548,  0.04145478],
#        [-0.0048112 ,  0.17231973, -0.04160231, -0.10379265, -0.02211356],
#        [-0.01537001, -0.04160231,  0.2006559 , -0.12125484, -0.02242876],
#        [-0.05537548, -0.10379265, -0.12125484,  0.32963288, -0.04920992],
#        [ 0.04145478, -0.02211356, -0.02242876, -0.04920992,  0.05229746]],
#       dtype=float32)]

In [19]:
s = np.array([[ 0.00789194, -0.11012796, -0.04974077, -0.11113364,  0.26311067],
     [-0.28675118, -0.14651796,  0.62029225,  0.16511022, -0.3521331 ],
     [ 0.08187593,  0.13277604, -0.2824541,  -0.09003425,  0.15783665],
     [-0.15432523,  0.20151532, -0.275451,   -0.14067313,  0.3689341 ]])
batch_size = s.shape[0]

In [20]:
s1 = tf.matmul(s, s, transpose_a=True)
s1

<tf.Tensor: id=42, shape=(5, 5), dtype=float64, numpy=
array([[ 0.11280847,  0.02091734, -0.15887924, -0.03388484,  0.05903822],
       [ 0.02091734,  0.09183358, -0.17841684, -0.05225487,  0.11792078],
       [-0.15887924, -0.17841684,  0.54289019,  0.17212356, -0.37771764],
       [-0.03388484, -0.05225487,  0.17212356,  0.06750717, -0.15349104],
       [ 0.05903822,  0.11792078, -0.37771764, -0.15349104,  0.35424972]])>

In [21]:
rm = lambda x: tf.reduce_mean(x, axis=0, keepdims=True)
s1 = tf.matmul(rm(s), rm(s), transpose_a=True)
s1

<tf.Tensor: id=49, shape=(5, 5), dtype=float64, numpy=
array([[ 7.71360564e-03, -1.70484414e-03, -2.77673831e-04,
         3.88043996e-03, -9.61154520e-03],
       [-1.70484414e-03,  3.76800897e-04,  6.13708587e-05,
        -8.57646295e-04,  2.12432256e-03],
       [-2.77673831e-04,  6.13708587e-05,  9.99568294e-06,
        -1.39687803e-04,  3.45995725e-04],
       [ 3.88043996e-03, -8.57646295e-04, -1.39687803e-04,
         1.95211098e-03, -4.83522567e-03],
       [-9.61154520e-03,  2.12432256e-03,  3.45995725e-04,
        -4.83522567e-03,  1.19764745e-02]])>

In [22]:
rm = lambda x: tf.reduce_sum(x, axis=0, keepdims=True)
s1 = tf.matmul(rm(s), rm(s), transpose_a=True) / batch_size
s1

<tf.Tensor: id=58, shape=(5, 5), dtype=float64, numpy=
array([[ 3.08544226e-02, -6.81937654e-03, -1.11069532e-03,
         1.55217598e-02, -3.84461808e-02],
       [-6.81937654e-03,  1.50720359e-03,  2.45483435e-04,
        -3.43058518e-03,  8.49729023e-03],
       [-1.11069532e-03,  2.45483435e-04,  3.99827318e-05,
        -5.58751214e-04,  1.38398290e-03],
       [ 1.55217598e-02, -3.43058518e-03, -5.58751214e-04,
         7.80844392e-03, -1.93409027e-02],
       [-3.84461808e-02,  8.49729023e-03,  1.38398290e-03,
        -1.93409027e-02,  4.79058979e-02]])>