Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1415,15 +1415,17 @@ defmodule Axon do
defp dropout(%Axon{} = x, dropout, opts) do
opts = Keyword.validate!(opts, [:name, :seed, rate: 0.5])
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.BinaryBackend)

if opts[:rate] < 0 or opts[:rate] >= 1 do
raise ArgumentError,
"The dropout rate needs to be >= 0 and < 1, got #{inspect(opts[:rate])}"
end

key_state =
param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end)
param("key", fn _ -> {2} end,
type: {:u, 32},
initializer: fn _, _ -> Nx.Random.key(seed) end
)

layer(dropout, [x, key_state],
name: opts[:name],
Expand Down Expand Up @@ -2788,10 +2790,12 @@ defmodule Axon do

defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer, seed) do
initializer = initializer || :glorot_uniform
key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.BinaryBackend)

key_state =
param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end)
param("key", fn _ -> {2} end,
type: {:u, 32},
initializer: fn _, _ -> Nx.Random.key(seed) end
)

name =
case parent_name do
Expand All @@ -2804,7 +2808,7 @@ defmodule Axon do
"#{parent_name}_#{state_name}_hidden_state"
end

fun = fn inputs, key, opts ->
fun = fn inputs, key, _opts ->
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)

case initializer do
Expand Down
9 changes: 5 additions & 4 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ defmodule Axon.Compiler do
def build(%Axon{output: id, nodes: nodes}, opts) do
debug? = Keyword.get(opts, :debug, false)
mode = Keyword.get(opts, :mode, :inference)
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(:erlang.system_time()) end)
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
config = %{mode: mode, debug?: debug?}

{time, {root_id, {cache, _op_counts}}} =
Expand Down Expand Up @@ -98,15 +98,14 @@ defmodule Axon.Compiler do
end

init_cache = Map.new(cache, fn {_, {int_id, funs}} -> {int_id, funs} end)
key = Nx.backend_copy(key, Nx.BinaryBackend)

init_fun = fn template, init_params ->
{:current_stacktrace, [_process_info, _fn | stacktrace]} =
Process.info(self(), :current_stacktrace)

{time, params} =
:timer.tc(fn ->
param_keys = get_keys(nodes, key)
param_keys = get_keys(nodes, seed)

{_, {params, _}} =
init_cache[root_id][:init].(template, init_cache, %{}, stacktrace, param_keys)
Expand All @@ -126,7 +125,7 @@ defmodule Axon.Compiler do
{init_fun, predict_fun}
end

defp get_keys(nodes, key) do
defp get_keys(nodes, seed) do
{ids_and_data, _op_counts} =
Enum.reduce(nodes, {[], %{}}, fn
{_, %Axon.Node{id: id, op: op, name: name_fn, parameters: params}}, {keys, op_counts} ->
Expand Down Expand Up @@ -163,6 +162,8 @@ defmodule Axon.Compiler do
%{}

[_ | _] = ids ->
key = Nx.Random.key(seed)

keys_tensor =
data
|> Nx.tensor(type: :u32)
Expand Down
41 changes: 9 additions & 32 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,10 @@ defmodule Axon.Loop do
def train_step(model, loss, optimizer, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity, gradient_accumulation_steps: 1])

seed = opts[:seed]
loss_scale = opts[:loss_scale] || :identity
gradient_accumulation_steps = opts[:gradient_accumulation_steps] || 1

{init_model_fn, forward_model_fn} = build_model_fns(model, :train, seed)
{init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts)
loss_fn = build_loss_fn(loss)
{init_optimizer_fn, update_optimizer_fn} = build_optimizer_fns(optimizer)
{init_loss_scale, scale_loss, unscale_grads} = build_loss_scale_fns(loss_scale)
Expand Down Expand Up @@ -507,7 +506,7 @@ defmodule Axon.Loop do
single evaluation step.
"""
def eval_step(model) do
{_, forward_model_fn} = build_model_fns(model, :inference, nil)
{_, forward_model_fn} = build_model_fns(model, :inference, [])

init_fn = fn
{inp, tar}, state ->
Expand Down Expand Up @@ -688,29 +687,14 @@ defmodule Axon.Loop do
between steps, increasing the effective batch size on smaller devices. Defaults to 1.
"""
def trainer(model, loss, optimizer, opts \\ []) do
opts =
Keyword.validate!(opts, [
:seed,
log: 50,
loss_scale: :identity,
gradient_accumulation_steps: 1
])

log_interval = opts[:log] || 50
gradient_accumulation_steps = opts[:gradient_accumulation_steps] || 1
loss_scale = opts[:loss_scale] || :identity
seed = opts[:seed]
opts = Keyword.validate!(opts, [:seed, :loss_scale, :gradient_accumulation_steps, log: 50])

# Build loss now so we can use it as a metric
loss_fn = build_loss_fn(loss)
step_opts = Keyword.take(opts, [:gradient_accumulation_steps, :loss_cale, :seed])
{init_fn, step_fn} = train_step(model, loss_fn, optimizer, step_opts)

{init_fn, step_fn} =
train_step(model, loss_fn, optimizer,
loss_scale: loss_scale,
gradient_accumulation_steps: gradient_accumulation_steps,
seed: seed
)

log_interval = opts[:log] || 50
output_transform = fn state -> state.step_state[:model_state] end

loop =
Expand Down Expand Up @@ -2002,18 +1986,11 @@ defmodule Axon.Loop do
# a tuple of Axon structs, or a tuple of init / forward
# functions. Model functions are essentially just model
# init / apply functions.
defp build_model_fns(%Axon{} = model, mode, seed) do
opts =
if seed != nil do
[mode: mode, key: Nx.Random.key(seed)]
else
[mode: mode]
end

Axon.build(model, opts)
defp build_model_fns(%Axon{} = model, mode, opts) do
Axon.build(model, [mode: mode] ++ opts)
end

defp build_model_fns({init_fn, forward_fn}, _, _seed)
defp build_model_fns({init_fn, forward_fn}, _, _opts)
when is_function(init_fn, 2) and is_function(forward_fn, 2) do
{init_fn, forward_fn}
end
Expand Down