From f0743e1f84289afdf268ddeafbf595d0ce102a40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Wed, 14 Dec 2022 12:31:32 +0100 Subject: [PATCH] Build evaluator and zero metrics only once --- lib/axon/loop.ex | 33 ++++++++------------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 56b6c39b3..ebeee56f5 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1048,18 +1048,13 @@ defmodule Axon.Loop do opts = Keyword.validate!(opts, event: :epoch_completed, filter: :always) event = opts[:event] || :epoch_completed filter = opts[:filter] || :always + evaluator = evaluator(model) validation_loop = fn %State{metrics: metrics, step_state: step_state} = state -> %{model_state: model_state} = step_state metrics = - model - |> evaluator() - |> then( - &Enum.reduce(metric_fns, &1, fn {k, {_, v}}, loop -> - metric(loop, v, k) - end) - ) + Enum.reduce(metric_fns, evaluator, fn {k, {_, v}}, loop -> metric(loop, v, k) end) |> log(fn _ -> "\n" end, event: :completed) |> run(validation_data, model_state) |> Access.get(0) @@ -1640,18 +1635,14 @@ defmodule Axon.Loop do Logger.debug("Axon.Loop finished initializing loop state in #{us_to_ms(time)}ms") end + # TODO: Can we infer here? + zero_metrics = Map.new(metric_fns, fn {k, _} -> {k, Nx.tensor(0, type: :f32)} end) + final_metrics_map = - for i <- epoch_start..epoch_end do - {i, Map.new(metric_fns, fn {k, _} -> {k, Nx.tensor(0)} end)} - end - |> Map.new() + epoch_start..epoch_end + |> Map.new(&{&1, zero_metrics}) |> Map.merge(loop_state.metrics) - # TODO: Can we infer here? - zero_metrics = - metric_fns - |> Map.new(fn {k, _} -> {k, Nx.tensor(0, type: :f32)} end) - loop_state = %{loop_state | metrics: zero_metrics} {status, final_metrics, state} = @@ -1707,16 +1698,8 @@ defmodule Axon.Loop do {:halt, {final_metrics_map, state}} {:continue, state} -> - zero_metrics = - loop_state.metrics - |> Map.take(Map.keys(metric_fns)) - |> Map.new(fn {k, v} -> {k, zeros_like(v)} end) - - final_metrics_map = - Map.replace!(final_metrics_map, epoch, state.metrics) - {:cont, - {batch_fn, final_metrics_map, + {batch_fn, %{final_metrics_map | epoch => state.metrics}, %State{ state | epoch: epoch + 1,