Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,17 @@ defmodule Axon do
@doc type: :special
def param(name, shape, opts \\ [])
when is_binary(name) and (is_tuple(shape) or is_function(shape)) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform)
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32})
initializer = validate_initializer!(opts[:initializer])
type = opts[:type] || {:f, 32}

id = System.unique_integer([:positive, :monotonic])

%Axon.Parameter{
id: id,
name: name,
shape: shape,
type: type,
initializer: initializer
}
end
Expand Down Expand Up @@ -1382,14 +1384,19 @@ defmodule Axon do
end

defp dropout(%Axon{} = x, dropout, opts) do
opts = Keyword.validate!(opts, [:name, rate: 0.5])
opts = Keyword.validate!(opts, [:name, :seed, rate: 0.5])
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.Defn.Expr)

if opts[:rate] < 0 or opts[:rate] >= 1 do
raise ArgumentError,
"The dropout rate needs to be >= 0 and < 1, got #{inspect(opts[:rate])}"
end

layer(dropout, [x],
key_state =
param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end)

layer(dropout, [x, key_state],
name: opts[:name],
rate: opts[:rate],
op_name: dropout
Expand Down
43 changes: 19 additions & 24 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,6 @@ defmodule Axon.Compiler do
{id, {Map.put(cache, id, model_funs), op_counts}}
end

@dropout_layers [:dropout, :spatial_dropout, :feature_alpha_dropout, :alpha_dropout]

defp recur_model_funs(
%Axon.Node{
id: id,
Expand Down Expand Up @@ -526,7 +524,6 @@ defmodule Axon.Compiler do
layer_params,
hooks,
mode,
key,
stacktrace
)

Expand Down Expand Up @@ -606,7 +603,6 @@ defmodule Axon.Compiler do
layer_params,
hooks,
mode,
key,
layer_stacktrace
) do
# Recurse graph inputs and invoke cache to get parent results,
Expand Down Expand Up @@ -646,18 +642,23 @@ defmodule Axon.Compiler do
# parameter map, so we just need to extract them and then apply
# freezing and dtype policy
parameter_inputs =
Enum.map(layer_params, fn %{name: v, frozen: frz} ->
Enum.map(layer_params, fn %{type: type, name: v, frozen: frz} ->
param = params[name][v]

if param != nil do
safe_as_type(maybe_freeze(param, frz), compute)
else
raise ArgumentError,
"parameter #{inspect(v)} for layer: #{inspect(name)} in" <>
" was not present in the given parameter map, this can" <>
" happen if you are using parameters intended for another" <>
" model or did not initialize portions of your model with" <>
" Axon.init/3"
cond do
param != nil and should_cast?(type, compute) ->
safe_as_type(maybe_freeze(param, frz), compute)

param != nil ->
maybe_freeze(param, frz)

true ->
raise ArgumentError,
"parameter #{inspect(v)} for layer: #{inspect(name)} in" <>
" was not present in the given parameter map, this can" <>
" happen if you are using parameters intended for another" <>
" model or did not initialize portions of your model with" <>
" Axon.init/3"
end
end)

Expand All @@ -672,16 +673,6 @@ defmodule Axon.Compiler do
{layer_inputs, rest, [param | inputs]}
end)

# TODO: Hack for dropout with key, fix with a better implementation
opts =
if op in @dropout_layers do
<<data::unsigned-size(32), _rest::binary>> = :erlang.md5(name)
dropout_key = Nx.Random.fold_in(key, data)
opts ++ [key: dropout_key]
else
opts
end

# Compute arguments to be forwarded and ensure `:mode` is included
# for inference/training behavior dependent functions
args = Enum.reverse(tensor_inputs) ++ [Keyword.put(opts, :mode, mode)]
Expand Down Expand Up @@ -887,6 +878,10 @@ defmodule Axon.Compiler do
end
end

defp should_cast?(type1, type2) do
not Nx.Type.integer?(type1) and not Nx.Type.integer?(type2)
end

defp safe_shape(container_or_tensor) do
case container_or_tensor do
%Axon.None{} = none ->
Expand Down
100 changes: 33 additions & 67 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1607,35 +1607,28 @@ defmodule Axon.Layers do
* [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](https://jmlr.org/papers/v15/srivastava14a.html)
"""
@doc type: :dropout
defn dropout(input, opts \\ []) do
opts = keyword!(opts, [:key, :rate, noise_shape: Nx.shape(input), mode: :inference])
defn dropout(input, key, opts \\ []) do
opts = keyword!(opts, [:rate, noise_shape: Nx.shape(input), mode: :inference])
keep_prob = Nx.tensor(1, type: Nx.type(input)) - Nx.tensor(opts[:rate], type: Nx.type(input))

mask =
Nx.less(
Nx.Random.uniform_split(opts[:key], 0, 1, shape: opts[:noise_shape], type: Nx.type(input)),
keep_prob
)
{rand, new_key} =
Nx.Random.uniform(key, 0, 1, shape: opts[:noise_shape], type: Nx.type(input))

mask =
transform(
{mask, Nx.shape(input)},
fn {mask, input_shape} ->
if Elixir.Kernel.==(Nx.shape(mask), input_shape),
do: mask,
else: Nx.broadcast(mask, input_shape)
end
)
rand
|> Nx.less(keep_prob)
|> Nx.broadcast(input)
|> Nx.select(input / keep_prob, Nx.tensor(0, type: Nx.type(input)))
|> dropout_mode_transform(input, new_key, opts[:mode])
end

out = Nx.select(mask, input / keep_prob, Nx.tensor(0, type: Nx.type(input)))
deftransformp dropout_mode_transform(output, input, new_key, mode) do
case mode do
:train ->
%Axon.StatefulOutput{output: output, state: %{"key" => new_key}}

transform({input, out, opts[:mode]}, fn
{input, _, :inference} ->
:inference ->
input

{_, out, :train} ->
out
end)
end
end

@doc """
Expand All @@ -1662,18 +1655,17 @@ defmodule Axon.Layers do
* [Efficient Object Localization Using Convolutional Networks](https://arxiv.org/abs/1411.4280)
"""
@doc type: :dropout
defn spatial_dropout(input, opts \\ []) do
defn spatial_dropout(input, key, opts \\ []) do
assert_min_rank!("Axon.Layers.spatial_dropout", "input", input, 3)

opts = keyword!(opts, [:key, rate: 0.5, channels: :last, mode: :inference])
opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference)

noise_shape =
transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} ->
Axon.Shape.spatial_dropout_noise_shape(shape, channels)
end)

dropout(input,
key: opts[:key],
dropout(input, key,
rate: opts[:rate],
noise_shape: noise_shape,
mode: opts[:mode]
Expand Down Expand Up @@ -1701,34 +1693,26 @@ defmodule Axon.Layers do
* [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
@doc type: :dropout
defn alpha_dropout(input, opts \\ []) do
opts = keyword!(opts, [:key, rate: 0.5, mode: :inference])
defn alpha_dropout(input, key, opts \\ []) do
opts = keyword!(opts, rate: 0.5, mode: :inference)
rate = opts[:rate]

alpha = Nx.tensor(1.6732632423543772848170429916717, type: Nx.type(input))
scale = Nx.tensor(1.0507009873554804934193349852946, type: Nx.type(input))
alpha_p = -alpha * scale
keep_prob = Nx.tensor(1, type: Nx.type(input)) - rate

mask =
Nx.less(
Nx.Random.uniform_split(opts[:key], 0, 1, shape: Nx.shape(input), type: Nx.type(input)),
keep_prob
)
{rand, new_key} = Nx.Random.uniform(key, 0, 1, shape: Nx.shape(input), type: Nx.type(input))

mask = Nx.less(rand, keep_prob)

a = Nx.rsqrt(keep_prob * Nx.power(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2))
b = -a * alpha_p * rate

x = Nx.select(mask, input, alpha_p)
out = a * x + b

transform({input, out, opts[:mode]}, fn
{input, _, :inference} ->
input

{_, out, :train} ->
out
end)
dropout_mode_transform(out, input, new_key, opts[:mode])
end

@doc """
Expand All @@ -1749,10 +1733,10 @@ defmodule Axon.Layers do
Defaults to shape of input tensor.
"""
@doc type: :dropout
defn feature_alpha_dropout(input, opts \\ []) do
defn feature_alpha_dropout(input, key, opts \\ []) do
assert_min_rank!("Axon.Layers.feature_alpha_dropout", "input", input, 3)

opts = keyword!(opts, [:key, rate: 0.5, channels: :last, mode: :inference])
opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference)

noise_shape =
transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} ->
Expand All @@ -1761,31 +1745,13 @@ defmodule Axon.Layers do

keep_prob = 1 - opts[:rate]

mask =
Nx.less(
Nx.Random.uniform_split(opts[:key], 0, 1, shape: noise_shape, type: Nx.type(input)),
keep_prob
)
{rand, new_key} = Nx.Random.uniform(key, 0, 1, shape: noise_shape, type: Nx.type(input))

mask =
transform(
{mask, Nx.shape(input)},
fn {mask, input_shape} ->
if Elixir.Kernel.==(Nx.shape(mask), input_shape),
do: mask,
else: Nx.broadcast(mask, input_shape)
end
)

out = Nx.select(mask, input / keep_prob, Nx.negate(Axon.Activations.selu(input)))

transform({input, out, opts[:mode]}, fn
{input, _, :inference} ->
input

{_, out, :train} ->
out
end)
rand
|> Nx.less(keep_prob)
|> Nx.broadcast(input)
|> Nx.select(input / keep_prob, Nx.negate(Axon.Activations.selu(input)))
|> dropout_mode_transform(input, new_key, opts[:mode])
end

## Global Pooling
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/parameter.ex
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defmodule Axon.Parameter do
@moduledoc false
defstruct [:id, :name, :shape, :initializer, frozen: false]
defstruct [:id, :name, :shape, :initializer, type: {:f, 32}, frozen: false]
end
Loading