Skip to content

Commit

Permalink
Test output param type in integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jun 2, 2023
1 parent c19ff70 commit 0558146
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/axon/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,10 @@ defmodule Axon.IntegrationTest do

assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results

assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7)
assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60)
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params
end)
end

Expand Down Expand Up @@ -405,9 +406,10 @@ defmodule Axon.IntegrationTest do

assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results

assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7)
assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60)
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params
end)
end
end
Expand Down

0 comments on commit 0558146

Please sign in to comment.