From e98286e26ecfd51723fe11ab7781297a65285719 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 13 Dec 2022 05:01:24 -0800 Subject: [PATCH] Use a layer state to manage dropout state so it changes between runs --- lib/axon.ex | 13 +++-- lib/axon/compiler.ex | 43 +++++++--------- lib/axon/layers.ex | 100 ++++++++++++------------------------ lib/axon/parameter.ex | 2 +- test/axon/compiler_test.exs | 70 ++++++++++++++++++++----- test/axon/layers_test.exs | 4 +- 6 files changed, 123 insertions(+), 109 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 00bbef87c..7d07b0de2 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -336,8 +336,9 @@ defmodule Axon do @doc type: :special def param(name, shape, opts \\ []) when is_binary(name) and (is_tuple(shape) or is_function(shape)) do - opts = Keyword.validate!(opts, initializer: :glorot_uniform) + opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}) initializer = validate_initializer!(opts[:initializer]) + type = opts[:type] || {:f, 32} id = System.unique_integer([:positive, :monotonic]) @@ -345,6 +346,7 @@ defmodule Axon do id: id, name: name, shape: shape, + type: type, initializer: initializer } end @@ -1382,14 +1384,19 @@ defmodule Axon do end defp dropout(%Axon{} = x, dropout, opts) do - opts = Keyword.validate!(opts, [:name, rate: 0.5]) + opts = Keyword.validate!(opts, [:name, :seed, rate: 0.5]) + seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end) + key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.Defn.Expr) if opts[:rate] < 0 or opts[:rate] >= 1 do raise ArgumentError, "The dropout rate needs to be >= 0 and < 1, got #{inspect(opts[:rate])}" end - layer(dropout, [x], + key_state = + param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end) + + layer(dropout, [x, key_state], name: opts[:name], rate: opts[:rate], op_name: dropout diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index cca09ac92..c83428bff 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -470,8 +470,6 @@ defmodule Axon.Compiler do {id, {Map.put(cache, id, model_funs), op_counts}} end - @dropout_layers [:dropout, :spatial_dropout, :feature_alpha_dropout, :alpha_dropout] - defp recur_model_funs( %Axon.Node{ id: id, @@ -526,7 +524,6 @@ defmodule Axon.Compiler do layer_params, hooks, mode, - key, stacktrace ) @@ -606,7 +603,6 @@ defmodule Axon.Compiler do layer_params, hooks, mode, - key, layer_stacktrace ) do # Recurse graph inputs and invoke cache to get parent results, @@ -646,18 +642,23 @@ defmodule Axon.Compiler do # parameter map, so we just need to extract them and then apply # freezing and dtype policy parameter_inputs = - Enum.map(layer_params, fn %{name: v, frozen: frz} -> + Enum.map(layer_params, fn %{type: type, name: v, frozen: frz} -> param = params[name][v] - if param != nil do - safe_as_type(maybe_freeze(param, frz), compute) - else - raise ArgumentError, - "parameter #{inspect(v)} for layer: #{inspect(name)} in" <> - " was not present in the given parameter map, this can" <> - " happen if you are using parameters intended for another" <> - " model or did not initialize portions of your model with" <> - " Axon.init/3" + cond do + param != nil and should_cast?(type, compute) -> + safe_as_type(maybe_freeze(param, frz), compute) + + param != nil -> + maybe_freeze(param, frz) + + true -> + raise ArgumentError, + "parameter #{inspect(v)} for layer: #{inspect(name)} in" <> + " was not present in the given parameter map, this can" <> + " happen if you are using parameters intended for another" <> + " model or did not initialize portions of your model with" <> + " Axon.init/3" end end) @@ -672,16 +673,6 @@ defmodule Axon.Compiler do {layer_inputs, rest, [param | inputs]} end) - # TODO: Hack for dropout with key, fix with a better implementation - opts = - if op in @dropout_layers do - <> = :erlang.md5(name) - dropout_key = Nx.Random.fold_in(key, data) - opts ++ [key: dropout_key] - else - opts - end - # Compute arguments to be forwarded and ensure `:mode` is included # for inference/training behavior dependent functions args = Enum.reverse(tensor_inputs) ++ [Keyword.put(opts, :mode, mode)] @@ -887,6 +878,10 @@ defmodule Axon.Compiler do end end + defp should_cast?(type1, type2) do + not Nx.Type.integer?(type1) and not Nx.Type.integer?(type2) + end + defp safe_shape(container_or_tensor) do case container_or_tensor do %Axon.None{} = none -> diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index 8f9237dea..b0177c5b4 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -1607,35 +1607,28 @@ defmodule Axon.Layers do * [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](https://jmlr.org/papers/v15/srivastava14a.html) """ @doc type: :dropout - defn dropout(input, opts \\ []) do - opts = keyword!(opts, [:key, :rate, noise_shape: Nx.shape(input), mode: :inference]) + defn dropout(input, key, opts \\ []) do + opts = keyword!(opts, [:rate, noise_shape: Nx.shape(input), mode: :inference]) keep_prob = Nx.tensor(1, type: Nx.type(input)) - Nx.tensor(opts[:rate], type: Nx.type(input)) - mask = - Nx.less( - Nx.Random.uniform_split(opts[:key], 0, 1, shape: opts[:noise_shape], type: Nx.type(input)), - keep_prob - ) + {rand, new_key} = + Nx.Random.uniform(key, 0, 1, shape: opts[:noise_shape], type: Nx.type(input)) - mask = - transform( - {mask, Nx.shape(input)}, - fn {mask, input_shape} -> - if Elixir.Kernel.==(Nx.shape(mask), input_shape), - do: mask, - else: Nx.broadcast(mask, input_shape) - end - ) + rand + |> Nx.less(keep_prob) + |> Nx.broadcast(input) + |> Nx.select(input / keep_prob, Nx.tensor(0, type: Nx.type(input))) + |> dropout_mode_transform(input, new_key, opts[:mode]) + end - out = Nx.select(mask, input / keep_prob, Nx.tensor(0, type: Nx.type(input))) + deftransformp dropout_mode_transform(output, input, new_key, mode) do + case mode do + :train -> + %Axon.StatefulOutput{output: output, state: %{"key" => new_key}} - transform({input, out, opts[:mode]}, fn - {input, _, :inference} -> + :inference -> input - - {_, out, :train} -> - out - end) + end end @doc """ @@ -1662,18 +1655,17 @@ defmodule Axon.Layers do * [Efficient Object Localization Using Convolutional Networks](https://arxiv.org/abs/1411.4280) """ @doc type: :dropout - defn spatial_dropout(input, opts \\ []) do + defn spatial_dropout(input, key, opts \\ []) do assert_min_rank!("Axon.Layers.spatial_dropout", "input", input, 3) - opts = keyword!(opts, [:key, rate: 0.5, channels: :last, mode: :inference]) + opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference) noise_shape = transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} -> Axon.Shape.spatial_dropout_noise_shape(shape, channels) end) - dropout(input, - key: opts[:key], + dropout(input, key, rate: opts[:rate], noise_shape: noise_shape, mode: opts[:mode] @@ -1701,8 +1693,8 @@ defmodule Axon.Layers do * [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) """ @doc type: :dropout - defn alpha_dropout(input, opts \\ []) do - opts = keyword!(opts, [:key, rate: 0.5, mode: :inference]) + defn alpha_dropout(input, key, opts \\ []) do + opts = keyword!(opts, rate: 0.5, mode: :inference) rate = opts[:rate] alpha = Nx.tensor(1.6732632423543772848170429916717, type: Nx.type(input)) @@ -1710,11 +1702,9 @@ defmodule Axon.Layers do alpha_p = -alpha * scale keep_prob = Nx.tensor(1, type: Nx.type(input)) - rate - mask = - Nx.less( - Nx.Random.uniform_split(opts[:key], 0, 1, shape: Nx.shape(input), type: Nx.type(input)), - keep_prob - ) + {rand, new_key} = Nx.Random.uniform(key, 0, 1, shape: Nx.shape(input), type: Nx.type(input)) + + mask = Nx.less(rand, keep_prob) a = Nx.rsqrt(keep_prob * Nx.power(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2)) b = -a * alpha_p * rate @@ -1722,13 +1712,7 @@ defmodule Axon.Layers do x = Nx.select(mask, input, alpha_p) out = a * x + b - transform({input, out, opts[:mode]}, fn - {input, _, :inference} -> - input - - {_, out, :train} -> - out - end) + dropout_mode_transform(out, input, new_key, opts[:mode]) end @doc """ @@ -1749,10 +1733,10 @@ defmodule Axon.Layers do Defaults to shape of input tensor. """ @doc type: :dropout - defn feature_alpha_dropout(input, opts \\ []) do + defn feature_alpha_dropout(input, key, opts \\ []) do assert_min_rank!("Axon.Layers.feature_alpha_dropout", "input", input, 3) - opts = keyword!(opts, [:key, rate: 0.5, channels: :last, mode: :inference]) + opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference) noise_shape = transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} -> @@ -1761,31 +1745,13 @@ defmodule Axon.Layers do keep_prob = 1 - opts[:rate] - mask = - Nx.less( - Nx.Random.uniform_split(opts[:key], 0, 1, shape: noise_shape, type: Nx.type(input)), - keep_prob - ) + {rand, new_key} = Nx.Random.uniform(key, 0, 1, shape: noise_shape, type: Nx.type(input)) - mask = - transform( - {mask, Nx.shape(input)}, - fn {mask, input_shape} -> - if Elixir.Kernel.==(Nx.shape(mask), input_shape), - do: mask, - else: Nx.broadcast(mask, input_shape) - end - ) - - out = Nx.select(mask, input / keep_prob, Nx.negate(Axon.Activations.selu(input))) - - transform({input, out, opts[:mode]}, fn - {input, _, :inference} -> - input - - {_, out, :train} -> - out - end) + rand + |> Nx.less(keep_prob) + |> Nx.broadcast(input) + |> Nx.select(input / keep_prob, Nx.negate(Axon.Activations.selu(input))) + |> dropout_mode_transform(input, new_key, opts[:mode]) end ## Global Pooling diff --git a/lib/axon/parameter.ex b/lib/axon/parameter.ex index 2c21b672a..4c1c7bded 100644 --- a/lib/axon/parameter.ex +++ b/lib/axon/parameter.ex @@ -1,4 +1,4 @@ defmodule Axon.Parameter do @moduledoc false - defstruct [:id, :name, :shape, :initializer, frozen: false] + defstruct [:id, :name, :shape, :initializer, type: {:f, 32}, frozen: false] end diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 538a66c1f..1cc70d144 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -1015,14 +1015,59 @@ defmodule CompilerTest do @dropout_layers [:dropout, :feature_alpha_dropout, :spatial_dropout, :alpha_dropout] describe "dropout" do - test "initializes with no params" do + test "initializes with key" do for dropout <- @dropout_layers do - model = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32})]) + model = + apply(Axon, dropout, [ + Axon.input("input", shape: {nil, 1, 32}), + [name: "dropout", seed: 0] + ]) input = Nx.random_uniform({1, 1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) - assert %{} = init_fn.(input, %{}) + assert %{"dropout" => %{"key" => key}} = init_fn.(input, %{}) + assert_equal(key, Nx.Random.key(0)) + end + end + + test "same key results in same mask" do + for dropout <- @dropout_layers do + model = + apply(Axon, dropout, [ + Axon.input("input", shape: {nil, 1, 32}), + [name: "dropout", seed: 0] + ]) + + input = Nx.random_uniform({1, 1, 32}) + + assert {init_fn, predict_fn} = Axon.build(model, mode: :train) + + params = init_fn.(input, %{}) + result1 = predict_fn.(params, input) + result2 = predict_fn.(params, input) + + assert_equal(result1, result2) + end + end + + test "does not return same mask with updated key in training mode" do + for dropout <- @dropout_layers do + model = + apply(Axon, dropout, [ + Axon.input("input", shape: {nil, 32, 32}), + [rate: 0.5, name: "dropout", seed: 0] + ]) + + input = Nx.random_uniform({1, 16, 32}) + + assert {init_fn, predict_fn} = Axon.build(model, mode: :train) + + params = init_fn.(input, %{}) + %{prediction: result1, state: new_state} = predict_fn.(params, input) + %{prediction: result2} = predict_fn.(new_state, input) + + assert_not_equal(result1, result2) end end @@ -1031,8 +1076,8 @@ defmodule CompilerTest do model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32})]) input1 = Nx.random_uniform({1, 1, 32}, type: {:f, 32}) - assert {_, predict_fn} = Axon.build(model1, mode: :train) - %{prediction: result1} = predict_fn.(%{}, input1) + assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) + %{prediction: result1} = predict_fn.(init_fn.(input1, %{}), input1) assert Nx.shape(result1) == {1, 1, 32} assert Nx.type(result1) == {:f, 32} @@ -1041,8 +1086,8 @@ defmodule CompilerTest do model2 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 8, 4})]) input2 = Nx.random_uniform({1, 1, 8, 4}, type: {:f, 32}) - assert {_, predict_fn} = Axon.build(model2, mode: :train) - %{prediction: result2} = predict_fn.(%{}, input2) + assert {init_fn, predict_fn} = Axon.build(model2, mode: :train) + %{prediction: result2} = predict_fn.(init_fn.(input2, %{}), input2) assert Nx.shape(result2) == {1, 1, 8, 4} assert Nx.type(result2) == {:f, 32} @@ -1051,8 +1096,8 @@ defmodule CompilerTest do model3 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 8, 4, 2})]) input3 = Nx.random_uniform({1, 1, 8, 4, 2}, type: {:f, 32}) - assert {_, predict_fn} = Axon.build(model3, mode: :train) - %{prediction: result3} = predict_fn.(%{}, input3) + assert {init_fn, predict_fn} = Axon.build(model3, mode: :train) + %{prediction: result3} = predict_fn.(init_fn.(input3, %{}), input3) assert Nx.shape(result3) == {1, 1, 8, 4, 2} assert Nx.type(result3) == {:f, 32} @@ -1066,9 +1111,9 @@ defmodule CompilerTest do model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32}), opts1]) input1 = Nx.random_uniform({1, 1, 32}, type: {:f, 32}) - assert {_, predict_fn} = Axon.build(model1, mode: :train) + assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) - %{prediction: result} = predict_fn.(%{}, input1) + %{prediction: result} = predict_fn.(init_fn.(input1, %{}), input1) assert Nx.shape(result) == {1, 1, 32} assert Nx.type(result) == {:f, 32} @@ -1094,7 +1139,8 @@ defmodule CompilerTest do model = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32})]) input = Nx.random_uniform({1, 1, 32}) - assert_equal(Axon.predict(model, %{}, input), input) + {init_fn, predict_fn} = Axon.build(model) + assert_equal(predict_fn.(init_fn.(input, %{}), input), input) end end end diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index 8f40ba0c4..2781145cf 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -736,7 +736,7 @@ defmodule Axon.LayersTest do assert_raise ArgumentError, ~r/Axon.Layers.spatial_dropout: expected input shape to have at least rank 3/, fn -> - Axon.Layers.spatial_dropout(inp) + Axon.Layers.spatial_dropout(inp, Nx.Random.key(0)) end end end @@ -748,7 +748,7 @@ defmodule Axon.LayersTest do assert_raise ArgumentError, ~r/Axon.Layers.feature_alpha_dropout: expected input shape to have at least rank 3/, fn -> - Axon.Layers.feature_alpha_dropout(inp) + Axon.Layers.feature_alpha_dropout(inp, Nx.Random.key(0)) end end end