Skip to content

Commit

Permalink
Add @paulkogni's implementation of Batchnorm1d HVP
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jan 27, 2020
1 parent 140b94f commit f5c2fe8
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
9 changes: 7 additions & 2 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
weight_jac_mat_prod_check_shapes,
make_hessian_mat_prod_accept_vectors,
make_hessian_mat_prod_check_shapes,
R_mat_prod_accept_vectors,
R_mat_prod_check_shapes,
)


Expand Down Expand Up @@ -116,8 +118,8 @@ def hessian_diagonal(self):
def hessian_is_psd(self):
raise NotImplementedError

# TODO make accept vectors
# TODO add shape check
@R_mat_prod_accept_vectors
@R_mat_prod_check_shapes
def make_residual_mat_prod(self, module, g_inp, g_out):
"""Return multiplication routine with the residual term.
Expand All @@ -129,6 +131,9 @@ def make_residual_mat_prod(self, module, g_inp, g_out):
This function only has to be implemented if the residual is not
zero and not diagonal (for instance, `BatchNorm`).
"""
return self._make_residual_mat_prod(module, g_inp, g_out)

def _make_residual_mat_prod(self, module, g_inp, g_out):
raise NotImplementedError

# TODO Refactor and remove
Expand Down
40 changes: 29 additions & 11 deletions backpack/core/derivatives/batchnorm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

from backpack.utils.ein import einsum
from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.core.derivatives.shape_check import (
R_mat_prod_accept_vectors,
R_mat_prod_check_shapes,
)


class BatchNorm1dDerivatives(BaseParameterDerivatives):
Expand Down Expand Up @@ -59,19 +55,41 @@ def get_normalized_input_and_var(self, module):
var = input.var(dim=0, unbiased=False)
return (input - mean) / (var + module.eps).sqrt(), var

@R_mat_prod_accept_vectors
@R_mat_prod_check_shapes
def make_residual_mat_prod(self, module, g_inp, g_out):
# TODO: Implement R_mat_prod for BatchNorm
def _make_residual_mat_prod(self, module, g_inp, g_out):

N = self.get_batch(module)
x_hat, var = self.get_normalized_input_and_var(module)
gamma = module.weight
eps = module.eps

def R_mat_prod(mat):
"""Multiply with the residual: mat → [∑_{k} Hz_k(x) 𝛿z_k] mat.
Second term of the module input Hessian backpropagation equation.
"""
raise NotImplementedError
factor = gamma / (N * (var + eps))

sum_127 = einsum("nc,vnc->vc", (x_hat, mat))
sum_24 = einsum("nc->c", g_out[0])
sum_3 = einsum("nc,vnc->vc", (g_out[0], mat))
sum_46 = einsum("vnc->vc", mat)
sum_567 = einsum("nc,nc->c", (x_hat, g_out[0]))

r_mat = -einsum("nc,vc->vnc", (g_out[0], sum_127))
r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_24, sum_127)).unsqueeze(
1
).expand(-1, N, -1)
r_mat -= einsum("nc,vc->vnc", (x_hat, sum_3))
r_mat += (1.0 / N) * einsum("nc,c,vc->vnc", (x_hat, sum_24, sum_46))

r_mat -= einsum("vnc,c->vnc", (mat, sum_567))
r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_567, sum_46)).unsqueeze(
1
).expand(-1, N, -1)
r_mat += (3.0 / N) * einsum("nc,vc,c->vnc", (x_hat, sum_127, sum_567))

return einsum("c,vnc->vnc", (factor, r_mat))

# TODO: Enable tests in test/automated_bn_test.py
raise NotImplementedError
return R_mat_prod

def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
Expand Down
20 changes: 4 additions & 16 deletions test/automated_bn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,20 @@ def test_ggn_vp(problem, device):


@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS)
def test_hvp_is_not_implemented(problem, device):
# TODO: Rename after implementing BatchNorm R_mat_prod
def test_hvp(problem, device):
problem.to(device)

vecs = [torch.randn(*p.shape, device=device) for p in problem.model.parameters()]

# TODO: Implement BatchNorm R_mat_prod in backpack/core/derivatives/batchnorm1d.py
try:
backpack_res = BpextImpl(problem).hvp(vecs)
except NotImplementedError:
return

backpack_res = BpextImpl(problem).hvp(vecs)
autograd_res = AutogradImpl(problem).hvp(vecs)

check_sizes(autograd_res, backpack_res)
check_values(autograd_res, backpack_res)


@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS)
def test_hmp_is_not_implemented(problem, device):
# TODO: Rename after implementing BatchNorm R_mat_prod
def test_hmp(problem, device):
problem.to(device)

NUM_COLS = 10
Expand All @@ -104,12 +97,7 @@ def test_hmp_is_not_implemented(problem, device):
for p in problem.model.parameters()
]

# TODO: Implement BatchNorm R_mat_prod in backpack/core/derivatives/batchnorm1d.py
try:
backpack_res = BpextImpl(problem).hmp(matrices)
except NotImplementedError:
return

backpack_res = BpextImpl(problem).hmp(matrices)
autograd_res = AutogradImpl(problem).hmp(matrices)

check_sizes(autograd_res, backpack_res)
Expand Down

0 comments on commit f5c2fe8

Please sign in to comment.