Skip to content

Commit

Permalink
Merge Allow ResidualNorm stop to compute residual in gen/check
Browse files Browse the repository at this point in the history
This PR allows ResidualNorm criterion to compute needed information in generation or check.
Solver does not need to implement the residual computation if the algorithm does not contain residual computation.
The solver only compute residual when user passes the residual related stopping criterion

Related PR: #818
  • Loading branch information
yhmtsai committed Jul 15, 2021
2 parents 7aaba5b + 0d7a0e9 commit 26ee87e
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 10 deletions.
20 changes: 20 additions & 0 deletions core/stop/residual_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ bool ResidualNormBase<ValueType>::check_impl(
dense_r->compute_norm2(u_dense_tau_.get());
}
dense_tau = u_dense_tau_.get();
} else if (updater.solution_ != nullptr && system_matrix_ != nullptr &&
b_ != nullptr) {
auto exec = this->get_executor();
// when LinOp is real but rhs is complex, we use real view on complex,
// so it still uses the same type of scalar in apply.
if (auto vec_b = std::dynamic_pointer_cast<const Vector>(b_)) {
auto dense_r = vec_b->clone();
system_matrix_->apply(neg_one_.get(), updater.solution_, one_.get(),
dense_r.get());
dense_r->compute_norm2(u_dense_tau_.get());
} else if (auto vec_b =
std::dynamic_pointer_cast<const ComplexVector>(b_)) {
auto dense_r = vec_b->clone();
system_matrix_->apply(neg_one_.get(), updater.solution_, one_.get(),
dense_r.get());
dense_r->compute_norm2(u_dense_tau_.get());
} else {
GKO_NOT_SUPPORTED(nullptr);
}
dense_tau = u_dense_tau_.get();
} else {
GKO_NOT_SUPPORTED(nullptr);
}
Expand Down
51 changes: 41 additions & 10 deletions include/ginkgo/core/stop/residual_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,47 @@ class ResidualNormBase
: EnablePolymorphicObject<ResidualNormBase, Criterion>(exec),
device_storage_{exec, 2},
reduction_factor_{reduction_factor},
baseline_{baseline}
baseline_{baseline},
system_matrix_{args.system_matrix},
b_{args.b},
one_{gko::initialize<Vector>({1}, exec)},
neg_one_{gko::initialize<Vector>({-1}, exec)}
{
switch (baseline_) {
case mode::initial_resnorm: {
if (args.initial_residual == nullptr) {
GKO_NOT_SUPPORTED(nullptr);
}
this->starting_tau_ = NormVector::create(
exec, dim<2>{1, args.initial_residual->get_size()[1]});
if (dynamic_cast<const ComplexVector *>(args.initial_residual)) {
auto dense_r = as<ComplexVector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
if (args.system_matrix == nullptr || args.b == nullptr ||
args.x == nullptr) {
GKO_NOT_SUPPORTED(nullptr);
} else {
this->starting_tau_ = NormVector::create(
exec, dim<2>{1, args.b->get_size()[1]});
auto b_clone = share(args.b->clone());
args.system_matrix->apply(neg_one_.get(), args.x,
one_.get(), b_clone.get());
if (auto vec =
std::dynamic_pointer_cast<const ComplexVector>(
b_clone)) {
vec->compute_norm2(this->starting_tau_.get());
} else if (auto vec =
std::dynamic_pointer_cast<const Vector>(
b_clone)) {
vec->compute_norm2(this->starting_tau_.get());
} else {
GKO_NOT_SUPPORTED(nullptr);
}
}
} else {
auto dense_r = as<Vector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
this->starting_tau_ = NormVector::create(
exec, dim<2>{1, args.initial_residual->get_size()[1]});
if (dynamic_cast<const ComplexVector *>(
args.initial_residual)) {
auto dense_r = as<ComplexVector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
} else {
auto dense_r = as<Vector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
}
}
break;
}
Expand Down Expand Up @@ -157,6 +183,11 @@ class ResidualNormBase

private:
mode baseline_{mode::rhs_norm};
std::shared_ptr<const LinOp> system_matrix_{};
std::shared_ptr<const LinOp> b_{};
/* one/neg_one for residual computation */
std::shared_ptr<const Vector> one_{};
std::shared_ptr<const Vector> neg_one_{};
};


Expand Down
249 changes: 249 additions & 0 deletions reference/test/stop/residual_norm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,205 @@ TYPED_TEST(ResidualNorm, WaitsTillResidualGoal)
}


TYPED_TEST(ResidualNorm, SelfCalulatesThrowWithoutMatrix)
{
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
using T = TypeParam;
using T_nc = gko::remove_complex<TypeParam>;
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);

