diff --git a/core/log/record.cpp b/core/log/record.cpp index 32be42931c9..ba249edf3e6 100644 --- a/core/log/record.cpp +++ b/core/log/record.cpp @@ -244,7 +244,8 @@ void Record::on_criterion_check_started( void Record::on_criterion_check_completed( const stop::Criterion *criterion, const size_type &num_iterations, - const LinOp *residual, const LinOp *residual_norm, const LinOp *solution, + const LinOp *residual, const LinOp *residual_norm, + const LinOp *implicit_residual_norm_sq, const LinOp *solution, const uint8 &stopping_id, const bool &set_finalized, const Array *status, const bool &oneChanged, const bool &converged) const @@ -257,6 +258,19 @@ void Record::on_criterion_check_completed( } +void Record::on_criterion_check_completed( + const stop::Criterion *criterion, const size_type &num_iterations, + const LinOp *residual, const LinOp *residual_norm, const LinOp *solution, + const uint8 &stopping_id, const bool &set_finalized, + const Array *status, const bool &oneChanged, + const bool &converged) const +{ + this->on_criterion_check_completed( + criterion, num_iterations, residual, residual_norm, nullptr, solution, + stopping_id, set_finalized, status, oneChanged, converged); +} + + void Record::on_iteration_complete(const LinOp *solver, const size_type &num_iterations, const LinOp *residual, const LinOp *solution, diff --git a/core/test/log/record.cpp b/core/test/log/record.cpp index ffde5bb0e23..a9a0947d199 100644 --- a/core/test/log/record.cpp +++ b/core/test/log/record.cpp @@ -464,7 +464,7 @@ TEST(Record, CatchesCriterionCheckStarted) } -TEST(Record, CatchesCriterionCheckCompleted) +TEST(Record, CatchesCriterionCheckCompletedOld) { auto exec = gko::ReferenceExecutor::create(); auto logger = gko::log::Record::create( @@ -494,6 +494,36 @@ TEST(Record, CatchesCriterionCheckCompleted) } +TEST(Record, CatchesCriterionCheckCompleted) +{ + auto exec = gko::ReferenceExecutor::create(); + auto logger = gko::log::Record::create( + exec, gko::log::Logger::criterion_check_completed_mask); + auto criterion = + gko::stop::Iteration::build().with_max_iters(3u).on(exec)->generate( + nullptr, nullptr, nullptr); + constexpr gko::uint8 RelativeStoppingId{42}; + gko::Array stop_status(exec, 1); + + logger->on( + criterion.get(), 1, nullptr, nullptr, nullptr, nullptr, + RelativeStoppingId, true, &stop_status, true, true); + + stop_status.get_data()->reset(); + stop_status.get_data()->stop(RelativeStoppingId); + auto &data = logger->get().criterion_check_completed.back(); + ASSERT_NE(data->criterion, nullptr); + ASSERT_EQ(data->stopping_id, RelativeStoppingId); + ASSERT_EQ(data->set_finalized, true); + ASSERT_EQ(data->status->get_const_data()->has_stopped(), true); + ASSERT_EQ(data->status->get_const_data()->get_id(), + stop_status.get_const_data()->get_id()); + ASSERT_EQ(data->status->get_const_data()->is_finalized(), true); + ASSERT_EQ(data->oneChanged, true); + ASSERT_EQ(data->converged, true); +} + + TEST(Record, CatchesIterations) { using Dense = gko::matrix::Dense<>; diff --git a/include/ginkgo/core/log/logger.hpp b/include/ginkgo/core/log/logger.hpp index 22a809bab75..049d964fa51 100644 --- a/include/ginkgo/core/log/logger.hpp +++ b/include/ginkgo/core/log/logger.hpp @@ -417,7 +417,7 @@ public: \ const Array *status, const bool &one_changed, const bool &all_converged) const { - this->on_criterion_check_completed(criterion, it, r, x, tau, x, + this->on_criterion_check_completed(criterion, it, r, tau, x, stopping_id, set_finalized, status, one_changed, all_converged); } diff --git a/include/ginkgo/core/log/record.hpp b/include/ginkgo/core/log/record.hpp index eb20ae0a40a..71b31e9e794 100644 --- a/include/ginkgo/core/log/record.hpp +++ b/include/ginkgo/core/log/record.hpp @@ -363,6 +363,14 @@ class Record : public Logger { const uint8 &stopping_id, const bool &set_finalized) const override; + void on_criterion_check_completed( + const stop::Criterion *criterion, const size_type &num_iterations, + const LinOp *residual, const LinOp *residual_norm, + const LinOp *implicit_residual_norm_sq, const LinOp *solution, + const uint8 &stopping_id, const bool &set_finalized, + const Array *status, const bool &one_changed, + const bool &all_converged) const override; + void on_criterion_check_completed( const stop::Criterion *criterion, const size_type &num_iterations, const LinOp *residual, const LinOp *residual_norm,