From 31518ca4e4ddbd6df02dc4c4c4c2836ec6508eaf Mon Sep 17 00:00:00 2001 From: fkunstner Date: Sun, 19 Jan 2020 18:44:11 -0800 Subject: [PATCH 1/2] Add tests for Issue #30 --- test/bugfixes_test.py | 57 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 test/bugfixes_test.py diff --git a/test/bugfixes_test.py b/test/bugfixes_test.py new file mode 100644 index 000000000..de0b2d824 --- /dev/null +++ b/test/bugfixes_test.py @@ -0,0 +1,57 @@ +import itertools +import pytest +import torch +import backpack + + +def parameters_issue_30(): + possible_values = { + "N": [4], + "C_in": [4], + "C_out": [6], + "H": [6], + "W": [6], + "K": [3], + "S": [1, 3], + "pad": [0, 2], + "dil": [1, 2], + } + + configs = [ + dict(zip(possible_values.keys(), config_tuple)) + for config_tuple in itertools.product(*possible_values.values()) + ] + + return { + "argvalues": configs, + "ids": [str(config) for config in configs], + } + + +@pytest.mark.parametrize("params", **parameters_issue_30()) +def test_convolutions_stride_issue_30(params): + """ + https://github.com/f-dangel/backpack/issues/30 + + The gradient for the convolution is wrong when `stride` is not a multiple of + `D + 2*padding - dilation*(kernel-1) - 1`. + """ + torch.manual_seed(0) + + mod = torch.nn.Conv2d( + in_channels=params["C_in"], + out_channels=params["C_out"], + kernel_size=params["K"], + stride=params["S"], + padding=params["pad"], + dilation=params["dil"], + ) + backpack.extend(mod) + x = torch.randn(size=(params["N"], params["C_in"], params["W"], params["H"])) + + with backpack.backpack(backpack.extensions.BatchGrad()): + loss = torch.sum(mod(x)) + loss.backward() + + for p in mod.parameters(): + assert torch.allclose(p.grad, p.grad_batch.sum(0)) From 368786c3de77826bb3e6db976a9ff07ab207d2ca Mon Sep 17 00:00:00 2001 From: fkunstner Date: Sun, 19 Jan 2020 18:55:15 -0800 Subject: [PATCH 2/2] Add correct stride support - Issue #30 --- backpack/core/derivatives/conv2d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backpack/core/derivatives/conv2d.py b/backpack/core/derivatives/conv2d.py index 5bc813b1e..f936469cc 100644 --- a/backpack/core/derivatives/conv2d.py +++ b/backpack/core/derivatives/conv2d.py @@ -205,8 +205,10 @@ def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): ) grad_weight = grad_weight.view( - num_cols, batch, out_channels * in_channels, k_x, k_y + num_cols, batch, -1, grad_weight.shape[2], grad_weight.shape[3] ) + grad_weight = grad_weight.narrow(3, 0, k_x).narrow(4, 0, k_y) + if sum_batch is True: grad_weight = grad_weight.sum(1) batch = 1