diff --git a/src/zennit/composites.py b/src/zennit/composites.py index 8ed29c7..892a37d 100644 --- a/src/zennit/composites.py +++ b/src/zennit/composites.py @@ -494,6 +494,6 @@ def __init__(self, beta_smooth=10., layer_map=None, zero_params=None, canonizers layer_map = [] layer_map = layer_map + [ - (torch.nn.ReLU, ReLUBetaSmooth()), + (torch.nn.ReLU, ReLUBetaSmooth(beta_smooth=beta_smooth)), ] super().__init__(layer_map=layer_map, canonizers=canonizers)