Skip to content

Commit

Permalink
Merge 368786c into 3a92411
Browse files Browse the repository at this point in the history
  • Loading branch information
fKunstner committed Jan 20, 2020
2 parents 3a92411 + 368786c commit 6a7c683
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
4 changes: 3 additions & 1 deletion backpack/core/derivatives/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
)

grad_weight = grad_weight.view(
num_cols, batch, out_channels * in_channels, k_x, k_y
num_cols, batch, -1, grad_weight.shape[2], grad_weight.shape[3]
)
grad_weight = grad_weight.narrow(3, 0, k_x).narrow(4, 0, k_y)

if sum_batch is True:
grad_weight = grad_weight.sum(1)
batch = 1
Expand Down
57 changes: 57 additions & 0 deletions test/bugfixes_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import itertools
import pytest
import torch
import backpack


def parameters_issue_30():
possible_values = {
"N": [4],
"C_in": [4],
"C_out": [6],
"H": [6],
"W": [6],
"K": [3],
"S": [1, 3],
"pad": [0, 2],
"dil": [1, 2],
}

configs = [
dict(zip(possible_values.keys(), config_tuple))
for config_tuple in itertools.product(*possible_values.values())
]

return {
"argvalues": configs,
"ids": [str(config) for config in configs],
}


@pytest.mark.parametrize("params", **parameters_issue_30())
def test_convolutions_stride_issue_30(params):
"""
https://github.com/f-dangel/backpack/issues/30
The gradient for the convolution is wrong when `stride` is not a multiple of
`D + 2*padding - dilation*(kernel-1) - 1`.
"""
torch.manual_seed(0)

mod = torch.nn.Conv2d(
in_channels=params["C_in"],
out_channels=params["C_out"],
kernel_size=params["K"],
stride=params["S"],
padding=params["pad"],
dilation=params["dil"],
)
backpack.extend(mod)
x = torch.randn(size=(params["N"], params["C_in"], params["W"], params["H"]))

with backpack.backpack(backpack.extensions.BatchGrad()):
loss = torch.sum(mod(x))
loss.backward()

for p in mod.parameters():
assert torch.allclose(p.grad, p.grad_batch.sum(0))

0 comments on commit 6a7c683

Please sign in to comment.