diff --git a/opacus/grad_sample/conv.py b/opacus/grad_sample/conv.py index e272e4c5..b29c47ad 100644 --- a/opacus/grad_sample/conv.py +++ b/opacus/grad_sample/conv.py @@ -49,8 +49,15 @@ def compute_conv_grad_sample( ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0) return ret + # FSDPWrapper adds a prefix 'FSDP' to layer type, e.g. FSDPConv2d. + # Therefore the layer type can not be directly determined by type(layer). + layer_type = ( + layer.__class__.__bases__[1] + if isinstance(layer, torch.distributed.fsdp.FSDPModule) + else type(layer) + ) # get activations and backprops in shape depending on the Conv layer - if type(layer) is nn.Conv2d: + if layer_type is nn.Conv2d: activations = unfold2d( activations, kernel_size=layer.kernel_size, @@ -58,7 +65,7 @@ def compute_conv_grad_sample( stride=layer.stride, dilation=layer.dilation, ) - elif type(layer) is nn.Conv1d: + elif layer_type is nn.Conv1d: activations = activations.unsqueeze(-2) # add the H dimension # set arguments to tuples with appropriate second element if layer.padding == "same": @@ -76,7 +83,7 @@ def compute_conv_grad_sample( stride=(1, layer.stride[0]), dilation=(1, layer.dilation[0]), ) - elif type(layer) is nn.Conv3d: + elif layer_type is nn.Conv3d: activations = unfold3d( activations, kernel_size=layer.kernel_size,