diff --git a/keras/applications/convnext.py b/keras/applications/convnext.py index 01a0a5e2b8a..bf488f051ef 100644 --- a/keras/applications/convnext.py +++ b/keras/applications/convnext.py @@ -324,7 +324,7 @@ def apply(x): return apply -def Head(num_classes=1000, name=None): +def Head(num_classes=1000, classifier_activation=None, name=None): """Implementation of classification head of RegNet. Args: @@ -342,7 +342,9 @@ def apply(x): x = layers.LayerNormalization( epsilon=1e-6, name=name + "_head_layernorm" )(x) - x = layers.Dense(num_classes, name=name + "_head_dense")(x) + x = layers.Dense(num_classes, + activation=classifier_activation, + name=name + "_head_dense")(x) return x return apply @@ -522,7 +524,9 @@ def ConvNeXt( cur += depths[i] if include_top: - x = Head(num_classes=classes, name=model_name)(x) + x = Head(num_classes=classes, + classifier_activation=classifier_activation, + name=model_name)(x) imagenet_utils.validate_activation(classifier_activation, weights) else: