Skip to content

Commit

Permalink
Pass classifier_activation arg to "Head"
Browse files Browse the repository at this point in the history
  • Loading branch information
Frightera committed Feb 8, 2023
1 parent 0046e0d commit eedabb6
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions keras/applications/convnext.py
Expand Up @@ -324,7 +324,7 @@ def apply(x):
return apply


def Head(num_classes=1000, name=None):
def Head(num_classes=1000, classifier_activation=None, name=None):
"""Implementation of classification head of RegNet.
Args:
Expand All @@ -342,7 +342,9 @@ def apply(x):
x = layers.LayerNormalization(
epsilon=1e-6, name=name + "_head_layernorm"
)(x)
x = layers.Dense(num_classes, name=name + "_head_dense")(x)
x = layers.Dense(num_classes,
activation=classifier_activation,
name=name + "_head_dense")(x)
return x

return apply
Expand Down Expand Up @@ -522,7 +524,9 @@ def ConvNeXt(
cur += depths[i]

if include_top:
x = Head(num_classes=classes, name=model_name)(x)
x = Head(num_classes=classes,
classifier_activation=classifier_activation,
name=model_name)(x)
imagenet_utils.validate_activation(classifier_activation, weights)

else:
Expand Down

0 comments on commit eedabb6

Please sign in to comment.