diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 56b6c39b3..4e18c5f65 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1868,7 +1868,9 @@ defmodule Axon.Loop do Logger.debug("Axon.Loop fired event #{inspect(event)}") end - if filter.(state) do + state = update_counts(state, event) + + if filter.(state, event) do case handler.(state) do {:continue, %State{} = state} -> if debug? do @@ -1908,6 +1910,10 @@ defmodule Axon.Loop do end) end + defp update_counts(%State{event_counts: event_counts} = state, event) do + %{state | event_counts: Map.update(event_counts, event, 1, fn x -> x + 1 end)} + end + # Halts an epoch during looping defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do @@ -2155,28 +2161,42 @@ defmodule Axon.Loop do # Builds a filter function from an atom, keyword list, or function. A # valid filter is an atom which matches on of the valid predicates `:always` # or `:once`, a keyword which matches one of the valid predicate-value pairs - # such as `every: N`, or a function which takes loop state and returns `true` - # or `false`. - # - # TODO(seanmor5): In order to handle custom events and predicate filters, - # we will need to track event firings in the loop state. + # such as `every: N`, or a function which takes loop state and the current event + # and returns `true` to run the handler of `false` to avoid it. defp build_filter_fn(filter) do case filter do :always -> - fn _ -> true end + fn _, _ -> true end :first -> - fn - %State{epoch: 0, iteration: 0} -> true - _ -> false + fn %State{event_counts: counts}, event -> + counts[event] == 1 end - [{:every, n} | _] -> - fn %State{iteration: iter} -> - Kernel.rem(iter, n) == 0 - end + filters when is_list(filters) -> + Enum.reduce(filters, fn _, _ -> true end, fn + {:every, n}, acc -> + fn state, event -> + acc.(state, event) and filter_every_n(state, event, n) + end + + {:before, n}, acc -> + fn state, event -> + acc.(state, event) and filter_before_n(state, event, n) + end + + {:after, n}, acc -> + fn state, event -> + acc.(state, event) and filter_after_n(state, event, n) + end + + {:once, n}, acc -> + fn state, event -> + acc.(state, event) and filter_once_n(state, event, n) + end + end) - fun when is_function(fun, 1) -> + fun when is_function(fun, 2) -> fun invalid -> @@ -2184,11 +2204,27 @@ defmodule Axon.Loop do "Invalid filter #{inspect(invalid)}, a valid filter" <> " is an atom which matches a valid filter predicate" <> " such as :always or :once, a keyword of predicate-value" <> - " pairs such as every: N, or an arity-1 function which takes" <> - " loop state and returns true or false" + " pairs such as every: N, or an arity-2 function which takes" <> + " loop state and current event and returns true or false" end end + defp filter_every_n(%State{event_counts: counts}, event, n) do + rem(counts[event] - 1, n) == 0 + end + + defp filter_after_n(%State{event_counts: counts}, event, n) do + counts[event] > n + end + + defp filter_before_n(%State{event_counts: counts}, event, n) do + counts[event] < n + end + + defp filter_once_n(%State{event_counts: counts}, event, n) do + counts[event] == n + end + # JIT-compiles the given function if jit_compile? is true # otherwise just applies the function with the given arguments defp maybe_jit(fun, args, jit_compile?, jit_opts) do diff --git a/lib/axon/loop/state.ex b/lib/axon/loop/state.ex index f0dfa1e8e..09782bb03 100644 --- a/lib/axon/loop/state.ex +++ b/lib/axon/loop/state.ex @@ -39,6 +39,9 @@ defmodule Axon.Loop.State do `handler_metadata` is a metadata field for storing loop handler metadata. For example, loop checkpoints with specific metric criteria can store previous best metrics in the handler meta for use between iterations. + + `event_counts` is a metadata field which stores information about the number + of times each event has been fired. This is useful when creating custom filters. """ @enforce_keys [:step_state] defstruct [ @@ -49,6 +52,16 @@ defmodule Axon.Loop.State do iteration: 0, max_iteration: -1, metrics: %{}, - times: %{} + times: %{}, + event_counts: %{ + started: 0, + epoch_started: 0, + iteration_started: 0, + iteration_completed: 0, + epoch_completed: 0, + epoch_halted: 0, + halted: 0, + completed: 0 + } ] end diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index 8ca0b3779..ec0316cd4 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -645,7 +645,7 @@ defmodule Axon.LoopTest do end) model - |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) + |> Axon.Loop.trainer(:binary_cross_entropy, :sgd, log: -1) |> send_handler(event, filter) |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: iterations) end @@ -663,9 +663,7 @@ defmodule Axon.LoopTest do end test "supports an :always filter" do - ExUnit.CaptureIO.capture_io(fn -> - run_dummy_loop!(:iteration_started, :always, 5, 10) - end) + run_dummy_loop!(:iteration_started, :always, 5, 10) for _ <- 1..50 do assert_received :iteration_started @@ -675,15 +673,97 @@ defmodule Axon.LoopTest do end test "supports an every: n filter" do - ExUnit.CaptureIO.capture_io(fn -> - run_dummy_loop!(:iteration_started, [every: 2], 5, 10) - end) + run_dummy_loop!(:iteration_started, [every: 2], 5, 10) for _ <- 1..25 do assert_received :iteration_started end refute_received :iteration_started + + run_dummy_loop!(:iteration_completed, [every: 3], 3, 10) + + for _ <- 1..10 do + assert_received :iteration_completed + end + + refute_received :iteration_completed + end + + test "supports after: n filter" do + run_dummy_loop!(:iteration_started, [after: 10], 5, 10) + + for _ <- 1..40 do + assert_received :iteration_started + end + + refute_received :iteration_started + + run_dummy_loop!(:iteration_completed, [after: 10], 5, 10) + + for _ <- 1..40 do + assert_received :iteration_completed + end + + refute_received :iteration_completed + end + + test "supports before: n filter" do + run_dummy_loop!(:iteration_started, [before: 10], 5, 10) + + for _ <- 1..9 do + assert_received :iteration_started + end + + refute_received :iteration_started + + run_dummy_loop!(:iteration_completed, [before: 10], 5, 10) + + for _ <- 1..9 do + assert_received :iteration_completed + end + + refute_received :iteration_completed + end + + test "supports once: n filter" do + run_dummy_loop!(:iteration_started, [once: 30], 5, 10) + + assert_received :iteration_started + refute_received :iteration_started + + run_dummy_loop!(:iteration_completed, [once: 30], 5, 10) + + assert_received :iteration_completed + refute_received :iteration_completed + end + + test "supports hybrid filter" do + run_dummy_loop!(:iteration_started, [every: 2, after: 10, before: 40], 5, 10) + + for _ <- 1..15 do + assert_received :iteration_started + end + + refute_received :iteration_started + end + + test "supports :first filter" do + run_dummy_loop!(:iteration_started, :first, 5, 10) + + assert_received :iteration_started + refute_received :iteration_started + end + + test "supports function filter" do + fun = fn + %{event_counts: counts}, event -> counts[event] == 5 + end + + run_dummy_loop!(:iteration_started, fun, 5, 10) + + assert_received :iteration_started + refute_received :iteration_started end end @@ -814,7 +894,7 @@ defmodule Axon.LoopTest do assert Map.has_key?(metrics, "validation_accuracy") {:continue, state} end, - fn %{epoch: epoch} -> epoch == 1 end + fn %{epoch: epoch}, _ -> epoch == 1 end ) |> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5) end) @@ -846,7 +926,7 @@ defmodule Axon.LoopTest do {:continue, state} end, - fn %{epoch: epoch} -> epoch == 1 end + fn %{epoch: epoch}, _ -> epoch == 1 end ) |> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5) end) @@ -934,7 +1014,7 @@ defmodule Axon.LoopTest do {:continue, state} end, - fn %{epoch: epoch} -> epoch == 1 end + fn %{epoch: epoch}, _ -> epoch == 1 end ) |> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5) end)