Skip to content

Commit

Permalink
Add correct stride support - Issue #30
Browse files Browse the repository at this point in the history
  • Loading branch information
fKunstner committed Jan 20, 2020
1 parent 31518ca commit 368786c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion backpack/core/derivatives/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
)

grad_weight = grad_weight.view(
num_cols, batch, out_channels * in_channels, k_x, k_y
num_cols, batch, -1, grad_weight.shape[2], grad_weight.shape[3]
)
grad_weight = grad_weight.narrow(3, 0, k_x).narrow(4, 0, k_y)

if sum_batch is True:
grad_weight = grad_weight.sum(1)
batch = 1
Expand Down

0 comments on commit 368786c

Please sign in to comment.