# Test TF2 port of IHVP

YJ Choe (yjchoe@cmu.edu)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import tensorflow as tf

from infopt.ihvp import IterativeIHVP, LowRankIHVP
from infopt.ihvp_tf import IterativeIHVPTF

%matplotlib inline

In [2]:
print("torch:", torch.__version__)
print("tensorflow:", tf.__version__)

torch: 1.7.1
tensorflow: 2.4.1


In [3]:
t_torch = torch.randn(4, 3)
t_tf = tf.constant(t_torch.numpy())
print("torch:", t_torch)
print("tf:", t_tf)

torch: tensor([[-0.1396, -1.8070,  0.2753],
        [-0.1218,  0.7186,  0.5395],
        [ 1.2287,  0.6572, -0.8190],
        [-1.7459,  0.5130,  0.2059]])
tf: tf.Tensor(
[[-0.1395616  -1.8070085   0.2752801 ]
 [-0.12175678  0.7186479   0.5395054 ]
 [ 1.228726    0.657225   -0.81902796]
 [-1.7459302   0.51304966  0.2058618 ]], shape=(4, 3), dtype=float32)


## Test Case

Optimize
$$
f(x, y) = \frac{1}{2} [x, y]^T M [x, y]
$$

for some well-conditioned matrix $M \in \mathbb{R}^{n \times n}$.

In [4]:
from tests.test_ihvp import TestIHVP

In [5]:
test_torch = TestIHVP()
test_torch.setUp()

In [6]:
test_torch.params

[tensor([[ 0.6020, -1.2975],
         [ 0.7429,  0.6762]], requires_grad=True),
 tensor([[ 0.7552, -1.2698],
         [ 0.7202, -0.1917],
         [ 1.5910,  0.7936]], requires_grad=True)]

In [7]:
test_torch.test_iterative_ihvp()

In [8]:
test_torch.out

tensor(2.2972, grad_fn=<MulBackward0>)

In [9]:
torch.autograd.grad(test_torch.out, test_torch.params, create_graph=True)

(tensor([[ 0.4912, -0.5777],
         [ 0.7830,  0.1143]], grad_fn=<ViewBackward>),
 tensor([[ 0.6828, -0.9290],
         [ 0.1508, -0.1868],
         [ 0.5814,  0.1581]], grad_fn=<ViewBackward>))

## TF2

In [10]:
def allclose_tf(A, B, tol=1e-5):
    return tf.reduce_sum((A - B)**2).numpy() < tol

In [11]:
from tests.test_ihvp_tf import TestIHVP_TF

In [12]:
test_tf = TestIHVP_TF()
test_tf.setUp()

In [13]:
test_tf.params

[<tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
 array([[ 1.2110573 ,  0.01387285],
        [-1.611183  ,  1.6539273 ]], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=(3, 2) dtype=float32, numpy=
 array([[ 1.6525732 , -0.5374252 ],
        [-1.1671807 , -1.680566  ],
        [ 1.2428781 , -0.05604479]], dtype=float32)>]

In [14]:
test_tf.out

<tf.Tensor: shape=(), dtype=float32, numpy=3.5493422>

In [18]:
test_tf.test_iterative_ihvp()

IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 883.18it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 797.80it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 655.40it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 652.98it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 738.48it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 790.49it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 963.79it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 974.01it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 906.93it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 920.41it/s]


In [22]:
test_tf.outer_tape._persistent

True

In [33]:
test_tf.setUp()
ihvp = IterativeIHVPTF(test_tf.params, iters=1000)
test_tf._test_ihvp(ihvp)

IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 819.33it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 796.17it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 693.52it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 747.30it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 756.22it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 937.27it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 968.70it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 719.44it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:01<00:00, 887.98it/s]
IterativeIHVP_TF.get_ihvp: 100%|██████████| 1000/1000 [00:00<00:00, 1031.02it/s]


