Skip to content

Commit

Permalink
Add implicit_sq_resnorm check for ConvLogger
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Mar 29, 2021
1 parent 9eaa705 commit 36e9a9d
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 0 deletions.
14 changes: 14 additions & 0 deletions core/log/convergence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ void Convergence<ValueType>::on_criterion_check_completed(
}


template <typename ValueType>
void Convergence<ValueType>::on_criterion_check_completed(
const stop::Criterion *criterion, const size_type &num_iterations,
const LinOp *residual, const LinOp *residual_norm, const LinOp *solution,
const uint8 &stopping_id, const bool &set_finalized,
const Array<stopping_status> *status, const bool &one_changed,
const bool &converged) const
{
this->on_criterion_check_completed(
criterion, num_iterations, residual, residual_norm, nullptr, solution,
stopping_id, set_finalized, status, one_changed, converged);
}


#define GKO_DECLARE_CONVERGENCE(_type) class Convergence<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CONVERGENCE);

Expand Down
28 changes: 28 additions & 0 deletions include/ginkgo/core/log/convergence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ class Convergence : public Logger {
const bool &set_finalized, const Array<stopping_status> *status,
const bool &one_changed, const bool &all_converged) const override;

void on_criterion_check_completed(
const stop::Criterion *criterion, const size_type &num_iterations,
const LinOp *residual, const LinOp *residual_norm,
const LinOp *implicit_sq_resnorm, const LinOp *solution,
const uint8 &stopping_id, const bool &set_finalized,
const Array<stopping_status> *status, const bool &one_changed,
const bool &all_converged) const override;


/**
* Creates a convergence logger. This dynamically allocates the memory,
* constructs the object and returns an std::unique_ptr to this object.
Expand All @@ -94,6 +103,13 @@ class Convergence : public Logger {
new Convergence(exec, enabled_events));
}

/**
* Returns true if the solver has converged.
*
* @return the bool flag for convergence status
*/
bool has_converged() const noexcept { return convergence_status_; }

/**
* Returns the number of iterations
*
Expand Down Expand Up @@ -121,6 +137,16 @@ class Convergence : public Logger {
return residual_norm_.get();
}

/**
* Returns the implicit squared residual norm
*
* @return the implicit squared residual norm
*/
const LinOp *get_implicit_sq_resnorm() const noexcept
{
return implicit_sq_resnorm_.get();
}

protected:
/**
* Creates a Convergence logger.
Expand All @@ -136,9 +162,11 @@ class Convergence : public Logger {
{}

private:
mutable bool convergence_status_{false};
mutable size_type num_iterations_{};
mutable std::unique_ptr<LinOp> residual_{};
mutable std::unique_ptr<LinOp> residual_norm_{};
mutable std::unique_ptr<LinOp> implicit_sq_resnorm_{};
};


Expand Down
12 changes: 12 additions & 0 deletions include/ginkgo/core/log/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,18 @@ public: \
const uint8 &stopping_id, const bool &set_finalized,
const Array<stopping_status> *status, const bool &one_changed,
const bool &all_converged)
protected:
virtual void on_criterion_check_completed(
const stop::Criterion *criterion, const size_type &it, const LinOp *r,
const LinOp *tau, const LinOp *implicit_sq_tau, const LinOp *x,
const uint8 &stopping_id, const bool &set_finalized,
const Array<stopping_status> *status, const bool &one_changed,
const bool &all_converged) const
{
this->on_criterion_check_completed(criterion, it, r, x, tau, x,
stopping_id, set_finalized, status,
one_changed, all_converged);
}

/**
* Register the `iteration_complete` event which logs every completed
Expand Down
31 changes: 31 additions & 0 deletions reference/test/log/convergence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,37 @@ TYPED_TEST(Convergence, CatchesCriterionCheckCompleted)
RelativeStoppingId, true, &stop_status, true, true);

ASSERT_EQ(logger->get_num_iterations(), 1);
ASSERT_EQ(logger->has_converged(), true);
GKO_ASSERT_MTX_NEAR(gko::as<Mtx>(logger->get_residual()),
l({1.0, 2.0, 2.0}), 0.0);
GKO_ASSERT_MTX_NEAR(gko::as<NormVector>(logger->get_residual_norm()),
l({3.0}), 0.0);
}


TYPED_TEST(Convergence, CatchesCriterionCheckCompletedWithImplicitNorm)
{
auto exec = gko::ReferenceExecutor::create();
auto logger = gko::log::Convergence<TypeParam>::create(
exec, gko::log::Logger::criterion_check_completed_mask);
auto criterion =
gko::stop::Iteration::build().with_max_iters(3u).on(exec)->generate(
nullptr, nullptr, nullptr);
constexpr gko::uint8 RelativeStoppingId{42};
gko::Array<gko::stopping_status> stop_status(exec, 1);
using Mtx = gko::matrix::Dense<TypeParam>;
using NormVector = gko::matrix::Dense<gko::remove_complex<TypeParam>>;
auto residual = gko::initialize<Mtx>({1.0, 2.0, 2.0}, exec);
auto implicit_sq_resnorm = gko::initialize<Mtx>({4.0}, exec);

logger->template on<gko::log::Logger::criterion_check_completed>(
criterion.get(), 1, residual.get(), nullptr, implicit_sq_resnorm.get(),
nullptr, RelativeStoppingId, true, &stop_status, true, true);

ASSERT_EQ(logger->get_num_iterations(), 1);
ASSERT_EQ(logger->has_converged(), true);
GKO_ASSERT_MTX_NEAR(gko::as<Mtx>(logger->get_implicit_sq_resnorm()),
l({4.0}), 0.0);
GKO_ASSERT_MTX_NEAR(gko::as<Mtx>(logger->get_residual()),
l({1.0, 2.0, 2.0}), 0.0);
GKO_ASSERT_MTX_NEAR(gko::as<NormVector>(logger->get_residual_norm()),
Expand Down

0 comments on commit 36e9a9d

Please sign in to comment.