In [101]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [102]:
%autoreload 2

In [108]:
from datautils import get_datasets
from utils import one_hot, train_epoch, eval_model, create_train_state, model_class, cross_entropy_loss
from models import CNN, simpleMLP
from influence_utils import hvp, tree2NormalTree

In [5]:
from jax import value_and_grad, grad, random
import jax.numpy as jnp

In [6]:
train_ds, test_ds = get_datasets()

In [31]:
num_epochs = 5
batch_size = 64
rng = random.PRNGKey(0)
rng, init_rng = random.split(rng)
learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

In [8]:
""" Training """
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))

train epoch: 1, loss: 0.2453, accuracy: 92.47
 test epoch: 1, loss: 0.15, accuracy: 95.30
train epoch: 2, loss: 0.1149, accuracy: 96.58
 test epoch: 2, loss: 0.14, accuracy: 95.68
train epoch: 3, loss: 0.0906, accuracy: 97.17
 test epoch: 3, loss: 0.11, accuracy: 96.71
train epoch: 4, loss: 0.0723, accuracy: 97.68
 test epoch: 4, loss: 0.11, accuracy: 96.77
train epoch: 5, loss: 0.0601, accuracy: 98.04
 test epoch: 5, loss: 0.12, accuracy: 96.46


In [26]:

from jax import jvp, grad
model_class = simpleMLP
def hvp(f, x, v):
    return jvp(grad(f), x, v)[1]

def loss_fn(params, batch):
    logits = model_class().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss


In [45]:
key = random.PRNGKey(0)
idx = random.randint(key, (10,), 0, 100)
batch_ = {k: v[idx, ...] for k, v in train_ds.items()}

In [119]:

f = lambda x: loss_fn(x, batch_)

Dense_0
Dense_1


In [117]:
hvp(f, (state.params,) , (j,) )

FrozenDict({
    Dense_0: {
        bias: DeviceArray([ 0.00388239,  0.        ,  0.07091968,  0.00703276,
                      0.02698388,  0.08050614,  0.00492003,  0.01639154,
                      0.02078302,  0.00424677, -0.00664338,  0.00224208,
                      0.        ,  0.01827   , -0.04312547,  0.07010701,
                      0.0098291 ,  0.02584918,  0.01241394,  0.00508985,
                      0.        ,  0.06919228,  0.        ,  0.00167242,
                     -0.04295541,  0.02107832,  0.03713409,  0.02528018,
                      0.04704814, -0.00055497,  0.00119073,  0.04118587,
                     -0.04206506, -0.01093167,  0.04885335,  0.        ,
                     -0.00079792,  0.03822491,  0.02522604,  0.01532115,
                      0.01830385,  0.05647188,  0.03697992,  0.02376277,
                      0.01765823, -0.0043818 , -0.02464719,  0.04425436,
                      0.02752791,  0.05659475,  0.02768526,  0.        ,
                 

In [12]:
from jax.tree_util import tree_flatten, tree_unflatten

In [13]:
value_flat, value_tree = tree_flatten(state.params)

In [15]:
len(value_flat)

4

In [104]:
for val in value_flat:
    print(val.shape)

(100,)
(784, 100)
(10,)
(100, 10)


In [17]:
value_tree

PyTreeDef(CustomNode(FrozenDict[()], [{'Dense_0': {'bias': *, 'kernel': *}, 'Dense_1': {'bias': *, 'kernel': *}}]))

In [58]:
v = [jnp.ones(shape) for shape in [val.shape for val in value_flat]]

In [59]:
v_ = tree_unflatten(value_tree, v)

In [60]:
for val in v:
    print(val.shape)

(100,)
(784, 100)
(10,)
(100, 10)


In [62]:
from jax import jacfwd, jacrev

In [63]:
H = jacfwd(jacrev(f))(state.params)

In [69]:
h, _ = tree_flatten(H)

In [71]:
for val in h:x
    print(val.shape)

(100, 100)
(100, 784, 100)
(100, 10)
(100, 100, 10)
(784, 100, 100)
(784, 100, 784, 100)
(784, 100, 10)
(784, 100, 100, 10)
(10, 100)
(10, 784, 100)
(10, 10)
(10, 100, 10)
(100, 10, 100)
(100, 10, 784, 100)
(100, 10, 10)
(100, 10, 100, 10)


TypeError: primal and tangent arguments to jax.jvp must be tuples or lists; found FrozenDict and FrozenDict.

In [112]:
J, _ = tree_flatten(j)

In [115]:
params, _ = tree_flatten(state.params)

In [122]:
hvp(f, (state.params,), (j,))

FrozenDict({
    Dense_0: {
        bias: DeviceArray([ 0.00388239,  0.        ,  0.07091968,  0.00703276,
                      0.02698388,  0.08050614,  0.00492003,  0.01639154,
                      0.02078302,  0.00424677, -0.00664338,  0.00224208,
                      0.        ,  0.01827   , -0.04312547,  0.07010701,
                      0.0098291 ,  0.02584918,  0.01241394,  0.00508985,
                      0.        ,  0.06919228,  0.        ,  0.00167242,
                     -0.04295541,  0.02107832,  0.03713409,  0.02528018,
                      0.04704814, -0.00055497,  0.00119073,  0.04118587,
                     -0.04206506, -0.01093167,  0.04885335,  0.        ,
                     -0.00079792,  0.03822491,  0.02522604,  0.01532115,
                      0.01830385,  0.05647188,  0.03697992,  0.02376277,
                      0.01765823, -0.0043818 , -0.02464719,  0.04425436,
                      0.02752791,  0.05659475,  0.02768526,  0.        ,
                 