Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
padding: Union[str, int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
Expand Down
19 changes: 10 additions & 9 deletions opacus/grad_sample/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@
Computing per sample gradients is an integral part of Opacus framework. We strive to provide out-of-the-box support for
wide range of models, while keeping computations efficient.

We currently provide two independent approaches for computing per sample gradients: hooks-based ``GradSampleModule``
We currently provide two independent approaches for computing per sample gradients: hooks-based ``GradSampleModule``
(stable implementation, exists since the very first version of Opacus) and ``GradSampleModuleExpandedWeights``
(based on a beta functionality available in PyTorch 1.12).

Each of the two implementations comes with it's own set of limitations, and we leave the choice up to the client
which one to use.
Each of the two implementations comes with it's own set of limitations, and we leave the choice up to the client
which one to use.

``GradSampleModuleExpandedWeights`` is currently in early beta and can produce unexpected errors, but potentially
improves upon ``GradSampleModule`` on performance and functionality.

**TL;DR:** If you want stable implementation, use ``GradSampleModule`` (`grad_sample_mode="hooks"`).
If you want to experiment with the new functionality - try ``GradSampleModuleExpandedWeights``(`grad_sample_mode="ew"`)
and switch back to ``GradSampleModule`` if you encounter strange errors or unexpexted behaviour.
**TL;DR:** If you want stable implementation, use ``GradSampleModule`` (`grad_sample_mode="hooks"`).
If you want to experiment with the new functionality - try ``GradSampleModuleExpandedWeights``(`grad_sample_mode="ew"`)
and switch back to ``GradSampleModule`` if you encounter strange errors or unexpexted behaviour.
We'd also appreciate it if you report these to us

## Hooks-based approach
- Model wrapping class: ``opacus.grad_sample.grad_sample_module.GradSampleModule``
- Keyword argument for ``PrivacyEngine.make_private()``: `grad_sample_mode="hooks"`

Computes per-sample gradients for a model using backward hooks. It requires custom grad sampler methods for every
Computes per-sample gradients for a model using backward hooks. It requires custom grad sampler methods for every
trainable layer in the model. We provide such methods for most popular PyTorch layers. Additionally, client can
provide their own grad sampler for any new unsupported layer (see [tutorial](https://github.com/pytorch/opacus/blob/main/tutorials/guide_to_grad_sampler.ipynb))

Expand All @@ -32,7 +32,7 @@ provide their own grad sampler for any new unsupported layer (see [tutorial](htt

Computes per-sample gradients for a model using core functionality available in PyTorch 1.12+. Unlike hooks-based
grad sampler, which works on a module level, ExpandedWeights work on the function level, i.e. if your layer is not
explicitly supported, but only uses known operations, ExpandedWeights will support it out of the box.
explicitly supported, but only uses known operations, ExpandedWeights will support it out of the box.

At the time of writing, the coverage for custom grad samplers between ``GradSampleModule`` and ``GradSampleModuleExpandedWeights``
is roughly the same.
Expand All @@ -51,4 +51,5 @@ Please note that these are known limitations and we plan to improve Expanded Wei
| Client-provided grad sampler | ✅ Supported | Not supported |
| `batch_first=False` | ✅ Supported | Not supported |
| Most popular nn.* layers | ✅ Supported | ✅ Supported |
| Recurrent networks | ✅ Supported | Not supported |
| Recurrent networks | ✅ Supported | Not supported |
| Padding `same` in Conv | ✅ Supported | Not supported |
12 changes: 11 additions & 1 deletion opacus/grad_sample/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Dict, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from opacus.utils.tensor_utils import unfold2d, unfold3d
from opt_einsum import contract

Expand Down Expand Up @@ -51,10 +53,18 @@ def compute_conv_grad_sample(
elif type(layer) == nn.Conv1d:
activations = activations.unsqueeze(-2) # add the H dimension
# set arguments to tuples with appropriate second element
if layer.padding == "same":
total_pad = layer.dilation[0] * (layer.kernel_size[0] - 1)
left_pad = math.floor(total_pad / 2)
right_pad = total_pad - left_pad
elif layer.padding == "valid":
left_pad, right_pad = 0, 0
else:
left_pad, right_pad = layer.padding[0], layer.padding[0]
activations = F.pad(activations, (left_pad, right_pad))
activations = torch.nn.functional.unfold(
activations,
kernel_size=(1, layer.kernel_size[0]),
padding=(0, layer.padding[0]),
stride=(1, layer.stride[0]),
dilation=(1, layer.dilation[0]),
)
Expand Down
6 changes: 4 additions & 2 deletions opacus/tests/grad_samples/conv1d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn
from hypothesis import given, settings

from .common import GradSampleHooks_test, expander, shrinker
from .common import expander, GradSampleHooks_test, shrinker


class Conv1d_test(GradSampleHooks_test):
Expand All @@ -31,7 +31,7 @@ class Conv1d_test(GradSampleHooks_test):
out_channels_mapper=st.sampled_from([expander, shrinker]),
kernel_size=st.integers(2, 3),
stride=st.integers(1, 2),
padding=st.integers(0, 2),
padding=st.sampled_from([0, 1, 2, 'same', 'valid']),
dilation=st.integers(1, 2),
groups=st.integers(1, 12),
)
Expand All @@ -49,6 +49,8 @@ def test_conv1d(
groups: int,
):

if (padding == 'same' and stride != 1):
return
out_channels = out_channels_mapper(C)
if (
C % groups != 0 or out_channels % groups != 0
Expand Down
31 changes: 20 additions & 11 deletions opacus/tests/grad_samples/conv2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from opacus.utils.tensor_utils import unfold2d
from torch.testing import assert_allclose

from .common import GradSampleHooks_test, expander, shrinker
from .common import expander, GradSampleHooks_test, shrinker


class Conv2d_test(GradSampleHooks_test):
Expand All @@ -34,7 +34,7 @@ class Conv2d_test(GradSampleHooks_test):
out_channels_mapper=st.sampled_from([expander, shrinker]),
kernel_size=st.integers(2, 3),
stride=st.integers(1, 2),
padding=st.sampled_from([0, 2]),
padding=st.sampled_from([0, 2, 'same', 'valid']),
dilation=st.integers(1, 3),
groups=st.integers(1, 16),
)
Expand All @@ -52,7 +52,8 @@ def test_conv2d(
dilation: int,
groups: int,
):

if (padding == 'same' and stride != 1):
return
out_channels = out_channels_mapper(C)
if (
C % groups != 0 or out_channels % groups != 0
Expand All @@ -69,21 +70,29 @@ def test_conv2d(
dilation=dilation,
groups=groups,
)
self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4)
is_ew_compatible = padding != 'same' # TODO add support for padding = 'same' with EW
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does "valid" padding work fine with EW?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. It should since it's basically zero padding. But let me run a test locally and verify this. Good catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, actually on line 37 of this file we have unit tests for both "valid" and "same". So it should work just fine 😎

self.run_test(
x,
conv,
batch_first=True,
atol=10e-5,
rtol=10e-3,
ew_compatible=is_ew_compatible,
)

@given(
B=st.integers(1, 4),
C=st.sampled_from([1, 3, 32]),
H=st.integers(11, 17),
W=st.integers(11, 17),
k_w=st.integers(2, 3),
k_h=st.integers(2, 3),
stride_w=st.integers(1, 2),
k_w=st.integers(2, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.sampled_from([0, 2]),
pad_w=st.sampled_from([0, 2]),
dilation_w=st.integers(1, 3),
dilation_h=st.integers(1, 3),
dilation_w=st.integers(1, 3),
)
@settings(deadline=10000)
def test_unfold2d(
Expand All @@ -92,14 +101,14 @@ def test_unfold2d(
C: int,
H: int,
W: int,
k_w: int,
k_h: int,
pad_w: int,
k_w: int,
pad_h: int,
stride_w: int,
pad_w: int,
stride_h: int,
dilation_w: int,
stride_w: int,
dilation_h: int,
dilation_w: int,
):
X = torch.randn(B, C, H, W)
X_unfold_torch = torch.nn.functional.unfold(
Expand Down
8 changes: 5 additions & 3 deletions opacus/tests/grad_samples/conv3d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn
from hypothesis import given, settings

from .common import GradSampleHooks_test, expander, shrinker
from .common import expander, GradSampleHooks_test, shrinker


class Conv3d_test(GradSampleHooks_test):
Expand All @@ -33,7 +33,7 @@ class Conv3d_test(GradSampleHooks_test):
out_channels_mapper=st.sampled_from([expander, shrinker]),
kernel_size=st.sampled_from([2, 3, (1, 2, 3)]),
stride=st.sampled_from([1, 2, (1, 2, 3)]),
padding=st.sampled_from([0, 2, (1, 2, 3)]),
padding=st.sampled_from([0, 2, (1, 2, 3), 'same', 'valid']),
dilation=st.sampled_from([1, (1, 2, 2)]),
groups=st.integers(1, 16),
)
Expand All @@ -53,6 +53,8 @@ def test_conv3d(
groups: int,
):

if (padding == 'same' and stride != 1):
return
out_channels = out_channels_mapper(C)
if (
C % groups != 0 or out_channels % groups != 0
Expand All @@ -68,7 +70,7 @@ def test_conv3d(
dilation=dilation,
groups=groups,
)
is_ew_compatible = dilation == 1
is_ew_compatible = (dilation == 1 or padding != 'same') # TODO add support for padding = 'same' with EW
self.run_test(
x,
conv,
Expand Down
60 changes: 54 additions & 6 deletions opacus/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
Utils for generating stats from torch tensors.
"""
import math
from typing import Iterator, List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -116,22 +117,46 @@ def unfold2d(
input,
*,
kernel_size: Tuple[int, int],
padding: Tuple[int, int],
padding: Union[str, Tuple[int, int]],
stride: Tuple[int, int],
dilation: Tuple[int, int],
):
"""
See :meth:`~torch.nn.functional.unfold`
"""
*shape, H, W = input.shape
if padding == "same":
total_pad_H = dilation[0] * (kernel_size[0] - 1)
total_pad_W = dilation[1] * (kernel_size[1] - 1)
pad_H_left = math.floor(total_pad_H / 2)
pad_H_right = total_pad_H - pad_H_left
pad_W_left = math.floor(total_pad_W / 2)
pad_W_right = total_pad_W - pad_W_left

elif padding == "valid":
pad_W_left, pad_W_right, pad_H_left, pad_H_right = (0, 0, 0, 0)
else:
pad_H_left, pad_H_right, pad_W_left, pad_W_right = (
padding[0],
padding[0],
padding[1],
padding[1],
)

H_effective = (
H + 2 * padding[0] - (kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1))
H
+ pad_H_left
+ pad_H_right
- (kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1))
) // stride[0] + 1
W_effective = (
W + 2 * padding[1] - (kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1))
W
+ pad_W_left
+ pad_W_right
+ -(kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1))
) // stride[1] + 1
# F.pad's first argument is the padding of the *last* dimension
input = F.pad(input, (padding[1], padding[1], padding[0], padding[0]))
input = F.pad(input, (pad_W_left, pad_W_right, pad_H_left, pad_H_right))
*shape_pad, H_pad, W_pad = input.shape
strides = list(input.stride())
strides = strides[:-2] + [
Expand Down Expand Up @@ -196,13 +221,36 @@ def unfold3d(
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)

if padding == "same":
total_pad_D = dilation[0] * (kernel_size[0] - 1)
total_pad_H = dilation[1] * (kernel_size[1] - 1)
total_pad_W = dilation[2] * (kernel_size[2] - 1)
pad_D_left = math.floor(total_pad_D / 2)
pad_D_right = total_pad_D - pad_D_left
pad_H_left = math.floor(total_pad_H / 2)
pad_H_right = total_pad_H - pad_H_left
pad_W_left = math.floor(total_pad_W / 2)
pad_W_right = total_pad_W - pad_W_left

elif padding == "valid":
pad_D_left, pad_D_right, pad_W_left, pad_W_right, pad_H_left, pad_H_right = (0, 0, 0, 0, 0, 0)
else:
pad_D_left, pad_D_right, pad_H_left, pad_H_right, pad_W_left, pad_W_right = (
padding[0],
padding[0],
padding[1],
padding[1],
padding[2],
padding[2],
)

batch_size, channels, _, _, _ = tensor.shape

# Input shape: (B, C, D, H, W)
tensor = F.pad(
tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0])
tensor, (pad_W_left, pad_W_right, pad_H_left, pad_H_right, pad_D_left, pad_D_right)
)
# Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0])
# Output shape: (B, C, D+pad_W_left+pad_W_right, H+pad_H_left+pad_H_right, W+pad_D_left+pad_D_right)

dilated_kernel_size = (
kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1),
Expand Down