In [40]:
hvps = [ihvp.compute_hvp(v) for v in test_tf.vs]
[tf.concat([tf.reshape(h, [-1]) for h in hvp], axis=0) for hvp in hvps]

[<tf.Tensor: shape=(10,), dtype=float32, numpy=
 array([-0.25477085, -0.05003049,  0.42217764, -0.08275922,  0.26389915,
        -0.0652815 ,  0.40807134,  0.0514034 ,  0.03918216, -0.30609643],
       dtype=float32)>,
 <tf.Tensor: shape=(10,), dtype=float32, numpy=
 array([-0.44654477,  0.05398908, -0.04285225, -0.34855127, -0.37472948,
        -0.18207447,  0.28177   ,  0.0738418 , -0.0641168 , -0.3288384 ],
       dtype=float32)>,
 <tf.Tensor: shape=(10,), dtype=float32, numpy=
 array([-0.35499635,  0.24794406,  0.41747937, -0.3864661 , -0.28803688,
         0.986034  ,  0.28840446, -0.3360384 , -0.21421148,  0.03048006],
       dtype=float32)>,
 <tf.Tensor: shape=(10,), dtype=float32, numpy=
 array([ 0.30575866,  0.25215325, -0.17385244,  0.6436982 , -0.19871779,
        -0.7207403 , -0.00875807,  0.5046706 ,  0.7557608 ,  0.38369173],
       dtype=float32)>,
 <tf.Tensor: shape=(10,), dtype=float32, numpy=
 array([ 0.55017054, -0.01878751, -0.8293436 ,  0.6875292 , -0.4026634 ,
   

In [27]:
tf.reduce_sum(tf.multiply(tf.reshape(test_tf.params[0], [-1]), tf.reshape(test_tf.params[0], [-1])))

<tf.Tensor: shape=(), dtype=float32, numpy=6.7982388>

In [None]:
tf.random.set_seed(0)

# 1. x and y are two parameter tensors
x = tf.Variable(tf.random.normal(shape=(2, 2)))
nx = x.shape.num_elements()
y = tf.Variable(tf.random.normal(shape=(3, 2)))
ny = y.shape.num_elements()
n = nx + ny
print("x:", x)
print("y:", y)
print(n)

In [None]:
params = [x, y]
x_flat = tf.Variable(tf.reshape(x, -1))
y_flat = tf.Variable(tf.reshape(y, -1))

z = tf.Variable(tf.expand_dims(tf.Variable(tf.concat([x_flat, y_flat], -1)), 1))
print(z)

In [None]:
A = tf.random.normal(shape=(n, n))
M = tf.transpose(A) @ A + 0.01 * tf.linalg.eye(n)
E, V = tf.linalg.eigh(M)
emin_good = E[-1] / 5
E = tf.maximum(E, emin_good)  # E_max/E_min <= 0.2
if E[-1] > 1:
    E = E / E[-1]
M = V @ tf.linalg.diag(E) @ tf.transpose(V)
M_inv = tf.linalg.inv(M)

In [None]:
# objective
objective = 0.5 * tf.reduce_sum(tf.multiply(z, M @ z))
objective

In [None]:
# true ihvp outputs
vs, ihvp_true_flats = [], []
for _ in range(10):
    v = [tf.random.normal(p.shape) for p in params]
    vs.append(v)
    v_flat = tf.expand_dims(tf.concat([tf.reshape(v_i, -1) for v_i in v], -1), -1)
    ihvp_true_flat = M_inv @ v_flat
    ihvp_true_flats.append(ihvp_true_flat)

In [None]:
vs[0], ihvp_true_flats[0]

## TF2 Gradients & Hessians

...are managed by `tf.GradientTape`. 

According to [this documentation](https://www.tensorflow.org/api_docs/python/tf/autodiff/ForwardAccumulator), it is more efficient to consider the `forwardprop` module when considering many-inputs-to-scalar-output NNs.
The module also supports efficiently computing Jacobian-vector products (JVPs) and Hessian-vector products (HVPs) without explicitly constructing the Jacobian and the Hessian.

In [None]:
v_flats = [
    tf.expand_dims(tf.concat([tf.reshape(v_i, -1) for v_i in v], -1), -1)
    for v in vs
]
v_flat = v_flats[0]
print(v_flat)

### forward-over-backward

This could be faster, but requires that v's are known prior to the gradient computation.

In [None]:
z = tf.Variable(z)

with tf.autodiff.ForwardAccumulator(z, v_flat) as acc:
    with tf.GradientTape() as tape:
        # objective = 0.5 * (tf.transpose(z) @ (M @ z))
        objective = 0.5 * tf.reduce_sum(tf.multiply(z, M @ z))
    grad = tape.gradient(objective, z)

# gradient = M @ z
print(allclose_tf(grad, M @ z))

In [None]:
# hessian-vector product with v = M @ v
hvp = acc.jvp(grad)
print(allclose_tf(hvp, M @ v_flat))

### backward-over-backward

Compute $\mathrm{HVP}(\nabla_\theta f_\theta(z)) = H_\theta \nabla_\theta f_\theta(z) = (M)(Mz) = M^2z$.

In [None]:
z = tf.Variable(z)
with tf.GradientTape(persistent=True) as outer_tape:
    with tf.GradientTape() as inner_tape:
        objective = 0.5 * tf.reduce_sum(tf.multiply(z, M @ z))
        grad = inner_tape.gradient(objective, z)  # 1 x p
    print(allclose_tf(grad, M @ z))
    jvp = tf.transpose(grad) @ grad  # 1 x 1

In [None]:
jvp

In [None]:
hvp = outer_tape.gradient(jvp, z)  # 1 x p
print(allclose_tf(hvp, 2 * M @ (M @ z)))

In [None]:
hvp

In [None]:
# vector-valued objective
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
z = tf.Variable(z)  # (p, 1)
with tf.GradientTape(persistent=True) as outer_tape:
    with tf.GradientTape() as inner_tape:
        objective = loss_fn(M @ z, z)  # (p, )
    grad = inner_tape.jacobian(objective, [z, z])  # (p, p, 1)
#    print(allclose_tf(grad, M ))
    #jvp = tf.transpose(grad) @ grad  # 1 x 1

The following is a workaround for keeping gradient tapes active (without Python contexts) from [here](https://stackoverflow.com/questions/62452614/how-to-reuse-the-inner-gradient-in-nested-gradient-tapes).

In [None]:
@tf.function
def compute_hvp(params, gradient, v, outer_tape):
    """Computes the HVP knowing gradients and their outer tape."""
    assert isinstance(v, list)
    assert len(params) == len(gradient) == len(v)
    outer_tape._push_tape()
    jvp = [tf.reduce_sum(tf.multiply(g_i, v_i))
           for g_i, v_i in zip(gradient, v)]
    outer_tape._pop_tape()
    return outer_tape.gradient(jvp, params)

In [None]:
@tf.function
def setup(x, y):
    params = [x, y]
    x_flat = tf.reshape(x, [-1])
    y_flat = tf.reshape(y, [-1])

    z = tf.expand_dims(tf.concat([x_flat, y_flat], -1), 1)
    return z

In [None]:
z = setup(x, y)
z

In [None]:
z = setup(x, y)

outer_tape = tf.GradientTape(persistent=True)
outer_tape._push_tape()

# Recorded by both outer and inner
with tf.GradientTape() as inner_tape:
    objective = 0.5 * (tf.transpose(z) @ (M @ z))

grad = inner_tape.gradient(objective, [x, y])
print(grad)
#print(allclose_tf(grad, M @ z))

# Stop recording by outer, for now
outer_tape._pop_tape()

In [None]:
hvp = compute_hvp(z, grad, tf.identity(grad), outer_tape)
print(allclose_tf(hvp, M @ (M @ z)))

In [None]:
outer_tape.watched_variables()

In [None]:
for v_flat in v_flats:
    hvp = compute_hvp(z, grad, v_flat, outer_tape)
    print(allclose_tf(hvp, M @ v_flat))

In [None]:
# close tapes when over
inner_tape._tape = None
outer_tape._tape = None

In [None]:
tf.stop_gradient(grad)

### forward-over-backward

In [None]:
with tf.autodiff.ForwardAccumulator(z, v_flat) as acc:
    acc.watch(grad)
    hvp = acc.jvp(grad)
print(allclose_tf(hvp, M @ v_flat))    

In [None]:
tf.keras.losses.MSE(grad, M@z, reduction=tf.keras.losses.Reduction.NONE)

## Gradients: torch to TF2

In [None]:
test_torch.out

x_torch, y_torch = test_torch.params

grad_params_torch = torch.autograd.grad(test_torch.out, test_torch.params, create_graph=True)

In [None]:
grad_params_torch

### torch to numpy

In [None]:
x_np, y_np = [t.detach().numpy() for t in test_torch.params]
print("x:", x_np)
print("y:", y_np)

In [None]:
M_np = test_torch.M.detach().numpy()
M_inv_np = test_torch.M_inv.detach().numpy()
print(M_np.shape, M_inv_np.shape)

In [None]:
grad_params_np = [grad_param.detach().numpy() for grad_param in grad_params_torch]
grad_params_np

In [None]:
vs_np = [[v.detach().numpy() for v in v_params] 
         for v_params in test_torch.vs]
vs_np[0]

### numpy to tf2

In [None]:
x = tf.Variable(x_np)
y = tf.Variable(y_np)

params = [x, y]
x_flat = tf.reshape(x, -1)
y_flat = tf.reshape(y, -1)

z = tf.expand_dims(tf.concat([x_flat, y_flat], -1), 1)
print(z)

In [None]:
M = tf.constant(M_np)
objective = 0.5 * (tf.transpose(z) @ (M @ z))
print("tf2 objective == torch objective?", np.allclose(objective.numpy(), test_torch.out.detach().numpy()))

### tf2 gradients

In [None]:
with tf.GradientTape(persistent=True) as tape:
    x_flat = tf.reshape(x, -1)
    y_flat = tf.reshape(y, -1)
    z = tf.expand_dims(tf.concat([x_flat, y_flat], -1), 1)
    objective = 0.5 * (tf.transpose(z) @ (M @ z))
grad_params = tape.gradient(objective, params)
print(grad_params)

In [None]:
# torch
grad_params_torch

In [None]:
vs_tf = [[tf.Variable(v) for v in v_param] for v_param in vs_np]
vs_tf[0]

In [None]:
x = tf.Variable(x_np)
y = tf.Variable(y_np)
params = [x, y]
n_params = len(params)

with tf.GradientTape(persistent=True) as outer_tape:
    with tf.GradientTape(persistent=True) as tape:
        x_flat = tf.reshape(x, -1)
        y_flat = tf.reshape(y, -1)
        z = tf.expand_dims(tf.concat([x_flat, y_flat], -1), 1)
        objective = 0.5 * (tf.transpose(z) @ (M @ z))
    grad_params = tape.gradient(objective, params)
    print("grad_params:", grad_params)

    for v_param in vs_tf:  # repeat for 10 random v's
        ihvp = v_param[:]
        assert len(ihvp) == n_params
        grad_params_ihvp = [
            tf.reduce_sum(grad_params[i] * ihvp[i]) 
            for i in range(n_params)
        ]
        print("grad_params_ihvp:", grad_params_ihvp)
        break
H_ihvp = outer_tape.gradient(grad_params_ihvp, params)
print("H_ihvp:", H_ihvp)
for i in range(n_params):
    ihvp[i] = v[i] + (1.0) * ihvp[i] - H_ihvp[i]
    #ihvp[i] = ihvp[i].stop_gradient()
print("ihvp:", ihvp)

In [None]:
x_torch, y_torch = test_torch.params
grad_params_torch = torch.autograd.grad(test_torch.out, test_torch.params, create_graph=True)

In [None]:
# torch
v_torch = test_torch.vs[0]
ihvp_torch = v_torch[:]
for _ in range(1):
    # Apply the recursion ihvp <- v + ihvp - H*ihvp
    grad_params_ihvp_torch = [
        grad_params_torch[i].view(-1) @ ihvp_torch[i].view(-1) for i in range(len(ihvp_torch))
    ]
    with torch.no_grad():
        H_ihvp_torch = torch.autograd.grad(
            grad_params_ihvp_torch, test_torch.params, create_graph=True
        )
        for i in range(len(ihvp_torch)):
            ihvp_torch[i] = v_torch[i] + (1.0) * ihvp_torch[i] - H_ihvp_torch[i]
            ihvp_torch[i] = ihvp_torch[i].detach()
print(ihvp_torch)

In [None]:
grad_params_torch[0]

In [None]:
ihvp_torch[0]

In [None]:
grad_params_torch[0].view(-1) @ ihvp_torch[0].view(-1)

## TF2 Models

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data(
    path='boston_housing.npz', test_split=0.2, seed=113
)
print(x_train.shape, y_train.shape)

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(13, )),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(1)
])
model.summary()

In [None]:
x_train = tf.convert_to_tensor(x_train, dtype=tf.float32)
y_train = tf.convert_to_tensor(y_train, dtype=tf.float32)
train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
len(train)

In [None]:
x_train

In [None]:
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)

outer_tape = tf.GradientTape(persistent=True)
outer_tape._push_tape()

# Recorded by both outer and inner
x, y = iter(train.batch(8)).get_next()
with tf.GradientTape() as inner_tape:
    preds = model(x)
    losses = loss_fn(preds[:, tf.newaxis], y[:, tf.newaxis])

gradients = list(zip(*[
    list(g) for g in inner_tape.jacobian(
        losses, model.trainable_variables
    )]))  # gradients per example (first) per layer
assert len(gradients) == 8
assert len(gradients[0]) == len(model.trainable_variables)
mean_loss = tf.reduce_mean(losses)

# Stop recording by outer, for now
outer_tape._pop_tape()

In [None]:
[t.shape for t in model.trainable_variables]

In [None]:
[t.shape for t in gradients[0]]

In [None]:
tf.reshape(gradient[0], -1)[tf.newaxis, :] @ tf.reshape(v[0], -1)[:, tf.newaxis]

In [None]:
tf.reduce_sum(tf.multiply(gradient[0], v[0]))

In [None]:
def compute_hvp(params, gradient, v, outer_tape):
    """Computes the HVP knowing gradients and their outer tape."""
    assert isinstance(v, list)
    assert len(params) == len(gradient) == len(v)
    outer_tape._push_tape()
    jvp = [tf.reduce_sum(tf.multiply(g_i, v_i))
           for g_i, v_i in zip(gradient, v)]
    outer_tape._pop_tape()
    return outer_tape.gradient(jvp, params)

In [None]:
vs = [[tf.stop_gradient(tf.identity(g)) for g in gradient]
      for gradient in gradients]  # influence
# loop over batch; outer_tape is recycled
for gradient, v in zip(gradients, vs):
    hvp = compute_hvp(model.trainable_variables, gradient, v, outer_tape)
    print([t.shape for t in hvp])

In [None]:
loss_fn = tf.keras.losses.MSE
loss_fn(predictions, y_train[:1])

In [None]:
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.MSE,
    metrics=["mse"],
)

In [None]:
model.fit(x_train, y_train, validation_split=0.1, epochs=200)

In [None]:
model.evaluate(x_test, y_test, verbose=1)

In [None]:
predictions = model(x_test)

print("predicted:", predictions.numpy().squeeze()[:5])
print("true:", y_test[:5])

In [None]:
v_flat = v_flats[0]

In [None]:
v_flat