From 333f22c21e931ab899543c2e623746d6ab3f24f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Tue, 27 Dec 2022 11:35:58 -0300 Subject: [PATCH 1/2] Only compute random key on init Also changed Axon.build/2 to require a `:seed` to normalize with other Axon APIs. --- lib/axon.ex | 2 +- lib/axon/compiler.ex | 9 +++++---- lib/axon/loop.ex | 41 +++++++++-------------------------------- 3 files changed, 15 insertions(+), 37 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index c8ae9fb1a..60738c3cb 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -2804,7 +2804,7 @@ defmodule Axon do "#{parent_name}_#{state_name}_hidden_state" end - fun = fn inputs, key, opts -> + fun = fn inputs, key, _opts -> shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type) case initializer do diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index ddc103508..d52b2b8ed 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -33,7 +33,7 @@ defmodule Axon.Compiler do def build(%Axon{output: id, nodes: nodes}, opts) do debug? = Keyword.get(opts, :debug, false) mode = Keyword.get(opts, :mode, :inference) - key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(:erlang.system_time()) end) + seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end) config = %{mode: mode, debug?: debug?} {time, {root_id, {cache, _op_counts}}} = @@ -98,7 +98,6 @@ defmodule Axon.Compiler do end init_cache = Map.new(cache, fn {_, {int_id, funs}} -> {int_id, funs} end) - key = Nx.backend_copy(key, Nx.BinaryBackend) init_fun = fn template, init_params -> {:current_stacktrace, [_process_info, _fn | stacktrace]} = @@ -106,7 +105,7 @@ defmodule Axon.Compiler do {time, params} = :timer.tc(fn -> - param_keys = get_keys(nodes, key) + param_keys = get_keys(nodes, seed) {_, {params, _}} = init_cache[root_id][:init].(template, init_cache, %{}, stacktrace, param_keys) @@ -126,7 +125,7 @@ defmodule Axon.Compiler do {init_fun, predict_fun} end - defp get_keys(nodes, key) do + defp get_keys(nodes, seed) do {ids_and_data, _op_counts} = Enum.reduce(nodes, {[], %{}}, fn {_, %Axon.Node{id: id, op: op, name: name_fn, parameters: params}}, {keys, op_counts} -> @@ -163,6 +162,8 @@ defmodule Axon.Compiler do %{} [_ | _] = ids -> + key = Nx.Random.key(seed) + keys_tensor = data |> Nx.tensor(type: :u32) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 6aec9b0d4..93cab57e9 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -326,11 +326,10 @@ defmodule Axon.Loop do def train_step(model, loss, optimizer, opts \\ []) do opts = Keyword.validate!(opts, [:seed, loss_scale: :identity, gradient_accumulation_steps: 1]) - seed = opts[:seed] loss_scale = opts[:loss_scale] || :identity gradient_accumulation_steps = opts[:gradient_accumulation_steps] || 1 - {init_model_fn, forward_model_fn} = build_model_fns(model, :train, seed) + {init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts) loss_fn = build_loss_fn(loss) {init_optimizer_fn, update_optimizer_fn} = build_optimizer_fns(optimizer) {init_loss_scale, scale_loss, unscale_grads} = build_loss_scale_fns(loss_scale) @@ -507,7 +506,7 @@ defmodule Axon.Loop do single evaluation step. """ def eval_step(model) do - {_, forward_model_fn} = build_model_fns(model, :inference, nil) + {_, forward_model_fn} = build_model_fns(model, :inference, []) init_fn = fn {inp, tar}, state -> @@ -688,29 +687,14 @@ defmodule Axon.Loop do between steps, increasing the effective batch size on smaller devices. Defaults to 1. """ def trainer(model, loss, optimizer, opts \\ []) do - opts = - Keyword.validate!(opts, [ - :seed, - log: 50, - loss_scale: :identity, - gradient_accumulation_steps: 1 - ]) - - log_interval = opts[:log] || 50 - gradient_accumulation_steps = opts[:gradient_accumulation_steps] || 1 - loss_scale = opts[:loss_scale] || :identity - seed = opts[:seed] + opts = Keyword.validate!(opts, [:seed, :loss_scale, :gradient_accumulation_steps, log: 50]) # Build loss now so we can use it as a metric loss_fn = build_loss_fn(loss) + step_opts = Keyword.take(opts, [:gradient_accumulation_steps, :loss_cale, :seed]) + {init_fn, step_fn} = train_step(model, loss_fn, optimizer, step_opts) - {init_fn, step_fn} = - train_step(model, loss_fn, optimizer, - loss_scale: loss_scale, - gradient_accumulation_steps: gradient_accumulation_steps, - seed: seed - ) - + log_interval = opts[:log] || 50 output_transform = fn state -> state.step_state[:model_state] end loop = @@ -2002,18 +1986,11 @@ defmodule Axon.Loop do # a tuple of Axon structs, or a tuple of init / forward # functions. Model functions are essentially just model # init / apply functions. - defp build_model_fns(%Axon{} = model, mode, seed) do - opts = - if seed != nil do - [mode: mode, key: Nx.Random.key(seed)] - else - [mode: mode] - end - - Axon.build(model, opts) + defp build_model_fns(%Axon{} = model, mode, opts) do + Axon.build(model, [mode: mode] ++ opts) end - defp build_model_fns({init_fn, forward_fn}, _, _seed) + defp build_model_fns({init_fn, forward_fn}, _, _opts) when is_function(init_fn, 2) and is_function(forward_fn, 2) do {init_fn, forward_fn} end From 51128809e4142f4f307e54c9e5a761b633967ca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 27 Dec 2022 16:31:17 +0100 Subject: [PATCH 2/2] Same when building Axon layers --- lib/axon.ex | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 60738c3cb..45845a60b 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -1415,7 +1415,6 @@ defmodule Axon do defp dropout(%Axon{} = x, dropout, opts) do opts = Keyword.validate!(opts, [:name, :seed, rate: 0.5]) seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end) - key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.BinaryBackend) if opts[:rate] < 0 or opts[:rate] >= 1 do raise ArgumentError, @@ -1423,7 +1422,10 @@ defmodule Axon do end key_state = - param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end) + param("key", fn _ -> {2} end, + type: {:u, 32}, + initializer: fn _, _ -> Nx.Random.key(seed) end + ) layer(dropout, [x, key_state], name: opts[:name], @@ -2788,10 +2790,12 @@ defmodule Axon do defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer, seed) do initializer = initializer || :glorot_uniform - key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.BinaryBackend) key_state = - param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end) + param("key", fn _ -> {2} end, + type: {:u, 32}, + initializer: fn _, _ -> Nx.Random.key(seed) end + ) name = case parent_name do