Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide deterministic builds #427

Merged
merged 11 commits into from
Dec 21, 2022
Merged

Provide deterministic builds #427

merged 11 commits into from
Dec 21, 2022

Conversation

josevalim
Copy link
Contributor

@josevalim josevalim commented Dec 13, 2022

  • Parameter IDs were removed
  • Dropouts are completely removed from the network via a new :mode option
  • Freezing traverse the nodes directly without relying on IDs
  • Removed almost all usage of backend_copy(Nx.Defn.Expr)
  • We use integers as cache keys after the cache is built

lib/axon.ex Outdated Show resolved Hide resolved
@josevalim josevalim changed the title Initial work on deterministic builds Provide deterministic builds Dec 13, 2022
)

# Names are computed lazily, so compute name from current
# op and aggregate op_counts.
name = name_fn.(op_name, op_counts)
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end)

# TODO: Hack for dropout with key, fix with a better implementation
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it should not be based on the name but I would say it should be based on the key. Thoughts?

# Compute arguments to be forwarded and ensure `:mode` is included
# for inference/training behavior dependent functions
args = Enum.reverse(tensor_inputs) ++ [Keyword.put(opts, :mode, mode)]
args = Enum.reverse(tensor_inputs, [Keyword.put(opts, :mode, mode)])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enum.reverse(list, tail) is an efficient version of Enum.reverse(list) ++ tail.

{_, out, :train} ->
out
end)
Nx.select(mask, input / keep_prob, Nx.tensor(0, type: Nx.type(input)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the mode check from dropout because it is no longer relevant.

lib/axon/losses.ex Outdated Show resolved Hide resolved
@@ -114,7 +114,7 @@ defmodule CompilerTest do
x2 = Axon.dense(input, 64)
model = Axon.add(x1, x2)

{init_fn, _predict_fn} = Axon.build(model)
{init_fn, _predict_fn} = Axon.build(model, debug: true)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only get fn stacktraces with debug: true now.

lib/axon.ex Outdated
@@ -2722,7 +2721,8 @@ defmodule Axon do

defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer) do
initializer = initializer || :glorot_uniform
key = Nx.Random.key(:erlang.system_time()) |> Nx.backend_copy(Nx.Defn.Expr)
# TODO: This key should be managed by the compiler
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to dropout.

@josevalim
Copy link
Contributor Author

This is good to go!

lib/axon/losses.ex Outdated Show resolved Hide resolved
lib/axon/losses.ex Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants