Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,9 @@ defmodule Axon.Loop do
Logger.debug("Axon.Loop fired event #{inspect(event)}")
end

if filter.(state) do
state = update_counts(state, event)

if filter.(state, event) do
case handler.(state) do
{:continue, %State{} = state} ->
if debug? do
Expand Down Expand Up @@ -1908,6 +1910,10 @@ defmodule Axon.Loop do
end)
end

defp update_counts(%State{event_counts: event_counts} = state, event) do
%{state | event_counts: Map.update(event_counts, event, 1, fn x -> x + 1 end)}
end

# Halts an epoch during looping
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do
case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do
Expand Down Expand Up @@ -2155,40 +2161,70 @@ defmodule Axon.Loop do
# Builds a filter function from an atom, keyword list, or function. A
# valid filter is an atom which matches on of the valid predicates `:always`
# or `:once`, a keyword which matches one of the valid predicate-value pairs
# such as `every: N`, or a function which takes loop state and returns `true`
# or `false`.
#
# TODO(seanmor5): In order to handle custom events and predicate filters,
# we will need to track event firings in the loop state.
# such as `every: N`, or a function which takes loop state and the current event
# and returns `true` to run the handler of `false` to avoid it.
defp build_filter_fn(filter) do
case filter do
:always ->
fn _ -> true end
fn _, _ -> true end

:first ->
fn
%State{epoch: 0, iteration: 0} -> true
_ -> false
fn %State{event_counts: counts}, event ->
counts[event] == 1
end

[{:every, n} | _] ->
fn %State{iteration: iter} ->
Kernel.rem(iter, n) == 0
end
filters when is_list(filters) ->
Enum.reduce(filters, fn _, _ -> true end, fn
{:every, n}, acc ->
fn state, event ->
acc.(state, event) and filter_every_n(state, event, n)
end

{:before, n}, acc ->
fn state, event ->
acc.(state, event) and filter_before_n(state, event, n)
end

{:after, n}, acc ->
fn state, event ->
acc.(state, event) and filter_after_n(state, event, n)
end

{:once, n}, acc ->
fn state, event ->
acc.(state, event) and filter_once_n(state, event, n)
end
end)

fun when is_function(fun, 1) ->
fun when is_function(fun, 2) ->
fun

invalid ->
raise ArgumentError,
"Invalid filter #{inspect(invalid)}, a valid filter" <>
" is an atom which matches a valid filter predicate" <>
" such as :always or :once, a keyword of predicate-value" <>
" pairs such as every: N, or an arity-1 function which takes" <>
" loop state and returns true or false"
" pairs such as every: N, or an arity-2 function which takes" <>
" loop state and current event and returns true or false"
end
end

defp filter_every_n(%State{event_counts: counts}, event, n) do
rem(counts[event] - 1, n) == 0
end

defp filter_after_n(%State{event_counts: counts}, event, n) do
counts[event] > n
end

defp filter_before_n(%State{event_counts: counts}, event, n) do
counts[event] < n
end

defp filter_once_n(%State{event_counts: counts}, event, n) do
counts[event] == n
end

# JIT-compiles the given function if jit_compile? is true
# otherwise just applies the function with the given arguments
defp maybe_jit(fun, args, jit_compile?, jit_opts) do
Expand Down
15 changes: 14 additions & 1 deletion lib/axon/loop/state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ defmodule Axon.Loop.State do
`handler_metadata` is a metadata field for storing loop handler metadata.
For example, loop checkpoints with specific metric criteria can store
previous best metrics in the handler meta for use between iterations.

