In [34]:
import pickle
import os
import argparse
import torch
from jax import random
import json
import datetime
from src.losses import sse_loss
from src.helper import calculate_exact_ggn, tree_random_normal_like, compute_num_params
from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive
from jax import numpy as jnp
import jax
from jax import flatten_util
import matplotlib.pyplot as plt
import tree_math as tm
from src.laplace.diagonal import hutchinson_diagonal
from flax import linen as nn
from jax import nn as jnn



In [88]:
class FC_NN(nn.Module):  # create a Flax Module dataclass
    out_dims: 1
    hidden_dim: 64
    num_layers: 3

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        for _ in range(self.num_layers):
            x = nn.Dense(self.hidden_dim)(x)
            x = jnn.tanh(x)
        x = self.perturb('last_layer', x)
        x = nn.Dense(self.out_dims)(x)  # shape inference
        return x

In [89]:
output_dim = 1
model = FC_NN(output_dim, 10, 2)

In [90]:
x = jnp.ones((5, 4)) # random data
y = jnp.empty((5, 1))

In [91]:
variables = model.init(jax.random.key(1), x)


In [92]:
params, perturbations = variables['params'], variables['perturbations']


In [93]:
def loss_fn(params, perturbations, x, y):
  y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
  residual = y_pred - y
  return 0.5 * jnp.sum(residual**2)


In [97]:
model_fn_ = lambda p, pert, x: model.apply({'params': p, 'perturbations': pert}, x)
intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)


In [98]:
intermediate_grads

{'last_layer': Array([[ 0.7660077 , -0.3599549 ,  0.05328136,  0.29065937, -0.7825169 ,
          0.61560875, -0.11423077, -0.3838355 ,  0.6061546 ,  0.4577563 ],
        [ 0.7660077 , -0.3599549 ,  0.05328136,  0.29065937, -0.7825169 ,
          0.61560875, -0.11423077, -0.3838355 ,  0.6061546 ,  0.4577563 ],
        [ 0.7660077 , -0.3599549 ,  0.05328136,  0.29065937, -0.7825169 ,
          0.61560875, -0.11423077, -0.3838355 ,  0.6061546 ,  0.4577563 ],
        [ 0.7660077 , -0.3599549 ,  0.05328136,  0.29065937, -0.7825169 ,
          0.61560875, -0.11423077, -0.3838355 ,  0.6061546 ,  0.4577563 ],
        [ 0.7660077 , -0.3599549 ,  0.05328136,  0.29065937, -0.7825169 ,
          0.61560875, -0.11423077, -0.3838355 ,  0.6061546 ,  0.4577563 ]],      dtype=float32)}

In [101]:
full_grads = jax.grad(loss_fn, argnums=0)(params, perturbations, x, y)


In [102]:
full_grads

