Skip to content

Commit

Permalink
Merge branch 'development' into conv-transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jun 24, 2020
2 parents 82d07b2 + f2509cb commit a666eb0
Show file tree
Hide file tree
Showing 7 changed files with 709 additions and 4 deletions.
6 changes: 6 additions & 0 deletions backpack/core/derivatives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
ELU,
SELU,
AvgPool2d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose2d,
CrossEntropyLoss,
Dropout,
Expand All @@ -18,7 +20,9 @@
)

from .avgpool2d import AvgPool2DDerivatives
from .conv1d import Conv1DDerivatives
from .conv2d import Conv2DDerivatives
from .conv3d import Conv3DDerivatives
from .conv_transpose2d import ConvTranspose2DDerivatives
from .crossentropyloss import CrossEntropyLossDerivatives
from .dropout import DropoutDerivatives
Expand All @@ -36,7 +40,9 @@

derivatives_for = {
Linear: LinearDerivatives,
Conv1d: Conv1DDerivatives,
Conv2d: Conv2DDerivatives,
Conv3d: Conv3DDerivatives,
AvgPool2d: AvgPool2DDerivatives,
MaxPool2d: MaxPool2DDerivatives,
ZeroPad2d: ZeroPad2dDerivatives,
Expand Down
134 changes: 134 additions & 0 deletions backpack/core/derivatives/conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from torch import einsum
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.functional import conv1d

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import conv as convUtils
from backpack.utils.ein import eingroup


class Conv1DDerivatives(BaseParameterDerivatives):
def get_module(self):
return Conv1d

def hessian_is_zero(self):
return True

def get_unfolded_input(self, module):
return convUtils.unfold_by_conv(module.input0, module)

def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
_, C_in, L_in = module.input0.size()
in_features = C_in * L_in
_, C_out, L_out = module.output.size()
out_features = C_out * L_out

mat = mat.reshape(out_features, C_out, L_out)
jac_t_mat = self.__jac_t(module, mat).reshape(out_features, in_features)

mat_t_jac = jac_t_mat.t().reshape(in_features, C_out, L_out)
jac_t_mat_t_jac = self.__jac_t(module, mat_t_jac).reshape(
in_features, in_features
)

return jac_t_mat_t_jac.t()

def _jac_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,c,l->vn,c,l", mat)
jmp_as_conv = conv1d(
mat_as_conv,
module.weight.data,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
return self.reshape_like_output(jmp_as_conv, module)

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,c,l->vn,c,l", mat)
jmp_as_conv = self.__jac_t(module, mat_as_conv)
return self.reshape_like_input(jmp_as_conv, module)

def __jac_t(self, module, mat):
"""Apply Conv1d backward operation."""
_, C_in, L_in = module.input0.size()
_, C_out, _ = module.output.size()
L_axis = 2

conv1d_t = ConvTranspose1d(
in_channels=C_out,
out_channels=C_in,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
bias=False,
dilation=module.dilation,
groups=module.groups,
).to(module.input0.device)

conv1d_t.weight.data = module.weight

V_N = mat.size(0)
output_size = (V_N, C_in, L_in)

jac_t_mat = conv1d_t(mat, output_size=output_size).narrow(L_axis, 0, L_in)
return jac_t_mat

def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
"""mat has shape [V, C_out]"""
# expand for each batch and for each channel
N_axis, L_axis = 1, 3
jac_mat = mat.unsqueeze(N_axis).unsqueeze(L_axis)

N, _, L_out = module.output_shape
return jac_mat.expand(-1, N, -1, L_out)

def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
N_axis, L_axis = 1, 3
axes = [L_axis]
if sum_batch:
axes = [N_axis] + axes

return mat.sum(axes)

# TODO: Improve performance by using conv instead of unfold

def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
jac_mat = eingroup("v,o,i,l->v,o,il", mat)
X = self.get_unfolded_input(module)
jac_mat = einsum("nij,vki->vnkj", (X, jac_mat))
return self.reshape_like_output(jac_mat, module)

def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
V = mat.shape[0]
N, C_out, _ = module.output_shape
_, C_in, _ = module.input0_shape

mat = eingroup("v,n,c,l->vn,c,l", mat).repeat(1, C_in, 1)
C_in_axis = 1
# a,b represent the combined/repeated dimensions
mat = eingroup("a,b,l->ab,l", mat).unsqueeze(C_in_axis)

N_axis = 0
input = eingroup("n,c,l->nc,l", module.input0).unsqueeze(N_axis)
input = input.repeat(1, V, 1)

grad_weight = conv1d(
input,
mat,
bias=None,
stride=module.dilation,
padding=module.padding,
dilation=module.stride,
groups=C_in * N * V,
).squeeze(0)

K_L_axis = 1
_, _, K_L = module.weight.shape
grad_weight = grad_weight.narrow(K_L_axis, 0, K_L)

eingroup_eq = "vnio,x->v,{}o,i,x".format("" if sum_batch else "n,")
return eingroup(
eingroup_eq, grad_weight, dim={"v": V, "n": N, "i": C_in, "o": C_out}
)
147 changes: 147 additions & 0 deletions backpack/core/derivatives/conv3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from torch import einsum
from torch.nn import Conv3d, ConvTranspose3d
from torch.nn.functional import conv3d

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import conv as convUtils
from backpack.utils.ein import eingroup


class Conv3DDerivatives(BaseParameterDerivatives):
def get_module(self):
return Conv3d

def hessian_is_zero(self):
return True

def get_unfolded_input(self, module):
return convUtils.unfold_by_conv(module.input0, module)

def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
_, C_in, D_in, H_in, W_in = module.input0.size()
in_features = C_in * D_in * H_in * W_in
_, C_out, D_out, H_out, W_out = module.output.size()
out_features = C_out * D_out * H_out * W_out

mat = mat.reshape(out_features, C_out, D_out, H_out, W_out)
jac_t_mat = self.__jac_t(module, mat).reshape(out_features, in_features)

mat_t_jac = jac_t_mat.t().reshape(in_features, C_out, D_out, H_out, W_out)
jac_t_mat_t_jac = self.__jac_t(module, mat_t_jac).reshape(
in_features, in_features
)

return jac_t_mat_t_jac.t()

def _jac_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,d,c,h,w->vn,d,c,h,w", mat)
jmp_as_conv = conv3d(
mat_as_conv,
module.weight.data,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
return self.reshape_like_output(jmp_as_conv, module)

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,c,d,h,w->vn,c,d,h,w", mat)
jmp_as_conv = self.__jac_t(module, mat_as_conv)
return self.reshape_like_input(jmp_as_conv, module)

def __jac_t(self, module, mat):
"""Apply Conv3d backward operation."""
_, C_in, D_in, H_in, W_in = module.input0.size()
_, C_out, _, _, _ = module.output.size()
D_axis = 2
H_axis = 3
W_axis = 4

conv3d_t = ConvTranspose3d(
in_channels=C_out,
out_channels=C_in,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
bias=False,
dilation=module.dilation,
groups=module.groups,
).to(module.input0.device)

conv3d_t.weight.data = module.weight

V_N = mat.size(0)
output_size = (V_N, C_in, D_in, H_in, W_in)

jac_t_mat = (
conv3d_t(mat, output_size=output_size)
.narrow(D_axis, 0, D_in)
.narrow(H_axis, 0, H_in)
.narrow(W_axis, 0, W_in)
)
return jac_t_mat

def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
"""mat has shape [V, C_out]"""
# expand for each batch and for each channel
N_axis, D_axis, H_axis, W_axis = 1, 3, 4, 5
jac_mat = (
mat.unsqueeze(N_axis).unsqueeze(D_axis).unsqueeze(H_axis).unsqueeze(W_axis)
)

N, _, D_out, H_out, W_out = module.output_shape
return jac_mat.expand(-1, N, -1, D_out, H_out, W_out)

def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
N_axis, D_axis, H_axis, W_axis = 1, 3, 4, 5
axes = [D_axis, H_axis, W_axis]
if sum_batch:
axes = [N_axis] + axes

return mat.sum(axes)

# TODO: Improve performance by using conv instead of unfold

def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
jac_mat = eingroup("v,o,i,d,h,w->v,o,idhw", mat)
X = self.get_unfolded_input(module)
jac_mat = einsum("nij,vki->vnkj", (X, jac_mat))
return self.reshape_like_output(jac_mat, module)

def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
V = mat.shape[0]
N, C_out, _, _, _ = module.output_shape
_, C_in, _, _, _ = module.input0_shape

mat = eingroup("v,n,c,d,w,h->vn,c,d,w,h", mat).repeat(1, C_in, 1, 1, 1)
C_in_axis = 1
# a,b represent the combined/repeated dimensions
mat = eingroup("a,b,d,w,h->ab,d,w,h", mat).unsqueeze(C_in_axis)

N_axis = 0
input = eingroup("n,c,d,h,w->nc,d,h,w", module.input0).unsqueeze(N_axis)
input = input.repeat(1, V, 1, 1, 1)

grad_weight = conv3d(
input,
mat,
bias=None,
stride=module.dilation,
padding=module.padding,
dilation=module.stride,
groups=C_in * N * V,
).squeeze(0)

K_D_axis, K_H_axis, K_W_axis = 1, 2, 3
_, _, K_D, K_H, K_W = module.weight.shape
grad_weight = (
grad_weight.narrow(K_D_axis, 0, K_D)
.narrow(K_H_axis, 0, K_H)
.narrow(K_W_axis, 0, K_W)
)

eingroup_eq = "vnio,x,y,z->v,{}o,i,x,y,z".format("" if sum_batch else "n,")
return eingroup(
eingroup_eq, grad_weight, dim={"v": V, "n": N, "i": C_in, "o": C_out}
)
41 changes: 41 additions & 0 deletions backpack/utils/conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from torch import einsum
from torch.nn import Unfold
from torch.nn.functional import conv1d, conv2d, conv3d

from backpack.utils.ein import eingroup

Expand Down Expand Up @@ -43,3 +45,42 @@ def extract_bias_diagonal(module, sqrt):
V_axis, N_axis = 0, 1
bias_diagonal = (einsum("vnchw->vnc", sqrt) ** 2).sum([V_axis, N_axis])
return bias_diagonal


def unfold_by_conv(input, module):
"""Return the unfolded input using convolution"""
N, C_in = input.shape[0], input.shape[1]
kernel_size = module.kernel_size
kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size)))

def make_weight():
weight = torch.zeros(kernel_size_numel, 1, *kernel_size)

for i in range(kernel_size_numel):
extraction = torch.zeros(kernel_size_numel)
extraction[i] = 1.0
weight[i] = extraction.reshape(1, *kernel_size)

repeat = [C_in, 1] + [1 for _ in kernel_size]
return weight.repeat(*repeat)

def get_conv():
functional_for_module_cls = {
torch.nn.Conv1d: conv1d,
torch.nn.Conv2d: conv2d,
torch.nn.Conv3d: conv3d,
}
return functional_for_module_cls[module.__class__]

conv = get_conv()
unfold = conv(
input,
make_weight().to(input.device),
bias=None,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=C_in,
)

return unfold.reshape(N, C_in * kernel_size_numel, -1)
Loading

0 comments on commit a666eb0

Please sign in to comment.