Skip to content

Commit

Permalink
add mish activation function
Browse files Browse the repository at this point in the history
  • Loading branch information
Ely-S committed Jun 10, 2020
1 parent cc79eb4 commit 28cf011
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
2 changes: 2 additions & 0 deletions efficientdet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def activation_fn(features: tf.Tensor, act_type: Text):
return tf.nn.relu(features)
elif act_type == 'relu6':
return tf.nn.relu6(features)
elif act_type == 'mish':
return features * tf.math.tanh(tf.math.softplus(features))
else:
raise ValueError('Unsupported act_type {}'.format(act_type))

Expand Down
27 changes: 27 additions & 0 deletions efficientdet/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,33 @@ def _model(inputs):
self.assertIs(out.dtype, tf.float16) # output should be float16.


class ActivationTest(tf.test.TestCase):

def test_swish(self):
features = tf.constant([.5, 10])

result = utils.activation_fn(features, "swish")
expected = features * tf.sigmoid(features)
self.assertAllClose(result, expected)

result = utils.activation_fn(features, "swish_native")
self.assertAllClose(result, expected)

def test_relu(self):
features = tf.constant([.5, 10])
result = utils.activation_fn(features, "relu")
self.assertAllClose(result, [0.5, 10])

def test_relu6(self):
features = tf.constant([.5, 10])
result = utils.activation_fn(features, "relu6")
self.assertAllClose(result, [0.5, 6])

def test_mish(self):
features = tf.constant([.5, 10])
result = utils.activation_fn(features, "mish")
self.assertAllClose(result, [0.37524524, 10.0])

if __name__ == '__main__':
logging.set_verbosity(logging.WARNING)
tf.disable_eager_execution()
Expand Down

0 comments on commit 28cf011

Please sign in to comment.