Skip to content

Commit

Permalink
Remove get_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed May 15, 2020
1 parent 31fa520 commit 1b1e4dd
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 8 deletions.
5 changes: 0 additions & 5 deletions backpack/core/derivatives/basederivatives.py
Expand Up @@ -129,11 +129,6 @@ def make_residual_mat_prod(self, module, g_inp, g_out):
"""
raise NotImplementedError


# TODO Refactor and remove
def get_batch(self, module):
return module.input0.size(0)

@staticmethod
def _reshape_like(mat, like):
"""Reshape as like with trailing and additional 0th dimension.
Expand Down
4 changes: 2 additions & 2 deletions backpack/core/derivatives/batchnorm1d.py
Expand Up @@ -40,7 +40,7 @@ def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
"""
assert module.affine is True

N = self.get_batch(module)
N = module.input0.size(0)
x_hat, var = self.get_normalized_input_and_var(module)
ivar = 1.0 / (var + module.eps).sqrt()

Expand Down Expand Up @@ -90,7 +90,7 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch):
return einsum(equation, operands)

def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
N = self.get_batch(module)
N = module.input0.size(0)
return mat.unsqueeze(1).repeat(1, N, 1)

def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
Expand Down
2 changes: 1 addition & 1 deletion backpack/core/derivatives/linear.py
Expand Up @@ -48,7 +48,7 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):

def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
"""Apply Jacobian of the output w.r.t. the bias."""
N = self.get_batch(module)
N = module.input0.size(0)
return mat.unsqueeze(1).expand(-1, N, -1)

def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
Expand Down

0 comments on commit 1b1e4dd

Please sign in to comment.