In [1]:
import sys
sys.path.append("..")


In [2]:
import jax 
import jax.numpy as jnp
from jax import random

In [None]:
from pyhessian import compute_density, compute_trace
%load_ext autoreload
%autoreload 2

In [3]:
from models import VGG16
from utils import  tree_ones_like


model = VGG16(num_classes=10, projection_dim=512, width_multiplier=2)
batch_x = jnp.ones((10, 224, 224, 3))  # (N, H, W, C) format
labels = jax.nn.one_hot(jnp.arange(10), 10)
rng = random.PRNGKey(0)
theta = model.init(rng, batch_x)['params']
logits = model.apply({'params':theta}, batch_x)

batch= {'image': batch_x, 
        'label': labels}


2024-07-31 11:42:19.615264: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-31 11:42:19.646777: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-31 11:42:19.646830: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
2024-07-31 11:42:23.270579: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/

In [4]:
from utils import tree_zeros_like, tree_scalar_multiply, tree_add
from pyhessian.hvp import compute_hessian_vector_product_for_batch


def compute_hessian_vector_product(loss_fn, data_iter, theta, v, nbatches=None, axis_name=None):
    if axis_name is not None: 
        batch_hvp = jax.pmap(compute_hessian_vector_product_for_batch, static_argnums=(0, 4))
    else:
        batch_hvp = jax.jit(compute_hessian_vector_product_for_batch, static_argnums=(0, 4))
    N = 0
    Hv = tree_zeros_like(v)
    ix = 0
    
    if axis_name is not None:
        N_fn = lambda batch: batch['label'].shape[0] * batch['label'].shape[1]
    else:
        N_fn = lambda batch: batch['label'].shape[0] 
    
    for batch in data_iter:
        batch_n = N_fn(batch)
        N += batch_n
        _Hv = batch_hvp(loss_fn, batch, theta, v, axis_name)
        Hv = tree_add(Hv, tree_scalar_multiply(batch_n, _Hv))
        if nbatches is not None:
          if ix >= nbatches:
              break
    
    return tree_scalar_multiply(1./N, Hv)


In [5]:
from trainer import cross_entropy_loss

def loss_fn(batch, params):    
    logits = model.apply(
        {'params': params},
        batch['image'])
    loss = cross_entropy_loss(logits, batch['label'])
    return loss



In [6]:
Hv = compute_hessian_vector_product(loss_fn, [batch], theta, tree_ones_like(theta))

In [7]:
HVP = lambda theta, v: compute_hessian_vector_product(loss_fn, [batch], theta, v)

In [13]:
trace = compute_trace(random.PRNGKey(0),HVP, theta)
print(jnp.mean(jnp.stack(trace)))

5161.396


In [10]:
density = compute_density(random.PRNGKey(0),HVP, theta, n_eigs=65)


100%|██████████| 64/64 [00:21<00:00,  2.95it/s]


In [11]:
density

([array([-3.27187897e+02, -3.27111267e+02, -2.95718384e+02, -2.90161621e+02,
         -2.78119171e+02, -2.67021484e+02, -2.49264252e+02, -2.38545853e+02,
         -2.24129471e+02, -2.04580673e+02, -1.87019012e+02, -1.61635406e+02,
         -1.39716049e+02, -1.14151871e+02, -8.75241241e+01, -6.12279358e+01,
         -3.23569946e+01, -5.54385376e+00,  1.70898438e-03,  2.52340698e+01,
          5.37284546e+01,  8.61958008e+01,  1.14474457e+02,  1.41534485e+02,
          1.70649323e+02,  1.99142334e+02,  2.20908997e+02,  2.33808289e+02,
          2.58133179e+02,  4.20188751e+02,  4.20269928e+02,  5.42360046e+02,
          5.42585449e+02,  5.56015198e+02,  6.74210083e+02,  6.74257141e+02,
          9.43235352e+02,  9.43384216e+02,  9.43452332e+02,  1.02174756e+03,
          1.02177820e+03,  1.02182587e+03,  1.36194336e+03,  1.36200977e+03,
          1.36204346e+03,  1.36208752e+03,  1.70859814e+03,  1.70877112e+03,
          1.70877429e+03,  1.70880310e+03,  2.13389868e+03,  2.13764478e+03,

In [16]:
from utils import normal_tree_like
Hv = HVP(theta, normal_tree_like(rng, theta))

In [17]:
Hv

{'classifier': {'Dense_0': {'kernel': Array([[ 2.3916593 , -2.1892796 ,  1.600219  , ...,  1.8660086 ,
           -1.3288367 , -0.11625004],
          [ 6.934653  , -3.9132428 ,  4.486203  , ...,  0.76188874,
           -2.083085  ,  1.7786335 ],
          [12.39119   , -6.1026363 ,  7.960019  , ..., -0.3374951 ,
           -3.0753405 ,  3.9513505 ],
          ...,
          [-8.06523   ,  0.39032364, -4.954995  , ...,  7.0587335 ,
           -0.60218   , -5.684491  ],
          [ 1.4406937 , -2.4653025 ,  1.0363044 , ...,  3.3132145 ,
           -1.633959  , -1.0663664 ],
          [ 2.4651036 , -4.939668  ,  1.8187027 , ...,  7.0465364 ,
           -3.3202343 , -2.4515195 ]], dtype=float32)}},
 'encoder': {'Conv_0': {'kernel': Array([[[[-0.1094439 , -0.11272006, -0.11166426, ..., -0.11009119,
              0.46964812, -0.17074658],
            [-0.10944386, -0.11272004, -0.11166423, ..., -0.11009123,
              0.46964842, -0.17074655],
            [-0.1094439 , -0.11272006, -0.11