Skip to content

Commit

Permalink
ADD: ConvTranspose1/3d derivatives (#80)
Browse files Browse the repository at this point in the history
Refactor tests for 1/2/3d transpose convolution `unfold`, implement derivatives with tests
  • Loading branch information
sbharadwajj committed Jul 2, 2020
1 parent a666eb0 commit f08a697
Show file tree
Hide file tree
Showing 7 changed files with 690 additions and 125 deletions.
8 changes: 7 additions & 1 deletion backpack/core/derivatives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
CrossEntropyLoss,
Dropout,
LeakyReLU,
Expand All @@ -21,9 +23,11 @@

from .avgpool2d import AvgPool2DDerivatives
from .conv1d import Conv1DDerivatives
from .conv_transpose1d import ConvTranspose1DDerivatives
from .conv2d import Conv2DDerivatives
from .conv3d import Conv3DDerivatives
from .conv_transpose2d import ConvTranspose2DDerivatives
from .conv3d import Conv3DDerivatives
from .conv_transpose3d import ConvTranspose3DDerivatives
from .crossentropyloss import CrossEntropyLossDerivatives
from .dropout import DropoutDerivatives
from .elu import ELUDerivatives
Expand All @@ -50,7 +54,9 @@
ReLU: ReLUDerivatives,
Tanh: TanhDerivatives,
Sigmoid: SigmoidDerivatives,
ConvTranspose1d: ConvTranspose1DDerivatives,
ConvTranspose2d: ConvTranspose2DDerivatives,
ConvTranspose3d: ConvTranspose3DDerivatives,
LeakyReLU: LeakyReLUDerivatives,
LogSigmoid: LogSigmoidDerivatives,
ELU: ELUDerivatives,
Expand Down
137 changes: 137 additions & 0 deletions backpack/core/derivatives/conv_transpose1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Partial derivatives for `torch.nn.ConvTranspose1d`."""

import torch
from torch.nn import ConvTranspose1d
from torch.nn.functional import conv1d

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils.conv_transpose import unfold_by_conv_transpose
from backpack.utils.ein import eingroup


class ConvTranspose1DDerivatives(BaseParameterDerivatives):
def get_module(self):
return ConvTranspose1d

def hessian_is_zero(self):
return True

def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
N_axis, L_axis = 1, 3
axes = [L_axis]
if sum_batch:
axes = [N_axis] + axes

return mat.sum(axes)

def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
# expand for each batch and for each channel
N_axis, L_axis = 1, 3
jac_mat = mat.unsqueeze(N_axis).unsqueeze(L_axis)

N, _, L_out = module.output_shape
return jac_mat.expand(-1, N, -1, L_out,)

def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
V = mat.shape[0]
G = module.groups
C_in = module.input0.shape[1]
_, _, K_X = module.weight.shape
N, C_out, L_out = module.output.shape

mat_reshape = mat.reshape(V, C_in, G, C_out // G, K_X)
u = unfold_by_conv_transpose(module.input0, module).reshape(
N, C_in // G, G, K_X, L_out
)

jac_mat = torch.einsum("nigxl,vigox->vngol", u, mat_reshape)

return self.reshape_like_output(jac_mat, module)

def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
_, C_in, L_in = module.input0.size()
in_features = C_in * L_in
_, C_out, H_out = module.output.size()
out_features = C_out * H_out

mat = mat.reshape(out_features, C_out, H_out)
jac_t_mat = self.__jac_t(module, mat).reshape(out_features, in_features)

mat_t_jac = jac_t_mat.t().reshape(in_features, C_out, H_out)
jac_t_mat_t_jac = self.__jac_t(module, mat_t_jac).reshape(
in_features, in_features
)

return jac_t_mat_t_jac.t()

def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
V = mat.shape[0]
G = module.groups
N, C_out, L_out = module.output.shape

mat_reshape = mat.reshape(V, N, G, C_out // G, L_out)

C_in = module.input0.shape[1]
_, _, K_X = module.weight.shape

u = unfold_by_conv_transpose(module.input0, module).reshape(
N, C_in // G, G, K_X, L_out
)

result_str = "vigox" if sum_batch else "vnigox"
equation = "nigxl,vngol->{}".format(result_str)

final_shape = (
(V, *module.weight.shape) if sum_batch else (V, N, *module.weight.shape)
)

return torch.einsum(equation, u, mat_reshape).reshape(final_shape)

def _jac_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,c,l->vn,c,l", mat)
jmp_as_conv = self.__jac(module, mat_as_conv)
return self.reshape_like_output(jmp_as_conv, module)

def __jac(self, module, mat):
C_in = module.input0.shape[1]
_, C_out, L_out = module.output.shape
L_axis = 2

conv1d_t = ConvTranspose1d(
in_channels=C_in,
out_channels=C_out,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
bias=False,
dilation=module.dilation,
groups=module.groups,
).to(module.input0.device)

conv1d_t.weight.data = module.weight

V_N = mat.size(0)
output_size = (V_N, C_out, L_out)

jac_mat = conv1d_t(mat, output_size=output_size).narrow(L_axis, 0, L_out)
return jac_mat

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,c,l->vn,c,l", mat)
jmp_as_conv = self.__jac_t(module, mat_as_conv)
return self.reshape_like_input(jmp_as_conv, module)

def __jac_t(self, module, mat):
"""Apply ConvTranspose1d backward operation."""
L_axis = 2
L_in = module.input0.size(L_axis)

return conv1d(
mat,
module.weight,
bias=None,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
).narrow(L_axis, 0, L_in)
155 changes: 155 additions & 0 deletions backpack/core/derivatives/conv_transpose3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Partial derivatives for `torch.nn.ConvTranspose3d`."""

import torch
from torch.nn import ConvTranspose3d
from torch.nn.functional import conv3d

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils.conv_transpose import unfold_by_conv_transpose
from backpack.utils.ein import eingroup


class ConvTranspose3DDerivatives(BaseParameterDerivatives):
def get_module(self):
return ConvTranspose3d

def hessian_is_zero(self):
return True

def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
N_axis, D_axis, H_axis, W_axis = 1, 3, 4, 5
axes = [D_axis, H_axis, W_axis]
if sum_batch:
axes = [N_axis] + axes

return mat.sum(axes)

def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
# expand for each batch and for each channel
N_axis, D_axis, H_axis, W_axis = 1, 3, 4, 5
jac_mat = (
mat.unsqueeze(N_axis).unsqueeze(D_axis).unsqueeze(H_axis).unsqueeze(W_axis)
)

N, _, D_out, H_out, W_out = module.output_shape
return jac_mat.expand(-1, N, -1, D_out, H_out, W_out)

def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
V = mat.shape[0]
G = module.groups
C_in = module.input0.shape[1]
_, _, K_X, K_Y, K_Z = module.weight.shape
N, C_out, D_out, H_out, W_out = module.output.shape

mat_reshape = mat.reshape(V, C_in, G, C_out // G, K_X, K_Y, K_Z)
u = unfold_by_conv_transpose(module.input0, module).reshape(
N, C_in // G, G, K_X, K_Y, K_Z, D_out, H_out, W_out
)

jac_mat = torch.einsum("nigxyzdhw,vigoxyz->vngodhw", u, mat_reshape)

return self.reshape_like_output(jac_mat, module)

def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
_, C_in, D_in, H_in, W_in = module.input0.size()
in_features = C_in * D_in * H_in * W_in
_, C_out, D_out, H_out, W_out = module.output.size()
out_features = C_out * D_out * H_out * W_out

mat = mat.reshape(out_features, C_out, D_out, H_out, W_out)
jac_t_mat = self.__jac_t(module, mat).reshape(out_features, in_features)

mat_t_jac = jac_t_mat.t().reshape(in_features, C_out, D_out, H_out, W_out)
jac_t_mat_t_jac = self.__jac_t(module, mat_t_jac).reshape(
in_features, in_features
)

return jac_t_mat_t_jac.t()

def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
V = mat.shape[0]
G = module.groups
N, C_out, D_out, H_out, W_out = module.output.shape

mat_reshape = mat.reshape(V, N, G, C_out // G, D_out, H_out, W_out)

C_in = module.input0.shape[1]
_, _, K_X, K_Y, K_Z = module.weight.shape

u = unfold_by_conv_transpose(module.input0, module).reshape(
N, C_in // G, G, K_X, K_Y, K_Z, D_out, H_out, W_out
)

result_str = "vigoxyz" if sum_batch else "vnigoxyz"
equation = "nigxyzdhw,vngodhw->{}".format(result_str)

final_shape = (
(V, *module.weight.shape) if sum_batch else (V, N, *module.weight.shape)
)

return torch.einsum(equation, u, mat_reshape).reshape(final_shape)

def _jac_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,c,d,h,w->vn,c,d,h,w", mat)
jmp_as_conv = self.__jac(module, mat_as_conv)
return self.reshape_like_output(jmp_as_conv, module)

def __jac(self, module, mat):
C_in = module.input0.shape[1]
_, C_out, D_out, H_out, W_out = module.output.shape
D_axis = 2
H_axis = 3
W_axis = 4

conv3d_t = ConvTranspose3d(
in_channels=C_in,
out_channels=C_out,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
bias=False,
dilation=module.dilation,
groups=module.groups,
).to(module.input0.device)

conv3d_t.weight.data = module.weight

V_N = mat.size(0)
output_size = (V_N, C_out, D_out, H_out, W_out)

jac_mat = (
conv3d_t(mat, output_size=output_size)
.narrow(D_axis, 0, D_out)
.narrow(H_axis, 0, H_out)
.narrow(W_axis, 0, W_out)
)
return jac_mat

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = eingroup("v,n,c,d,h,w->vn,c,d,h,w", mat)
jmp_as_conv = self.__jac_t(module, mat_as_conv)
return self.reshape_like_input(jmp_as_conv, module)

def __jac_t(self, module, mat):
"""Apply ConvTranspose3d backward operation."""
D_axis = 2
H_axis = 3
W_axis = 4
D_in = module.input0.size(D_axis)
H_in = module.input0.size(H_axis)
W_in = module.input0.size(W_axis)

return (
conv3d(
mat,
module.weight,
bias=None,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
.narrow(D_axis, 0, D_in)
.narrow(H_axis, 0, H_in)
.narrow(W_axis, 0, W_in)
)
2 changes: 1 addition & 1 deletion backpack/utils/conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_conv_transpose():
conv_transpose = get_conv_transpose()
unfold = conv_transpose(
input,
make_weight(),
make_weight().to(module.weight.device),
bias=None,
stride=module.stride,
padding=module.padding,
Expand Down
Loading

0 comments on commit f08a697

Please sign in to comment.