`event_counts` is a metadata field which stores information about the number
of times each event has been fired. This is useful when creating custom filters.
"""
@enforce_keys [:step_state]
defstruct [
Expand All @@ -49,6 +52,16 @@ defmodule Axon.Loop.State do
iteration: 0,
max_iteration: -1,
metrics: %{},
times: %{}
times: %{},
event_counts: %{
started: 0,
epoch_started: 0,
iteration_started: 0,
iteration_completed: 0,
epoch_completed: 0,
epoch_halted: 0,
halted: 0,
completed: 0
}
]
end
100 changes: 90 additions & 10 deletions test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ defmodule Axon.LoopTest do
end)

model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd, log: -1)
|> send_handler(event, filter)
|> Axon.Loop.run(data, %{}, epochs: epochs, iterations: iterations)
end
Expand All @@ -663,9 +663,7 @@ defmodule Axon.LoopTest do
end

test "supports an :always filter" do
ExUnit.CaptureIO.capture_io(fn ->
run_dummy_loop!(:iteration_started, :always, 5, 10)
end)
run_dummy_loop!(:iteration_started, :always, 5, 10)

for _ <- 1..50 do
assert_received :iteration_started
Expand All @@ -675,15 +673,97 @@ defmodule Axon.LoopTest do
end

test "supports an every: n filter" do
ExUnit.CaptureIO.capture_io(fn ->
run_dummy_loop!(:iteration_started, [every: 2], 5, 10)
end)
run_dummy_loop!(:iteration_started, [every: 2], 5, 10)

for _ <- 1..25 do
assert_received :iteration_started
end

refute_received :iteration_started

run_dummy_loop!(:iteration_completed, [every: 3], 3, 10)

for _ <- 1..10 do
assert_received :iteration_completed
end

refute_received :iteration_completed
end

test "supports after: n filter" do
run_dummy_loop!(:iteration_started, [after: 10], 5, 10)

for _ <- 1..40 do
assert_received :iteration_started
end

refute_received :iteration_started

run_dummy_loop!(:iteration_completed, [after: 10], 5, 10)

for _ <- 1..40 do
assert_received :iteration_completed
end

refute_received :iteration_completed
end

test "supports before: n filter" do
run_dummy_loop!(:iteration_started, [before: 10], 5, 10)

for _ <- 1..9 do
assert_received :iteration_started
end

refute_received :iteration_started

run_dummy_loop!(:iteration_completed, [before: 10], 5, 10)

for _ <- 1..9 do
assert_received :iteration_completed
end

refute_received :iteration_completed
end

test "supports once: n filter" do
run_dummy_loop!(:iteration_started, [once: 30], 5, 10)

assert_received :iteration_started
refute_received :iteration_started

run_dummy_loop!(:iteration_completed, [once: 30], 5, 10)

assert_received :iteration_completed
refute_received :iteration_completed
end

test "supports hybrid filter" do
run_dummy_loop!(:iteration_started, [every: 2, after: 10, before: 40], 5, 10)

for _ <- 1..15 do
assert_received :iteration_started
end

refute_received :iteration_started
end

test "supports :first filter" do
run_dummy_loop!(:iteration_started, :first, 5, 10)

assert_received :iteration_started
refute_received :iteration_started
end

test "supports function filter" do
fun = fn
%{event_counts: counts}, event -> counts[event] == 5
end

run_dummy_loop!(:iteration_started, fun, 5, 10)

assert_received :iteration_started
refute_received :iteration_started
end
end

Expand Down Expand Up @@ -814,7 +894,7 @@ defmodule Axon.LoopTest do
assert Map.has_key?(metrics, "validation_accuracy")
{:continue, state}
end,
fn %{epoch: epoch} -> epoch == 1 end
fn %{epoch: epoch}, _ -> epoch == 1 end
)
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5)
end)
Expand Down Expand Up @@ -846,7 +926,7 @@ defmodule Axon.LoopTest do

{:continue, state}
end,
fn %{epoch: epoch} -> epoch == 1 end
fn %{epoch: epoch}, _ -> epoch == 1 end
)
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5)
end)
Expand Down Expand Up @@ -934,7 +1014,7 @@ defmodule Axon.LoopTest do

{:continue, state}
end,
fn %{epoch: epoch} -> epoch == 1 end
fn %{epoch: epoch}, _ -> epoch == 1 end
)
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 5)
end)
Expand Down