Skip to content

Commit

Permalink
Add mish activation
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed May 3, 2021
1 parent 272e20f commit e2b5a46
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ defmodule Axon do
end

@activation_layers [:celu, :elu, :exp, :gelu, :hard_sigmoid, :hard_silu, :hard_tanh] ++
[:leaky_relu, :linear, :log_sigmoid, :relu, :relu6] ++
[:leaky_relu, :linear, :log_sigmoid, :mish, :relu, :relu6] ++
[:sigmoid, :silu, :selu, :softmax, :softplus, :softsign, :tanh]

@doc """
Expand Down
26 changes: 26 additions & 0 deletions lib/axon/activations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,32 @@ defmodule Axon.Activations do
"""
defn log_sigmoid(x), do: -softplus(-x)

@doc ~S"""
Mish activation.
$$f(x_i) = x_i* \tanh(\log(1 + e^x_i))$$
## Examples
iex> Axon.Activations.mish(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], type: {:f, 32}, names: [:data]))
#Nx.Tensor<
f32[data: 7]
[-0.14564745128154755, -0.2525014877319336, -0.30340147018432617, 0.0, 0.8650984168052673, 1.9439589977264404, 2.98653507232666]
>
iex> Axon.Activations.mish(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
#Nx.Tensor<
bf16[batch: 2][data: 3]
[
[-0.30078125, -0.25, -0.1435546875],
[0.86328125, 1.9375, 2.96875]
]
>
"""
defn mish(x) do
x * tanh(softplus(x))
end

@doc ~S"""
Rectified linear unit activation.
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ defmodule Axon.Compiler do
## Activation Layers

@activation_layers [:celu, :elu, :exp, :gelu, :hard_sigmoid, :hard_silu, :hard_tanh] ++
[:leaky_relu, :linear, :log_sigmoid, :relu, :relu6] ++
[:leaky_relu, :linear, :log_sigmoid, :mish, :relu, :relu6] ++
[:sigmoid, :silu, :selu, :softmax, :softplus, :softsign, :tanh]

defp recur_predict_fun(%Axon{op: op, parent: parent}, cache, param_map, input_map)
Expand Down

0 comments on commit e2b5a46

Please sign in to comment.