diff --git a/backpack/core/derivatives/conv.py b/backpack/core/derivatives/conv.py index 2ba677bd..b4fd6c89 100644 --- a/backpack/core/derivatives/conv.py +++ b/backpack/core/derivatives/conv.py @@ -100,10 +100,14 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): N, C_out = module.output.shape[0], module.output.shape[1] C_in = module.input0.shape[1] - mat = repeat(mat, "v n c_out ... -> v n (repeat_c_in c_out) ...", repeat_c_in=C_in) + mat = repeat( + mat, "v n c_out ... -> v n (repeat_c_in c_out) ...", repeat_c_in=C_in + ) mat = repeat(mat, "v n c_in_c_out ... -> (v n c_in_c_out) dummy ...", dummy=1) - input = repeat(module.input0, "n c ... -> dummy (repeat n c) ...", dummy=1, repeat=V) + input = repeat( + module.input0, "n c ... -> dummy (repeat n c) ...", dummy=1, repeat=V + ) grad_weight = self.conv_func( input, @@ -125,13 +129,19 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): grad_weight, "(v n C_in C_out) ... -> v C_out C_in ...", reduction="sum", - v=V, n=N, C_in=C_in, C_out=C_out, + v=V, + n=N, + C_in=C_in, + C_out=C_out, ) else: return rearrange( grad_weight, "(v n C_in C_out) ... -> v n C_out C_in ...", - v=V, n=N, C_in=C_in, C_out=C_out, + v=V, + n=N, + C_in=C_in, + C_out=C_out, ) def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): @@ -150,14 +160,20 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): class Conv1DDerivatives(ConvNDDerivatives): def __init__(self): - super().__init__(N=1, module=Conv1d, conv_func=conv1d, conv_transpose_func=conv_transpose1d) + super().__init__( + N=1, module=Conv1d, conv_func=conv1d, conv_transpose_func=conv_transpose1d + ) class Conv2DDerivatives(ConvNDDerivatives): def __init__(self): - super().__init__(N=2, module=Conv2d, conv_func=conv2d, conv_transpose_func=conv_transpose2d) + super().__init__( + N=2, module=Conv2d, conv_func=conv2d, conv_transpose_func=conv_transpose2d + ) class Conv3DDerivatives(ConvNDDerivatives): def __init__(self): - super().__init__(N=3, module=Conv3d, conv_func=conv3d, conv_transpose_func=conv_transpose3d) + super().__init__( + N=3, module=Conv3d, conv_func=conv3d, conv_transpose_func=conv_transpose3d + ) diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index 14ca7635..0926726d 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -5,7 +5,7 @@ from torch.nn.functional import conv1d, conv2d, conv3d from torch.nn.functional import conv_transpose1d, conv_transpose2d, conv_transpose3d -from einops import rearrange, repeat, reduce +from einops import rearrange, reduce from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils.conv_transpose import unfold_by_conv_transpose