Skip to content

Commit

Permalink
Add convTranspose weight_jac_mat_prod
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed May 26, 2020
1 parent 4c35fad commit abb90f5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
33 changes: 31 additions & 2 deletions backpack/core/derivatives/conv_transpose2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,41 @@ def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
N, _, H_out, W_out = module.output_shape
return jac_mat.expand(-1, N, -1, H_out, W_out)

def weight_jac_mat_prod(self, module, g_inp, g_out, mat):
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
"""Apply weight-output Jacobian to a matrix.
For 1d transpose convolution (x, W) ↦ y with u = unfold(x):
y[n,g,o,h] = ∑_{J,Y} W[J,g,o,Y] u[n,J,g,Y,h]
result[n,g,o,h]
= ∑_{I,O,X} ∂y[n,g,o,h]/∂W[I,g,O,X] mat[I,g,O,X]
= ∑_{I,O,X,J,Y} ∂( W[J,g,o,Y] u[n,J,g,Y,h] )/∂W[I,g,O,X] mat[I,g,O,X]
= ∑_{O,J,Y} ∂( W[J,g,O,Y] u[n,J,g,Y,h] )/∂W[J,g,O,Y] mat[J,g,O,Y]
= ∑_{O,J,Y} u[n,J,g,Y,h] mat[J,g,O,Y]
= ∑_{o,i,x} u[n,i,g,x,h] mat[i,g,o,x]
"""
V = mat.shape[0]
G = module.groups
C_in = module.input0.shape[1]
_, _, K_X, K_Y = module.weight.shape
N, C_out, H_out, W_out = module.output.shape

mat_reshape = mat.reshape(V, C_in, G, C_out // G, K_X, K_Y)
u = unfold_by_conv_transpose(module.input0, module).reshape(
N, C_in // G, G, K_X, K_Y, H_out, W_out
)

jac_mat = torch.einsum("nigxyhw,vigoxy->vngohw", u, mat_reshape)

return self.reshape_like_output(jac_mat, module)

def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
# TODO Implement with unfold
raise NotImplementedError

def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
"""Apply weight-output Jacobian to a matrix.
"""Apply transposed weight-output Jacobian to a matrix.
For 1d transpose convolution (x, W) ↦ y with u = unfold(x):
Expand Down
6 changes: 5 additions & 1 deletion test/core/derivatives/derivatives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def test_weight_jac_t_mat_prod(problem, sum_batch, V=3):
problem.tear_down()


@pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS)
@pytest.mark.parametrize(
"problem",
PROBLEMS_WITH_WEIGHTS + CONV_T_PROBLEMS,
ids=IDS_WITH_WEIGHTS + CONV_T_IDS,
)
def test_weight_jac_mat_prod(problem, V=3):
"""Test the Jacobian-matrix product w.r.t. to the weights.
Expand Down

0 comments on commit abb90f5

Please sign in to comment.