Skip to content

Commit

Permalink
Formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
fKunstner committed Dec 9, 2020
1 parent f2a5b52 commit 16f902e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
30 changes: 23 additions & 7 deletions backpack/core/derivatives/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,14 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
N, C_out = module.output.shape[0], module.output.shape[1]
C_in = module.input0.shape[1]

mat = repeat(mat, "v n c_out ... -> v n (repeat_c_in c_out) ...", repeat_c_in=C_in)
mat = repeat(
mat, "v n c_out ... -> v n (repeat_c_in c_out) ...", repeat_c_in=C_in
)
mat = repeat(mat, "v n c_in_c_out ... -> (v n c_in_c_out) dummy ...", dummy=1)

input = repeat(module.input0, "n c ... -> dummy (repeat n c) ...", dummy=1, repeat=V)
input = repeat(
module.input0, "n c ... -> dummy (repeat n c) ...", dummy=1, repeat=V
)

grad_weight = self.conv_func(
input,
Expand All @@ -125,13 +129,19 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
grad_weight,
"(v n C_in C_out) ... -> v C_out C_in ...",
reduction="sum",
v=V, n=N, C_in=C_in, C_out=C_out,
v=V,
n=N,
C_in=C_in,
C_out=C_out,
)
else:
return rearrange(
grad_weight,
"(v n C_in C_out) ... -> v n C_out C_in ...",
v=V, n=N, C_in=C_in, C_out=C_out,
v=V,
n=N,
C_in=C_in,
C_out=C_out,
)

def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
Expand All @@ -150,14 +160,20 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):

class Conv1DDerivatives(ConvNDDerivatives):
def __init__(self):
super().__init__(N=1, module=Conv1d, conv_func=conv1d, conv_transpose_func=conv_transpose1d)
super().__init__(
N=1, module=Conv1d, conv_func=conv1d, conv_transpose_func=conv_transpose1d
)


class Conv2DDerivatives(ConvNDDerivatives):
def __init__(self):
super().__init__(N=2, module=Conv2d, conv_func=conv2d, conv_transpose_func=conv_transpose2d)
super().__init__(
N=2, module=Conv2d, conv_func=conv2d, conv_transpose_func=conv_transpose2d
)


class Conv3DDerivatives(ConvNDDerivatives):
def __init__(self):
super().__init__(N=3, module=Conv3d, conv_func=conv3d, conv_transpose_func=conv_transpose3d)
super().__init__(
N=3, module=Conv3d, conv_func=conv3d, conv_transpose_func=conv_transpose3d
)
2 changes: 1 addition & 1 deletion backpack/core/derivatives/conv_transposend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn.functional import conv1d, conv2d, conv3d
from torch.nn.functional import conv_transpose1d, conv_transpose2d, conv_transpose3d

from einops import rearrange, repeat, reduce
from einops import rearrange, reduce
from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils.conv_transpose import unfold_by_conv_transpose

Expand Down

0 comments on commit 16f902e

Please sign in to comment.