Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 115 additions & 78 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,26 @@ defmodule Axon.Loop do
{init_optimizer_fn, update_optimizer_fn} = build_optimizer_fns(optimizer)
{init_loss_scale, scale_loss, unscale_grads} = build_loss_scale_fns(loss_scale)

init_fn = fn {inp, _}, init_model_state ->
model_state = init_model_fn.(inp, init_model_state)
optimizer_state = init_optimizer_fn.(model_state)
loss_scale_state = init_loss_scale.()
init_fn = fn
{inp, _}, %{} = init_model_state ->
model_state = init_model_fn.(inp, init_model_state)
optimizer_state = init_optimizer_fn.(model_state)
loss_scale_state = init_loss_scale.()

%{
i: Nx.tensor(0),
y_true: Nx.tensor(0.0),
y_pred: Nx.tensor(0.0),
loss: Nx.tensor(0.0),
gradient_step: Nx.tensor(0),
model_state: model_state,
gradient_state: zeros_like(model_state),
optimizer_state: optimizer_state,
loss_scale_state: loss_scale_state
}
%{
i: Nx.tensor(0),
y_true: Nx.tensor(0.0),
y_pred: Nx.tensor(0.0),
loss: Nx.tensor(0.0),
gradient_step: Nx.tensor(0),
model_state: model_state,
gradient_state: zeros_like(model_state),
optimizer_state: optimizer_state,
loss_scale_state: loss_scale_state
}

data, state ->
raise_bad_training_inputs!(data, state)
end

# TODO: We should probably compute in same compute policy as MP
Expand All @@ -345,59 +349,63 @@ defmodule Axon.Loop do
{model_out, loss}
end

step_fn = fn {inp, tar}, state ->
%{
i: i,
gradient_step: gradient_step,
loss_scale_state: loss_scale_state,
gradient_state: gradient_state,
model_state: model_state,
optimizer_state: optimizer_state,
loss: loss
} = state

{{model_out, batch_loss}, gradients} =
Nx.Defn.value_and_grad(
model_state,
&objective_fn.(&1, loss_scale_state, inp, tar),
fn x -> elem(x, 1) end
)

{gradients, new_loss_scale_state} = unscale_grads.(gradients, loss_scale_state)
step_fn = fn
{inp, tar}, %{} = state ->
%{
i: i,
gradient_step: gradient_step,
loss_scale_state: loss_scale_state,
gradient_state: gradient_state,
model_state: model_state,
optimizer_state: optimizer_state,
loss: loss
} = state

{{model_out, batch_loss}, gradients} =
Nx.Defn.value_and_grad(
model_state,
&objective_fn.(&1, loss_scale_state, inp, tar),
fn x -> elem(x, 1) end
)

preds = model_out.prediction
new_state = model_out.state
{gradients, new_loss_scale_state} = unscale_grads.(gradients, loss_scale_state)

new_loss =
loss
|> Nx.multiply(i)
|> Nx.add(Nx.multiply(batch_loss, steps))
|> Nx.divide(Nx.add(i, 1))
preds = model_out.prediction
new_state = model_out.state

{new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step} =
if Nx.greater_equal(gradient_step, steps - 1) do
{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)
new_loss =
loss
|> Nx.multiply(i)
|> Nx.add(Nx.multiply(batch_loss, steps))
|> Nx.divide(Nx.add(i, 1))

{new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step} =
if Nx.greater_equal(gradient_step, steps - 1) do
{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)

new_gradient_state = zeros_like(model_state)
new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state)
{new_model_state, new_optimizer_state, new_gradient_state, 0}
else
{model_state, optimizer_state, gradient_state + gradients, gradient_step + 1}
end

new_gradient_state = zeros_like(model_state)
new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state)
{new_model_state, new_optimizer_state, new_gradient_state, 0}
else
{model_state, optimizer_state, gradient_state + gradients, gradient_step + 1}
end
%{
state
| i: Nx.add(i, 1),
gradient_step: new_gradient_step,
y_true: tar,
y_pred: preds,
loss: new_loss,
model_state: new_model_state,
gradient_state: new_gradient_state,
optimizer_state: new_optimizer_state,
loss_scale_state: new_loss_scale_state
}

%{
state
| i: Nx.add(i, 1),
gradient_step: new_gradient_step,
y_true: tar,
y_pred: preds,
loss: new_loss,
model_state: new_model_state,
gradient_state: new_gradient_state,
optimizer_state: new_optimizer_state,
loss_scale_state: new_loss_scale_state
}
data, state ->
raise_bad_training_inputs!(data, state)
end

{
Expand All @@ -406,6 +414,21 @@ defmodule Axon.Loop do
}
end

defp raise_bad_training_inputs!(data, state) do
raise ArgumentError,
"invalid arguments given to train-step initialization," <>
" this usually happens when you pass a invalid parameters" <>
" to Axon.Loop.run with a loop constructed using Axon.Loop.trainer" <>
" or Axon.Loop.evaluator, supervised training and evaluation loops"

