Skip to content

Commit

Permalink
Remove batch_flat
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed May 15, 2020
1 parent 5f9e0f3 commit 31fa520
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
6 changes: 0 additions & 6 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ def make_residual_mat_prod(self, module, g_inp, g_out):
"""
raise NotImplementedError

# TODO Refactor and remove
def batch_flat(self, tensor):
batch = tensor.size(0)
# TODO Removing the clone().detach() will destroy the computation graph
# Tests will fail
return batch, tensor.clone().detach().view(batch, -1)

# TODO Refactor and remove
def get_batch(self, module):
Expand Down
5 changes: 3 additions & 2 deletions backpack/core/derivatives/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def _jac_mat_prod(self, module, g_inp, g_out, mat):
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
self._no_inplace(module)

batch, df_flat = self.batch_flat(self.df(module, g_inp, g_out))
return einsum("ni,nj,ij->ij", (df_flat, df_flat, mat)) / batch
N = module.input0.size(0)
df_flat = self.df(module, g_inp, g_out).reshape(N, -1)
return einsum("ni,nj,ij->ij", (df_flat, df_flat, mat)) / N

def hessian_diagonal(self, module, g_inp, g_out):
self._no_inplace(module)
Expand Down
6 changes: 3 additions & 3 deletions backpack/extensions/secondorder/hbp/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def _bias_for_batch_average(self, backproped):
return [backproped]

def __mean_input(self, module):
_, flat_input = self.derivatives.batch_flat(module.input0)
return flat_input.mean(0)
return module.input0.mean(0).flatten()

def __mean_input_outer(self, module):
N, flat_input = self.derivatives.batch_flat(module.input0)
N = module.input0.size(0)
flat_input = module.input0.reshape(N, -1)
return einsum("ni,nj->ij", (flat_input, flat_input)) / N

0 comments on commit 31fa520

Please sign in to comment.