Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions lib/axon/activations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,16 @@ defmodule Axon.Activations do
"""
defn celu(x, opts \\ []) do
opts = keyword!(opts, alpha: 1.0)

transform(
opts[:alpha],
fn x ->
if Elixir.Kernel.==(x, 0),
do: raise(ArgumentError, ":alpha must be non-zero in CELU activation")
end
)
validate_celu_alpha!(opts[:alpha])

Nx.select(Nx.greater(x, 0), x, opts[:alpha] * Nx.expm1(x / opts[:alpha]))
end

deftransformp validate_celu_alpha!(alpha) do
if alpha == 0,
do: raise(ArgumentError, ":alpha must be non-zero in CELU activation")
end

@doc ~S"""
Exponential linear unit activation.

Expand Down Expand Up @@ -376,7 +374,7 @@ defmodule Axon.Activations do
"""
defn log_sumexp(x, opts \\ []) do
opts = keyword!(opts, axis: -1)
axes = transform(opts[:axis], &List.wrap/1)
axes = wrap(opts[:axis])

# This is a scaling term designed to prevent over/under flow when x is very
# large. Consider cases where the intermediate value e^x with large positive
Expand Down Expand Up @@ -457,12 +455,6 @@ defmodule Axon.Activations do
defn log_softmax(x, opts \\ []) do
opts = keyword!(opts, axis: -1)

transform({x, opts}, fn {x, opts} ->
if Elixir.Kernel.<=(Nx.rank(x), opts[:axis]) do
raise ArgumentError, "log_softmax axis must be within rank of tensor"
end
end)

shifted = x - stop_grad(Nx.reduce_max(x, axes: [opts[:axis]], keep_axes: true))

shifted
Expand Down Expand Up @@ -594,7 +586,7 @@ defmodule Axon.Activations do
defn sigmoid(x) do
# Cache logits so they are available in certain calculations,
# e.g. binary_cross_entropy and categorical_cross_entropy
transform(Nx.sigmoid(x), &Nx.Defn.Expr.metadata(&1, %{logits: x}))
cache_logits(x, Nx.sigmoid(x))
end

@doc ~S"""
Expand Down Expand Up @@ -708,13 +700,7 @@ defmodule Axon.Activations do
"""
defn softmax(x, opts \\ []) do
opts = keyword!(opts, axis: -1)
axes = transform(opts[:axis], &List.wrap/1)

transform({x, axes}, fn {x, axes} ->
Enum.each(axes, fn axis ->
Nx.Shape.normalize_axis(Nx.shape(x), axis, Nx.names(x))
end)
end)
axes = wrap(opts[:axis])

# This is a scaling term designed to prevent over/under flow when x is very
# large. Consider cases where the intermediate value e^x with large positive
Expand Down Expand Up @@ -745,7 +731,7 @@ defmodule Axon.Activations do

# Cache logits so they are available in certain calculations,
# e.g. binary_cross_entropy and categorical_cross_entropy
transform(res, &Nx.Defn.Expr.metadata(&1, %{logits: x}))
cache_logits(x, res)
end

@doc ~S"""
Expand Down Expand Up @@ -837,4 +823,12 @@ defmodule Axon.Activations do

"""
defn tanh(x), do: Nx.tanh(x)

## Helpers

deftransformp cache_logits(input, output) do
Nx.Defn.Expr.metadata(output, %{logits: input})
end

deftransformp wrap(axis), do: List.wrap(axis)
end
52 changes: 25 additions & 27 deletions lib/axon/initializers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -684,33 +684,20 @@ defmodule Axon.Initializers do

assert_min_rank!("Axon.Initializers.orthogonal", "input_shape", shape, 2)

{{m, n}, random_seed} =
transform({key, shape, distribution, type}, fn {key, shape, distribution, type} ->
flat_shape =
if tuple_size(shape) > 2 do
tuple_list = shape |> Tuple.to_list() |> Enum.reverse()
n = hd(tuple_list)
m = Enum.reduce(tl(tuple_list), 1, &(&1 * &2))
{m, n}
else
shape
end

out =
case distribution do
:uniform ->
Nx.Random.uniform_split(key, 0.0, 1.0, shape: flat_shape, type: type)

:normal ->
Nx.Random.normal_split(key, 0.0, 1.0, shape: flat_shape, type: type)

dist ->
raise ArgumentError,
"invalid distribution #{inspect(dist)} passed to orthogonal/1"
end

{flat_shape, out}
end)
{m, n} = get_flat_shape(shape)

random_seed =
case distribution do
:uniform ->
Nx.Random.uniform_split(key, 0.0, 1.0, shape: {m, n}, type: type)

:normal ->
Nx.Random.normal_split(key, 0.0, 1.0, shape: {m, n}, type: type)

dist ->
raise ArgumentError,
"invalid distribution #{inspect(dist)} passed to orthogonal/1"
end

{q, _r} = Nx.LinAlg.qr(random_seed, mode: :complete)

Expand All @@ -722,6 +709,17 @@ defmodule Axon.Initializers do
rand
end

deftransformp get_flat_shape(shape) do
if tuple_size(shape) > 2 do
tuple_list = shape |> Tuple.to_list() |> Enum.reverse()
n = hd(tuple_list)
m = Enum.reduce(tl(tuple_list), 1, &(&1 * &2))
{m, n}
else
shape
end
end

# Variance scaling branches

defnp var_normal(key, variance, opts \\ []) do
Expand Down
Loading