Skip to content

Commit

Permalink
[BatchDotGrad] Implement other parameterized layers, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Aug 21, 2020
1 parent d1f76f4 commit 97c3b5b
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 4 deletions.
29 changes: 27 additions & 2 deletions backpack/extensions/firstorder/batch_dot_grad/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
from torch.nn import Linear
from torch.nn import (
BatchNorm1d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
Linear,
)

from backpack.extensions.backprop_extension import BackpropExtension

from . import linear
from . import (
batchnorm1d,
conv1d,
conv2d,
conv3d,
conv_transpose1d,
conv_transpose2d,
conv_transpose3d,
linear,
)


class BatchDotGrad(BackpropExtension):
Expand Down Expand Up @@ -33,5 +51,12 @@ def __init__(self):
fail_mode="WARNING",
module_exts={
Linear: linear.BatchDotGradLinear(),
Conv1d: conv1d.BatchDotGradConv1d(),
Conv2d: conv2d.BatchDotGradConv2d(),
Conv3d: conv3d.BatchDotGradConv3d(),
ConvTranspose1d: conv_transpose1d.BatchDotGradConvTranspose1d(),
ConvTranspose2d: conv_transpose2d.BatchDotGradConvTranspose2d(),
ConvTranspose3d: conv_transpose3d.BatchDotGradConvTranspose3d(),
BatchNorm1d: batchnorm1d.BatchDotGradBatchNorm1d(),
},
)
9 changes: 9 additions & 0 deletions backpack/extensions/firstorder/batch_dot_grad/batchnorm1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives
from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase


class BatchDotGradBatchNorm1d(BatchDotGradBase):
def __init__(self):
super().__init__(
derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"]
)
7 changes: 7 additions & 0 deletions backpack/extensions/firstorder/batch_dot_grad/conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from backpack.core.derivatives.conv1d import Conv1DDerivatives
from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase


class BatchDotGradConv1d(BatchDotGradBase):
def __init__(self):
super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"])
7 changes: 7 additions & 0 deletions backpack/extensions/firstorder/batch_dot_grad/conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from backpack.core.derivatives.conv2d import Conv2DDerivatives
from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase


class BatchDotGradConv2d(BatchDotGradBase):
def __init__(self):
super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"])
7 changes: 7 additions & 0 deletions backpack/extensions/firstorder/batch_dot_grad/conv3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from backpack.core.derivatives.conv3d import Conv3DDerivatives
from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase


class BatchDotGradConv3d(BatchDotGradBase):
def __init__(self):
super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives
from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase


class BatchDotGradConvTranspose1d(BatchDotGradBase):
def __init__(self):
super().__init__(
derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"]
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives
from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase


class BatchDotGradConvTranspose2d(BatchDotGradBase):
def __init__(self):
super().__init__(
derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"]
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives
from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase


class BatchDotGradConvTranspose3d(BatchDotGradBase):
def __init__(self):
super().__init__(
derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,103 @@
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((4,), 3),
},
# classification
{
"input_fn": lambda: torch.rand(3, 10),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Linear(10, 7), torch.nn.Linear(7, 5)
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
"input_fn": lambda: torch.rand(3, 10),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5)
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
},
# Regression
{
"input_fn": lambda: torch.rand(3, 10),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5)
),
"loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((3, 5)),
},
]

###############################################################################
# test setting: Convolutional Layers #
###############################################################################

BATCHDOTGRAD_SETTINGS += [
# TODO: Implement `BatchDotGrad` for conv layers
# TODO: Add more settings with convolutional layers
{
"input_fn": lambda: torch.rand(3, 3, 7),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Conv1d(3, 2, 2),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(12, 5),
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
"input_fn": lambda: torch.rand(3, 3, 7, 7),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Conv2d(3, 2, 2),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(72, 5),
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
"input_fn": lambda: torch.rand(3, 3, 2, 7, 7),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Conv3d(3, 2, 2),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(72, 5),
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
"input_fn": lambda: torch.rand(3, 3, 7),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.ConvTranspose1d(3, 2, 2),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(16, 5),
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
"input_fn": lambda: torch.rand(3, 3, 7, 7),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.ConvTranspose2d(3, 2, 2),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(128, 5),
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
"input_fn": lambda: torch.rand(3, 3, 2, 7, 7),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.ConvTranspose3d(3, 2, 2),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(384, 5),
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
"target_fn": lambda: classification_targets((3,), 5),
},
]

0 comments on commit 97c3b5b

Please sign in to comment.