-
Notifications
You must be signed in to change notification settings - Fork 116
Description
Hi 👋
this is the 2nd issue I'm opening in Axon in a short period of time. I hope I'm not misusing the medium, I'm learning ML and therefore there is high chances that I'm doing something funky. Please bear with me.
Again, thanks for all the work you are doing, that's great and I'm enjoy learning ML using the Nx* libraries so far.
The issue
While playing with Axon I noticed that the training slows down when increasing the number of epochs. I'm not referring to the overall time, which is of course expected, but the time to complete a single epoch.
Here a quick benchmark that I assembled.
Mix.install(
[
{:exla, "~> 0.4"},
{:nx, "~> 0.4"},
{:axon, "~> 0.3.1"},
{:benchee, "~> 1.1.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
# Generate the data
# 2 inputs
inputs =
Nx.iota({9000, 2}, type: :f32)
|> Nx.divide(9000)
|> Nx.subtract(0.5)
|> Nx.shuffle()
# one-hot encode the labels
labels =
Enum.map(0..8999, fn _ -> Enum.random([0, 1]) end)
|> Nx.tensor()
|> Nx.reshape({:auto, 1})
|> Nx.equal(Nx.tensor([0, 1]))
# split the dataset in batches
batch_size = 250
inputs_batches = Nx.to_batched(inputs, batch_size)
labels_batches = Nx.to_batched(labels, batch_size)
train_batches = Stream.zip(inputs_batches, labels_batches)
defmodule CustomEventHandler do
def maybe_exit(%Axon.Loop.State{epoch: epoch} = state) do
if epoch > 100 do
IO.puts("Early exit!")
{:halt_loop, state}
else
{:continue, state}
end
end
end
# Create a loop that early exits after the 100th epoch
loop =
Axon.input("data")
|> Axon.dense(100, activation: :sigmoid)
|> Axon.dense(2, activation: :softmax)
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.handle(:epoch_started, &CustomEventHandler.maybe_exit/1)
Benchee.run(
%{
"100" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 100, compiler: EXLA) end,
"1000" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 1000, compiler: EXLA) end,
"10_000" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 10_000, compiler: EXLA) end
},
time: 4,
memory_time: 2
)And here the results:
Name ips average deviation median 99th %
100 0.39 2.58 s ±0.32% 2.58 s 2.59 s
1000 0.25 4.04 s ±0.00% 4.04 s 4.04 s
10_000 0.0559 17.90 s ±0.00% 17.90 s 17.90 s
Comparison:
100 0.39
1000 0.25 - 1.56x slower +1.46 s
10_000 0.0559 - 6.92x slower +15.31 s
Memory usage statistics:
Name Memory usage
100 246.69 MB
1000 250.49 MB - 1.02x memory usage +3.80 MB
10_000 263.74 MB - 1.07x memory usage +17.05 MB
**All measurements for memory usage were the same**
Given that the training is always exiting after the 100th epoch, I'd expect comparable results, instead there is a remarkable bump-up when the max epoch is set to 10000. Is it something expected?
Thanks in advance, and please let me know if there is anything else I can do for you 🙇♂️
Best,
Nicolò