Skip to content

Commit

Permalink
Merge ff8313d into e0bc8d6
Browse files Browse the repository at this point in the history
  • Loading branch information
sbharadwajj committed Jun 10, 2020
2 parents e0bc8d6 + ff8313d commit 239883a
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 132 deletions.
6 changes: 6 additions & 0 deletions backpack/core/derivatives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torch.nn import (
AvgPool2d,
Conv1d,
Conv2d,
Conv3d,
CrossEntropyLoss,
Dropout,
ELU,
Expand All @@ -17,7 +19,9 @@
)

from .avgpool2d import AvgPool2DDerivatives
from .conv1d import Conv1DDerivatives
from .conv2d import Conv2DDerivatives
from .conv3d import Conv3DDerivatives
from .crossentropyloss import CrossEntropyLossDerivatives
from .elu import ELUDerivatives
from .dropout import DropoutDerivatives
Expand All @@ -34,7 +38,9 @@

derivatives_for = {
Linear: LinearDerivatives,
Conv1d: Conv1DDerivatives,
Conv2d: Conv2DDerivatives,
Conv3d: Conv3DDerivatives,
AvgPool2d: AvgPool2DDerivatives,
MaxPool2d: MaxPool2DDerivatives,
ZeroPad2d: ZeroPad2dDerivatives,
Expand Down
2 changes: 1 addition & 1 deletion backpack/utils/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ def get_conv():
groups=C_in,
)

return unfold.reshape(N, -1, kernel_size_numel)
return unfold.reshape(N, C_in * kernel_size_numel, -1)
167 changes: 36 additions & 131 deletions test/utils/test_conv.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,23 @@
"""Test generalization of unfold to 3d convolutions."""

# TODO: @sbharadwajj: impose test suite structure
# TODO: @sbharadwajj: Test for groups ≠ 1

import torch
import pytest

from backpack.utils.conv import unfold_by_conv, unfold_func

from test.utils.test_conv_settings import SETTINGS
from ..automated_test import check_sizes_and_values
from test.core.derivatives.problem import make_test_problems

###############################################################################
# Get the unfolded input with a convolution instead of torch.unfold #
###############################################################################

torch.manual_seed(0)

# check
UNFOLD_SETTINGS = [
[torch.nn.Conv2d(1, 1, kernel_size=2, bias=False), (1, 1, 3, 3)],
[torch.nn.Conv2d(1, 2, kernel_size=2, bias=False), (1, 1, 3, 3)],
[torch.nn.Conv2d(2, 1, kernel_size=2, bias=False), (1, 2, 3, 3)],
[torch.nn.Conv2d(2, 2, kernel_size=2, bias=False), (1, 2, 3, 3)],
[torch.nn.Conv2d(2, 3, kernel_size=2, bias=False), (3, 2, 11, 13)],
[torch.nn.Conv2d(2, 3, kernel_size=2, padding=1, bias=False), (3, 2, 11, 13)],
[
torch.nn.Conv2d(2, 3, kernel_size=2, padding=1, stride=2, bias=False),
(3, 2, 11, 13),
],
[
torch.nn.Conv2d(
2, 3, kernel_size=2, padding=1, stride=2, dilation=2, bias=False
),
(3, 2, 11, 13),
],
]


def test_unfold_by_conv():
for module, in_shape in UNFOLD_SETTINGS:
input = torch.rand(in_shape)

result_unfold = unfold_func(module)(input).flatten()
result_unfold_by_conv = unfold_by_conv(input, module).flatten()

check_sizes_and_values(result_unfold, result_unfold_by_conv)
PROBLEMS = make_test_problems(SETTINGS)
IDS = [problem.make_id() for problem in PROBLEMS]


###############################################################################
# Perform a convolution with the unfolded input matrix #
###############################################################################
CONV2D_PROBLEMS = [
problem
for problem in PROBLEMS
if isinstance(problem.make_module(), torch.nn.Conv2d)
]
CONV2D_IDS = [problem.make_id() for problem in CONV2D_PROBLEMS]


def convolution_with_unfold(input, module):
Expand Down Expand Up @@ -82,100 +51,36 @@ def get_output_shape(input, module):
return result.reshape(N, C_out, *spatial_out_size)