T rhs_val = 10.0;
std::shared_ptr<gko::LinOp> rhs =
gko::initialize<Mtx>({rhs_val}, this->exec_);
auto rhs_criterion =
this->rhs_factory_->generate(nullptr, rhs, nullptr, initial_res.get());
auto rel_criterion =
this->rel_factory_->generate(nullptr, rhs, nullptr, initial_res.get());
auto abs_criterion =
this->abs_factory_->generate(nullptr, rhs, nullptr, initial_res.get());
{
auto solution = gko::initialize<Mtx>({rhs_val - T{10.0}}, this->exec_);
auto rhs_norm = gko::initialize<NormVector>({100.0}, this->exec_);
gko::as<Mtx>(rhs)->compute_norm2(rhs_norm.get());
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}
{
T initial_norm = 100.0;
auto solution =
gko::initialize<Mtx>({rhs_val - initial_norm}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}
{
auto solution = gko::initialize<Mtx>({rhs_val - T{100.0}}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}
}


TYPED_TEST(ResidualNorm, RelativeSelfCalulatesThrowWithoutRhs)
{
// only relative residual norm allows generation without rhs.
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
using T = TypeParam;
using T_nc = gko::remove_complex<TypeParam>;
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);

T rhs_val = 10.0;
auto rel_criterion = this->rel_factory_->generate(nullptr, nullptr, nullptr,
initial_res.get());
T initial_norm = 100.0;
auto solution = gko::initialize<Mtx>({rhs_val - initial_norm}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}


TYPED_TEST(ResidualNorm, SelfCalulatesAndWaitsTillResidualGoal)
{
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
using T = TypeParam;
using T_nc = gko::remove_complex<TypeParam>;
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);
auto system_mtx = share(gko::initialize<Mtx>({1.0}, this->exec_));

T rhs_val = 10.0;
std::shared_ptr<gko::LinOp> rhs =
gko::initialize<Mtx>({rhs_val}, this->exec_);
auto rhs_criterion = this->rhs_factory_->generate(system_mtx, rhs, nullptr,
initial_res.get());
auto rel_criterion = this->rel_factory_->generate(system_mtx, rhs, nullptr,
initial_res.get());
auto abs_criterion = this->abs_factory_->generate(system_mtx, rhs, nullptr,
initial_res.get());
{
auto solution = gko::initialize<Mtx>({rhs_val - T{10.0}}, this->exec_);
auto rhs_norm = gko::initialize<NormVector>({100.0}, this->exec_);
gko::as<Mtx>(rhs)->compute_norm2(rhs_norm.get());
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));

solution->at(0) = rhs_val - r<T>::value * T{1.1} * rhs_norm->at(0);
ASSERT_FALSE(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

solution->at(0) = rhs_val - r<T>::value * T{0.9} * rhs_norm->at(0);
ASSERT_TRUE(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}
{
T initial_norm = 100.0;
auto solution =
gko::initialize<Mtx>({rhs_val - initial_norm}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));

solution->at(0) = rhs_val - r<T>::value * T{1.1} * initial_norm;
ASSERT_FALSE(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

solution->at(0) = rhs_val - r<T>::value * T{0.9} * initial_norm;
ASSERT_TRUE(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}
{
auto solution = gko::initialize<Mtx>({rhs_val - T{100.0}}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));

solution->at(0) = rhs_val - r<T>::value * T{1.2};
ASSERT_FALSE(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

solution->at(0) = rhs_val - r<T>::value * T{0.9};
ASSERT_TRUE(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}
}


TYPED_TEST(ResidualNorm, WaitsTillResidualGoalMultipleRHS)
{
using Mtx = typename TestFixture::Mtx;
Expand Down Expand Up @@ -370,6 +569,20 @@ class ResidualNormReduction : public ::testing::Test {
TYPED_TEST_SUITE(ResidualNormReduction, gko::test::ValueTypes);


TYPED_TEST(ResidualNormReduction,
CanCreateCriterionWithMtxRhsXWithoutInitialRes)
{
using Mtx = typename TestFixture::Mtx;
std::shared_ptr<gko::LinOp> x = gko::initialize<Mtx>({100.0}, this->exec_);
std::shared_ptr<gko::LinOp> mtx = gko::initialize<Mtx>({1.0}, this->exec_);
std::shared_ptr<gko::LinOp> b = gko::initialize<Mtx>({10.0}, this->exec_);

auto criterion = this->factory_->generate(mtx, b, x.get());

ASSERT_NE(criterion, nullptr);
}


TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoal)
{
using Mtx = typename TestFixture::Mtx;
Expand Down Expand Up @@ -407,6 +620,42 @@ TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoal)
}


TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoalWithoutInitialRes)
{
using T = TypeParam;
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
T initial_res = 100;
T rhs_val = 10;
std::shared_ptr<gko::LinOp> rhs =
gko::initialize<Mtx>({rhs_val}, this->exec_);
std::shared_ptr<Mtx> x =
gko::initialize<Mtx>({rhs_val - initial_res}, this->exec_);
std::shared_ptr<gko::LinOp> mtx = gko::initialize<Mtx>({1.0}, this->exec_);

auto criterion = this->factory_->generate(mtx, rhs, x.get());
bool one_changed{};
constexpr gko::uint8 RelativeStoppingId{1};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(criterion->update().solution(x.get()).check(
RelativeStoppingId, true, &stop_status, &one_changed));

x->at(0) = rhs_val - r<T>::value * T{1.1} * initial_res;
ASSERT_FALSE(criterion->update().solution(x.get()).check(
RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

x->at(0) = rhs_val - r<T>::value * T{0.9} * initial_res;
ASSERT_TRUE(criterion->update().solution(x.get()).check(
RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}


TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoalMultipleRHS)
{
using Mtx = typename TestFixture::Mtx;
Expand Down

0 comments on commit 26ee87e

Please sign in to comment.