{'Dense_0': {'bias': Array([-0.20086901, -0.02301684, -1.0536946 ,  0.38661367, -0.4682266 ,
         -0.8961696 , -0.91292334,  1.5687138 , -0.5860333 , -0.748496  ],      dtype=float32),
  'kernel': Array([[-0.20086901, -0.02301684, -1.0536946 ,  0.38661367, -0.4682266 ,
          -0.8961696 , -0.91292334,  1.5687138 , -0.5860333 , -0.748496  ],
         [-0.20086901, -0.02301684, -1.0536946 ,  0.38661367, -0.4682266 ,
          -0.8961696 , -0.91292334,  1.5687138 , -0.5860333 , -0.748496  ],
         [-0.20086901, -0.02301684, -1.0536946 ,  0.38661367, -0.4682266 ,
          -0.8961696 , -0.91292334,  1.5687138 , -0.5860333 , -0.748496  ],
         [-0.20086901, -0.02301684, -1.0536946 ,  0.38661367, -0.4682266 ,
          -0.8961696 , -0.91292334,  1.5687138 , -0.5860333 , -0.748496  ]],      dtype=float32)},
 'Dense_1': {'bias': Array([ 3.7293425 , -1.7791559 ,  0.25436106,  0.58609045, -1.3698171 ,
          2.6196804 , -0.35633272, -1.3191556 ,  2.2977252 ,  0.45500964],      d

In [99]:
ggn_la = 0
for i in range(5):
    J = (intermediate_grads['last_layer'][i]).reshape(1, -1) 
    ggn_la += J.T @ J

In [100]:
ggn_la

Array([[ 2.9338393 , -1.3786411 ,  0.20406967,  1.1132367 , -2.9970698 ,
         2.3578053 , -0.43750826, -1.4701047 ,  2.3215957 ,  1.7532243 ],
       [-1.3786411 ,  0.6478376 , -0.09589443, -0.5231213 ,  1.4083539 ,
        -1.1079569 ,  0.20558962,  0.6908173 , -1.0909415 , -0.82385814],
       [ 0.20406967, -0.09589443,  0.01419452,  0.07743363, -0.20846783,
         0.16400234, -0.03043185, -0.10225639,  0.16148372,  0.12194939],
       [ 1.1132367 , -0.5231213 ,  0.07743363,  0.42241436, -1.1372293 ,
         0.89466226, -0.16601121, -0.5578269 ,  0.8809226 ,  0.6652558 ],
       [-2.9970698 ,  1.4083539 , -0.20846783, -1.1372293 ,  3.0616636 ,
        -2.4086213 ,  0.44693753,  1.5017889 , -2.3716311 , -1.7910101 ],
       [ 2.3578053 , -1.1079569 ,  0.16400234,  0.89466226, -2.4086213 ,
         1.8948708 , -0.35160732, -1.1814625 ,  1.8657706 ,  1.408994  ],
       [-0.43750826,  0.20558962, -0.03043185, -0.16601121,  0.44693753,
        -0.35160732,  0.06524334,  0.2192291 

In [66]:
jac = jax.jacfwd(lambda p: model.apply({'params': p, 'perturbations': perturbations}, x))(params)

In [79]:
B = jac['Dense_2']['bias'].reshape(-1, 5)
print(B.T @ B)

[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]


In [None]:
K = jac['Dense_2']['kernel'].reshape(5, 10)
print(K.T @ K)

[[ 0.13145538  0.08677496 -0.17239244 -0.6262651  -0.6535753   0.312854
  -0.4972057  -0.45331508  0.3987159   0.72567916]
 [ 0.08677496  0.05728098 -0.1137979  -0.41340366 -0.43143135  0.20651793
  -0.32821026 -0.29923764  0.2631962   0.4790278 ]
 [-0.17239244 -0.1137979   0.2260779   0.821293    0.85710794 -0.41028115
   0.6520426   0.59448385 -0.5228817  -0.951666  ]
 [-0.6262651  -0.41340366  0.821293    2.9835832   3.113691   -1.4904643
   2.3687322   2.1596336  -1.8995183  -3.4572005 ]
 [-0.6535753  -0.43143135  0.85710794  3.113691    3.2494729  -1.5554606
   2.472028    2.253811   -1.9823523  -3.607962  ]
 [ 0.312854    0.20651793 -0.41028115 -1.4904643  -1.5554606   0.7445692
  -1.1833125  -1.0788561   0.94891423  1.7270623 ]
 [-0.4972057  -0.32821026  0.6520426   2.3687322   2.472028   -1.1833125
   1.880589    1.7145807  -1.5080694  -2.7447476 ]
 [-0.45331508 -0.29923764  0.59448385  2.1596336   2.253811   -1.0788561
   1.7145807   1.5632268  -1.3749452  -2.5024562 ]
 [ 0.39

In [60]:
model_fn = lambda params, x: model.apply({'params': params, 'perturbations': perturbations}, x[None, ...])[0]
n_params = compute_num_params(params) 
def sse_loss(preds, y):
    residual = preds - y
    return 0.5 * jnp.sum(residual**2)

ggn = calculate_exact_ggn(sse_loss, model_fn, params, x, y, n_params)


In [61]:
ggn[-10: , -10:]

Array([[ 0.13145538,  0.08677496, -0.17239244, -0.6262651 , -0.6535753 ,
         0.312854  , -0.4972057 , -0.45331508,  0.3987159 ,  0.72567916],
       [ 0.08677496,  0.05728098, -0.1137979 , -0.41340366, -0.43143135,
         0.20651793, -0.32821026, -0.29923764,  0.2631962 ,  0.4790278 ],
       [-0.17239244, -0.1137979 ,  0.2260779 ,  0.821293  ,  0.85710794,
        -0.41028115,  0.6520426 ,  0.59448385, -0.5228817 , -0.951666  ],
       [-0.6262651 , -0.41340366,  0.821293  ,  2.9835832 ,  3.113691  ,
        -1.4904643 ,  2.3687322 ,  2.1596336 , -1.8995183 , -3.4572005 ],
       [-0.6535753 , -0.43143135,  0.85710794,  3.113691  ,  3.2494729 ,
        -1.5554606 ,  2.472028  ,  2.253811  , -1.9823523 , -3.607962  ],
       [ 0.312854  ,  0.20651793, -0.41028115, -1.4904643 , -1.5554606 ,
         0.7445692 , -1.1833125 , -1.0788561 ,  0.94891423,  1.7270623 ],
       [-0.4972057 , -0.32821026,  0.6520426 ,  2.3687322 ,  2.472028  ,
        -1.1833125 ,  1.880589  ,  1.7145807 