In [1]:
import sys
sys.path.append("..")


In [2]:
import jax 
import jax.numpy as jnp
from jax import random

In [3]:
from models import VGG16
from utils import  tree_ones_like


model = VGG16(num_classes=10, projection_dim=512, width_multiplier=2)
batch_x = jnp.ones((10, 224, 224, 3))  # (N, H, W, C) format
labels = jax.nn.one_hot(jnp.arange(10), 10)
rng = random.PRNGKey(0)
params = model.init(rng, batch_x)['params']
logits = model.apply({'params':params}, batch_x)


2024-07-31 09:09:13.968382: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-31 09:09:14.087052: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-31 09:09:14.087116: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
2024-07-31 09:09:19.738434: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/

In [18]:
from utils import tree_zeros_like, tree_scalar_multiply, tree_add
from hvp import compute_hessian_vector_product_for_batch


def compute_hessian_vector_product(loss_fn, data_iter, params, v, nbatches=None, axis_name=None):
    if axis_name is not None: 
        batch_hvp = jax.pmap(compute_hessian_vector_product_for_batch, static_argnums=(0, 4))
    else:
        batch_hvp = jax.jit(compute_hessian_vector_product_for_batch, static_argnums=(0, 4))
    N = 0
    Hv = tree_zeros_like(v)
    ix = 0
    
    if axis_name is not None:
        N_fn = lambda batch: batch['label'].shape[0] * batch['label'].shape[1]
    else:
        N_fn = lambda batch: batch['label'].shape[0] 
    
    for batch in data_iter:
        batch_n = N_fn(batch)
        N += batch_n
        _Hv = batch_hvp(loss_fn, batch, params, v, axis_name)
        Hv = tree_add(Hv, tree_scalar_multiply(batch_n, _Hv))
        if nbatches is not None:
          if ix >= nbatches:
              break
    
    return tree_scalar_multiply(1./N, Hv)


In [19]:
labels.shape

(10, 10)

In [5]:
from hvp import compute_hessian_vector_product_for_batch
from pyhessian import compute_eigenvalues

In [6]:
v = tree_ones_like(params)
Hv = compute_hessian_vector_product_for_batch(model.apply, params, {'image': batch_x, 'label': labels}, v)

In [7]:
compute_hessian_vector_product_for_batch = jax.jit(compute_hessian_vector_product_for_batch, static_argnums=(0,))

In [8]:
Hv = compute_hessian_vector_product_for_batch(model.apply, params, {'image': batch_x, 'label': labels}, v)

In [12]:
Hv = compute_hessian_vector_product_for_batch(model.apply, params, {'image': batch_x, 'label': labels}, v)

In [16]:
Hv = compute_hessian_vector_product_for_batch(model.apply, params, {'image': 100* jnp.ones((10, 224, 224, 3)), 'label': labels}, v)

In [17]:
Hv

{'classifier': {'Dense_0': {'kernel': Array([[  48.9051   ,  -43.269436 ,   42.20083  , ...,   64.68891  ,
            -23.592323 ,  -30.585482 ],
          [  36.11164  ,   63.613514 ,   72.35839  , ...,    9.9805565,
             64.35591  ,  -84.99729  ],
          [ 256.31006  , -112.39426  ,  270.48178  , ...,  293.80682  ,
            -25.768944 , -234.99896  ],
          ...,
          [ -95.39591  ,  -97.54407  , -160.75497  , ...,  -54.24258  ,
           -109.67699  ,  178.49092  ],
          [  83.47262  , -120.65212  ,   51.85486  , ...,  128.917    ,
            -80.31489  ,  -21.639889 ],
          [ 104.9346   , -173.82886  ,   55.636387 , ...,  170.82364  ,
           -119.923996 ,  -12.734077 ]], dtype=float32)}},
 'encoder': {'Conv_0': {'kernel': Array([[[[ -0.69958717,  -0.6960474 ,  -0.6538208 , ...,  -0.73528147,
             -34.687683  ,  -1.4787024 ],
            [ -0.6995873 ,  -0.6960475 ,  -0.65382105, ...,  -0.7352816 ,
             -34.687683  ,  -1.4787022