diff --git a/test/axon/integration_test.exs b/test/axon/integration_test.exs index c617b843..67024d38 100644 --- a/test/axon/integration_test.exs +++ b/test/axon/integration_test.exs @@ -358,9 +358,10 @@ defmodule Axon.IntegrationTest do assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results - assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params end) end @@ -405,9 +406,10 @@ defmodule Axon.IntegrationTest do assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results - assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params end) end end