CONV_SETTINGS = UNFOLD_SETTINGS + [
[
torch.nn.Conv2d(
2, 6, kernel_size=2, padding=1, stride=2, dilation=2, bias=False, groups=2
),
(3, 2, 11, 13),
],
[
torch.nn.Conv2d(
3, 6, kernel_size=2, padding=1, stride=2, dilation=2, bias=False, groups=3
),
(5, 3, 11, 13),
],
[
torch.nn.Conv2d(
16,
33,
kernel_size=(3, 5),
stride=(2, 1),
padding=(4, 2),
dilation=(3, 1),
bias=False,
),
(20, 16, 50, 100),
],
]
@pytest.mark.parametrize("problem", CONV2D_PROBLEMS, ids=CONV2D_IDS)
def test_unfold_by_conv(problem):
"""Test the Unfold by convolution for torch.nn.Conv2d.
Args:
problem (ConvProblem): Problem for testing unfold operation.
"""
problem.set_up()
input = torch.rand(problem.input_shape)

def test_convolution2d_with_unfold():
for module, in_shape in CONV_SETTINGS:
input = torch.rand(in_shape)

result_conv = module(input)
result_conv_by_unfold = convolution_with_unfold(input, module)

check_sizes_and_values(result_conv, result_conv_by_unfold)


CONV_1D_SETTINGS = [
[torch.nn.Conv1d(1, 1, kernel_size=2, bias=False), (1, 1, 3)],
[torch.nn.Conv1d(1, 2, kernel_size=2, bias=False), (1, 1, 3)],
[torch.nn.Conv1d(2, 1, kernel_size=2, bias=False), (1, 2, 3)],
[torch.nn.Conv1d(2, 2, kernel_size=2, bias=False), (1, 2, 3)],
[torch.nn.Conv1d(2, 3, kernel_size=2, bias=False), (3, 2, 11)],
[torch.nn.Conv1d(2, 3, kernel_size=2, padding=1, bias=False), (3, 2, 11)],
[
torch.nn.Conv1d(2, 3, kernel_size=2, padding=1, stride=2, bias=False),
(3, 2, 11),
],
[
torch.nn.Conv1d(
2, 3, kernel_size=2, padding=1, stride=2, dilation=2, bias=False
),
(3, 2, 11),
],
]
result_unfold = unfold_func(problem.module)(input)
result_unfold_by_conv = unfold_by_conv(input, problem.module)

check_sizes_and_values(result_unfold, result_unfold_by_conv)
problem.tear_down()

def test_convolution1d_with_unfold():
for module, in_shape in CONV_1D_SETTINGS:
input = torch.rand(in_shape)

result_conv = module(input)
result_conv_by_unfold = convolution_with_unfold(input, module)

check_sizes_and_values(result_conv, result_conv_by_unfold)


CONV_3D_SETTINGS = [
[torch.nn.Conv3d(1, 1, kernel_size=2, bias=False), (1, 1, 3, 3, 3)],
[torch.nn.Conv3d(1, 2, kernel_size=2, bias=False), (1, 1, 3, 3, 3)],
[torch.nn.Conv3d(2, 1, kernel_size=2, bias=False), (1, 2, 3, 3, 3)],
[torch.nn.Conv3d(2, 2, kernel_size=2, bias=False), (1, 2, 3, 3, 3)],
[torch.nn.Conv3d(2, 3, kernel_size=2, bias=False), (3, 2, 11, 13, 17)],
[torch.nn.Conv3d(2, 3, kernel_size=2, padding=1, bias=False), (3, 2, 11, 13, 17)],
[
torch.nn.Conv3d(2, 3, kernel_size=2, padding=1, stride=2, bias=False),
(3, 2, 11, 13, 17),
],
[
torch.nn.Conv3d(
2, 3, kernel_size=2, padding=1, stride=2, dilation=2, bias=False
),
(3, 2, 11, 13, 17),
],
]

@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_convolution_with_unfold(problem):
"""Test the Unfold operation of torch.nn.Conv1d and torch.nn.Conv3d
by using convolution.
def test_convolution3d_with_unfold():
print("Conv via unfold check")
for module, in_shape in CONV_3D_SETTINGS:
input = torch.rand(in_shape)
Args:
problem (ConvProblem): Problem for testing unfold operation.
"""
problem.set_up()
input = torch.rand(problem.input_shape)

result_conv = module(input)
result_conv_by_unfold = convolution_with_unfold(input, module)
result_conv = problem.module(input)
result_conv_by_unfold = convolution_with_unfold(input, problem.module)

check_sizes_and_values(result_conv, result_conv_by_unfold)
check_sizes_and_values(result_conv, result_conv_by_unfold)
problem.tear_down()

0 comments on commit 239883a

Please sign in to comment.