Skip to content

Commit

Permalink
[REF] Spell out add_axes, replace list with bool
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jun 28, 2021
1 parent 2221a9d commit f488344
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions backpack/extensions/firstorder/batch_l2_grad/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def weight(
Returns:
batch_l2 for weight
"""
add_axes = list(range(1, g_out[0].dim() - 1))
has_additional_axes = g_out[0].dim() > 2

if add_axes:
if has_additional_axes:
# TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class
# implementation: https://github.com/fKunstner/backpack-discuss/issues/111
dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2)
Expand Down
4 changes: 2 additions & 2 deletions backpack/extensions/firstorder/sum_grad_squared/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def weight(self, ext, module, g_inp, g_out, backproped):
For details, see page 12 (paragraph about "second moment") of the
paper (https://arxiv.org/pdf/1912.10985.pdf).
"""
add_axes = list(range(1, g_out[0].dim() - 1))
has_additional_axes = g_out[0].dim() > 2

if add_axes:
if has_additional_axes:
# TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class
# implementation: https://github.com/fKunstner/backpack-discuss/issues/111
dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2)
Expand Down
10 changes: 5 additions & 5 deletions backpack/utils/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def extract_weight_diagonal(
``(N, module.weight.shape)`` with batch size ``N``) or summed weight diagonal
if ``sum_batch=True`` (shape ``module.weight.shape``).
"""
add_axes = list(range(1, module.input0.dim() - 1))
has_additional_axes = module.input0.dim() > 2

if add_axes:
if has_additional_axes:
S_flat = S.flatten(start_dim=2, end_dim=-2)
X_flat = module.input0.flatten(start_dim=1, end_dim=-2)
equation = f"vnmo,nmi,vnko,nki->{'' if sum_batch else 'n'}oi"
Expand All @@ -48,10 +48,10 @@ def extract_bias_diagonal(module: Linear, S: Tensor, sum_batch: bool = True) ->
``(N, module.bias.shape)`` with batch size ``N``) or summed bias diagonal
if ``sum_batch=True`` (shape ``module.bias.shape``).
"""
add_axes = list(range(2, module.input0.dim()))
additional_axes = list(range(2, module.input0.dim()))

if add_axes:
JS = S.sum(add_axes)
if additional_axes:
JS = S.sum(additional_axes)
else:
JS = S

Expand Down

0 comments on commit f488344

Please sign in to comment.