Skip to content

Time to complete an epoch depends on the number of epochs #430

@nickgnd

Description

@nickgnd

Hi 👋

this is the 2nd issue I'm opening in Axon in a short period of time. I hope I'm not misusing the medium, I'm learning ML and therefore there is high chances that I'm doing something funky. Please bear with me.

Again, thanks for all the work you are doing, that's great and I'm enjoy learning ML using the Nx* libraries so far.


The issue

While playing with Axon I noticed that the training slows down when increasing the number of epochs. I'm not referring to the overall time, which is of course expected, but the time to complete a single epoch.

Here a quick benchmark that I assembled.

Mix.install(
  [
    {:exla, "~> 0.4"},
    {:nx, "~> 0.4"},
    {:axon, "~> 0.3.1"},
    {:benchee, "~> 1.1.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

# Generate the data

# 2 inputs
inputs =
  Nx.iota({9000, 2}, type: :f32)
  |> Nx.divide(9000)
  |> Nx.subtract(0.5)
  |> Nx.shuffle()

# one-hot encode the labels
labels =
  Enum.map(0..8999, fn _ -> Enum.random([0, 1]) end)
  |> Nx.tensor()
  |> Nx.reshape({:auto, 1})
  |> Nx.equal(Nx.tensor([0, 1]))

# split the dataset in batches
batch_size = 250
inputs_batches = Nx.to_batched(inputs, batch_size)
labels_batches = Nx.to_batched(labels, batch_size)
train_batches = Stream.zip(inputs_batches, labels_batches)

defmodule CustomEventHandler do
  def maybe_exit(%Axon.Loop.State{epoch: epoch} = state) do
    if epoch > 100 do
      IO.puts("Early exit!")
      {:halt_loop, state}
    else
      {:continue, state}
    end
  end
end

# Create a loop that early exits after the 100th epoch
loop =
  Axon.input("data")
  |> Axon.dense(100, activation: :sigmoid)
  |> Axon.dense(2, activation: :softmax)
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.handle(:epoch_started, &CustomEventHandler.maybe_exit/1)

Benchee.run(
  %{
    "100" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 100, compiler: EXLA) end,
    "1000" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 1000, compiler: EXLA) end,
    "10_000" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 10_000, compiler: EXLA) end
  },
  time: 4,
  memory_time: 2
)

And here the results:

Name             ips        average  deviation         median         99th %
100             0.39         2.58 s     ±0.32%         2.58 s         2.59 s
1000            0.25         4.04 s     ±0.00%         4.04 s         4.04 s
10_000        0.0559        17.90 s     ±0.00%        17.90 s        17.90 s

Comparison:
100             0.39
1000            0.25 - 1.56x slower +1.46 s
10_000        0.0559 - 6.92x slower +15.31 s

Memory usage statistics:

Name      Memory usage
100          246.69 MB
1000         250.49 MB - 1.02x memory usage +3.80 MB
10_000       263.74 MB - 1.07x memory usage +17.05 MB

**All measurements for memory usage were the same**

Given that the training is always exiting after the 100th epoch, I'd expect comparable results, instead there is a remarkable bump-up when the max epoch is set to 10000. Is it something expected?

Thanks in advance, and please let me know if there is anything else I can do for you 🙇‍♂️

Best,
Nicolò

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions