From f4883447fb67b64e3e3da1943a411ec151038167 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 28 Jun 2021 17:02:48 +0200 Subject: [PATCH] [REF] Spell out `add_axes`, replace list with bool --- backpack/extensions/firstorder/batch_l2_grad/linear.py | 4 ++-- .../extensions/firstorder/sum_grad_squared/linear.py | 4 ++-- backpack/utils/linear.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py index 4a562797a..4978fd059 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/linear.py +++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py @@ -40,9 +40,9 @@ def weight( Returns: batch_l2 for weight """ - add_axes = list(range(1, g_out[0].dim() - 1)) + has_additional_axes = g_out[0].dim() > 2 - if add_axes: + if has_additional_axes: # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) diff --git a/backpack/extensions/firstorder/sum_grad_squared/linear.py b/backpack/extensions/firstorder/sum_grad_squared/linear.py index 2aa2a549a..8239da936 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/linear.py +++ b/backpack/extensions/firstorder/sum_grad_squared/linear.py @@ -18,9 +18,9 @@ def weight(self, ext, module, g_inp, g_out, backproped): For details, see page 12 (paragraph about "second moment") of the paper (https://arxiv.org/pdf/1912.10985.pdf). """ - add_axes = list(range(1, g_out[0].dim() - 1)) + has_additional_axes = g_out[0].dim() > 2 - if add_axes: + if has_additional_axes: # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) diff --git a/backpack/utils/linear.py b/backpack/utils/linear.py index e2b825323..b61c72cab 100644 --- a/backpack/utils/linear.py +++ b/backpack/utils/linear.py @@ -19,9 +19,9 @@ def extract_weight_diagonal( ``(N, module.weight.shape)`` with batch size ``N``) or summed weight diagonal if ``sum_batch=True`` (shape ``module.weight.shape``). """ - add_axes = list(range(1, module.input0.dim() - 1)) + has_additional_axes = module.input0.dim() > 2 - if add_axes: + if has_additional_axes: S_flat = S.flatten(start_dim=2, end_dim=-2) X_flat = module.input0.flatten(start_dim=1, end_dim=-2) equation = f"vnmo,nmi,vnko,nki->{'' if sum_batch else 'n'}oi" @@ -48,10 +48,10 @@ def extract_bias_diagonal(module: Linear, S: Tensor, sum_batch: bool = True) -> ``(N, module.bias.shape)`` with batch size ``N``) or summed bias diagonal if ``sum_batch=True`` (shape ``module.bias.shape``). """ - add_axes = list(range(2, module.input0.dim())) + additional_axes = list(range(2, module.input0.dim())) - if add_axes: - JS = S.sum(add_axes) + if additional_axes: + JS = S.sum(additional_axes) else: JS = S