" expect a stream or enumerable of inputs" <>
" of the form {x_train, y_train} where x_train and y_train" <>
" are batches of tensors, you must also provide an initial model" <>
" state such as an empty map: Axon.Loop.run(loop, data, %{}), got" <>
" input data: #{inspect(data)} and initial model state: " <>
" #{inspect(state)}"
end

@doc """
Creates a supervised evaluation step from a model and model state.

Expand All @@ -425,12 +448,16 @@ defmodule Axon.Loop do
}
end

step_fn = fn {inp, tar}, %{model_state: model_state} ->
%{
model_state: model_state,
y_true: tar,
y_pred: forward_model_fn.(model_state, inp)
}
step_fn = fn
{inp, tar}, %{model_state: model_state} ->
%{
model_state: model_state,
y_true: tar,
y_pred: forward_model_fn.(model_state, inp)
}

data, state ->
raise_bad_training_inputs!(data, state)
end

{
Expand Down Expand Up @@ -587,8 +614,11 @@ defmodule Axon.Loop do

if log_interval > 0 do
loop
|> log(:iteration_completed, &supervised_log_message_fn/1, :stdio, every: log_interval)
|> log(:epoch_completed, fn _ -> "\n" end, :stdio)
|> log(&supervised_log_message_fn/1,
event: :iteration_completed,
filter: [every: log_interval]
)
|> log(fn _ -> "\n" end, event: :epoch_completed)
else
loop
end
Expand Down Expand Up @@ -655,7 +685,7 @@ defmodule Axon.Loop do
output_transform = fn state -> state.metrics end

loop(step_fn, init_fn, output_transform)
|> log(:iteration_completed, &supervised_log_message_fn(&1, false), :stdio)
|> log(&supervised_log_message_fn(&1, false), event: :iteration_completed)
end

@doc """
Expand Down Expand Up @@ -829,8 +859,12 @@ defmodule Axon.Loop do
`message_fn` should take the loop state and return a binary
representing the message to be written to the IO device.
"""
def log(%Loop{} = loop, event, message_fn, device \\ :stdio, filter \\ :always)
when is_function(message_fn, 1) do
def log(%Loop{} = loop, message_fn, opts \\ []) when is_function(message_fn, 1) do
opts = Keyword.validate!(opts, event: :iteration_completed, filter: :always, device: :stdio)
event = opts[:event] || :iteration_completed
filter = opts[:filter] || :always
device = opts[:device] || :stdio

log_fn = fn %State{} = state ->
try do
msg = message_fn.(state)
Expand Down Expand Up @@ -888,16 +922,19 @@ defmodule Axon.Loop do
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data, :iteration_completed, every: 10_000)
|> Axon.Loop.validate(model, validation_data, event: :iteration_completed, filter: [every: 10_000])
|> Axon.Loop.metric(:binary_cross_entropy)
"""
def validate(
%Loop{metrics: metric_fns} = loop,
model,
validation_data,
event \\ :epoch_completed,
filter \\ :always
opts \\ []
) do
opts = Keyword.validate!(opts, event: :epoch_completed, filter: :always)
event = opts[:event] || :epoch_completed
filter = opts[:filter] || :always

validation_loop = fn %State{metrics: metrics, step_state: step_state} = state ->
%{model_state: model_state} = step_state

Expand All @@ -909,7 +946,7 @@ defmodule Axon.Loop do
metric(loop, v, k)
end)
)
|> log(:completed, fn _ -> "\n" end)
|> log(fn _ -> "\n" end, event: :completed)
|> run(validation_data, model_state)
|> Access.get(0)
|> Map.new(fn {k, v} ->
Expand Down
32 changes: 29 additions & 3 deletions test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ defmodule Axon.LoopTest do
assert %Loop{} = loop = Loop.evaluator(model)
assert %Loop{} = loop = Loop.metric(loop, :mean_absolute_error)

ExUnit.CaptureIO.capture_io(fn ->
assert %{0 => %{"mean_absolute_error" => _}} = Loop.run(loop, data, model_state)
end)
assert ExUnit.CaptureIO.capture_io(fn ->
assert %{0 => %{"mean_absolute_error" => _}} = Loop.run(loop, data, model_state)
end) =~ "Batch"
end

test "eval_step/1 evalutes model on a single batch" do
Expand Down Expand Up @@ -431,6 +431,32 @@ defmodule Axon.LoopTest do
end
end

describe "trainer" do
test "returns clear error on bad inputs" do
model = Axon.input("input")
data = Stream.repeatedly(fn -> Nx.tensor(5) end)

assert_raise ArgumentError, ~r/invalid arguments/, fn ->
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.run(data, %{})
end
end
end

describe "evaluator" do
test "returns clear error on bad inputs" do
model = Axon.input("input")
data = Stream.repeatedly(fn -> Nx.tensor(5) end)

assert_raise ArgumentError, ~r/invalid arguments/, fn ->
model
|> Axon.Loop.evaluator()
|> Axon.Loop.run(data, %{})
end
end
end

describe "serialization" do
test "serialize_state/deserialize_state preserve loop state" do
model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2)
Expand Down