Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions opacus/grad_sample/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,23 @@ def compute_conv_grad_sample(
ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0)
return ret

# FSDPWrapper adds a prefix 'FSDP' to layer type, e.g. FSDPConv2d.
# Therefore the layer type can not be directly determined by type(layer).
layer_type = (
layer.__class__.__bases__[1]
if isinstance(layer, torch.distributed.fsdp.FSDPModule)
else type(layer)
)
# get activations and backprops in shape depending on the Conv layer
if type(layer) is nn.Conv2d:
if layer_type is nn.Conv2d:
activations = unfold2d(
activations,
kernel_size=layer.kernel_size,
padding=layer.padding,
stride=layer.stride,
dilation=layer.dilation,
)
elif type(layer) is nn.Conv1d:
elif layer_type is nn.Conv1d:
activations = activations.unsqueeze(-2) # add the H dimension
# set arguments to tuples with appropriate second element
if layer.padding == "same":
Expand All @@ -76,7 +83,7 @@ def compute_conv_grad_sample(
stride=(1, layer.stride[0]),
dilation=(1, layer.dilation[0]),
)
elif type(layer) is nn.Conv3d:
elif layer_type is nn.Conv3d:
activations = unfold3d(
activations,
kernel_size=layer.kernel_size,
Expand Down
Loading