Skip to content

Commit

Permalink
Merge implicit residual log for solvers and benchmarks
Browse files Browse the repository at this point in the history
This adds the implicit residual to all suitable solvers and few additional loggers, namely ResidualLogger
used in solver benchmarks, the custom-logger example and a few overloads of Logger that we missed before.

Related PR: #714
  • Loading branch information
upsj committed Mar 6, 2021
2 parents e0e0cc4 + e3b780a commit 61f7b7c
Show file tree
Hide file tree
Showing 18 changed files with 226 additions and 95 deletions.
6 changes: 6 additions & 0 deletions benchmark/solver/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ void solve_system(const std::string &solver_name,
rapidjson::Value(rapidjson::kArrayType), allocator);
add_or_set_member(solver_json, "true_residuals",
rapidjson::Value(rapidjson::kArrayType), allocator);
add_or_set_member(solver_json, "implicit_residuals",
rapidjson::Value(rapidjson::kArrayType), allocator);
add_or_set_member(solver_json, "iteration_timestamps",
rapidjson::Value(rapidjson::kArrayType), allocator);
if (b->get_size()[1] == 1 && !FLAGS_overhead) {
Expand Down Expand Up @@ -457,9 +459,13 @@ void solve_system(const std::string &solver_name,
exec, lend(system_matrix), b,
solver_json["recurrent_residuals"],
solver_json["true_residuals"],
solver_json["implicit_residuals"],
solver_json["iteration_timestamps"], allocator);
solver->add_logger(res_logger);
solver->apply(lend(b), lend(x_clone));
if (!res_logger->has_implicit_res_norms()) {
solver_json.RemoveMember("implicit_residuals");
}
}
exec->synchronize();
}
Expand Down
30 changes: 29 additions & 1 deletion benchmark/utils/loggers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <chrono>
#include <cmath>
#include <mutex>
#include <regex>
#include <unordered_map>
Expand Down Expand Up @@ -202,10 +203,21 @@ template <typename ValueType>
struct ResidualLogger : gko::log::Logger {
using rc_vtype = gko::remove_complex<ValueType>;

void on_iteration_complete(const gko::LinOp *, const gko::size_type &,
// TODO2.0: Remove when deprecating simple overload
void on_iteration_complete(const gko::LinOp *solver,
const gko::size_type &it,
const gko::LinOp *residual,
const gko::LinOp *solution,
const gko::LinOp *residual_norm) const override
{
on_iteration_complete(solver, it, residual, solution, residual_norm,
nullptr);
}

void on_iteration_complete(
const gko::LinOp *, const gko::size_type &, const gko::LinOp *residual,
const gko::LinOp *solution, const gko::LinOp *residual_norm,
const gko::LinOp *implicit_sq_residual_norm) const override
{
timestamps.PushBack(std::chrono::duration<double>(
std::chrono::steady_clock::now() - start)
Expand All @@ -226,12 +238,22 @@ struct ResidualLogger : gko::log::Logger {
} else {
true_res_norms.PushBack(-1.0, alloc);
}
if (implicit_sq_residual_norm) {
implicit_res_norms.PushBack(
std::sqrt(get_norm(
gko::as<vec<rc_vtype>>(implicit_sq_residual_norm))),
alloc);
has_implicit_res_norm = true;
} else {
implicit_res_norms.PushBack(-1.0, alloc);
}
}

ResidualLogger(std::shared_ptr<const gko::Executor> exec,
const gko::LinOp *matrix, const vec<ValueType> *b,
rapidjson::Value &rec_res_norms,
rapidjson::Value &true_res_norms,
rapidjson::Value &implicit_res_norms,
rapidjson::Value &timestamps,
rapidjson::MemoryPoolAllocator<> &alloc)
: gko::log::Logger(exec, gko::log::Logger::iteration_complete_mask),
Expand All @@ -240,16 +262,22 @@ struct ResidualLogger : gko::log::Logger {
start{std::chrono::steady_clock::now()},
rec_res_norms{rec_res_norms},
true_res_norms{true_res_norms},
has_implicit_res_norm{},
implicit_res_norms{implicit_res_norms},
timestamps{timestamps},
alloc{alloc}
{}

bool has_implicit_res_norms() const { return has_implicit_res_norm; }

private:
const gko::LinOp *matrix;
const vec<ValueType> *b;
std::chrono::steady_clock::time_point start;
rapidjson::Value &rec_res_norms;
rapidjson::Value &true_res_norms;
mutable bool has_implicit_res_norm;
rapidjson::Value &implicit_res_norms;
rapidjson::Value &timestamps;
rapidjson::MemoryPoolAllocator<> &alloc;
};
Expand Down
11 changes: 11 additions & 0 deletions core/log/papi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,17 @@ void Papi<ValueType>::on_iteration_complete(const LinOp *solver,
const LinOp *residual,
const LinOp *solution,
const LinOp *residual_norm) const
{
this->on_iteration_complete(solver, num_iterations, residual, solution,
residual_norm, nullptr);
}


template <typename ValueType>
void Papi<ValueType>::on_iteration_complete(
const LinOp *solver, const size_type &num_iterations, const LinOp *residual,
const LinOp *solution, const LinOp *residual_norm,
const LinOp *implicit_sq_residual_norm) const
{
iteration_complete.get_counter(solver) = num_iterations;
}
Expand Down
14 changes: 13 additions & 1 deletion core/log/record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,23 @@ void Record::on_iteration_complete(const LinOp *solver,
const size_type &num_iterations,
const LinOp *residual, const LinOp *solution,
const LinOp *residual_norm) const
{
this->on_iteration_complete(solver, num_iterations, residual, solution,
residual_norm, nullptr);
}


void Record::on_iteration_complete(const LinOp *solver,
const size_type &num_iterations,
const LinOp *residual, const LinOp *solution,
const LinOp *residual_norm,
const LinOp *implicit_sq_residual_norm) const
{
append_deque(
data_.iteration_completed,
(std::unique_ptr<iteration_complete_data>(new iteration_complete_data{
solver, num_iterations, residual, solution, residual_norm})));
solver, num_iterations, residual, solution, residual_norm,
implicit_sq_residual_norm})));
}


Expand Down
21 changes: 19 additions & 2 deletions core/log/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,24 @@ void Stream<ValueType>::on_iteration_complete(const LinOp *solver,
const LinOp *residual,
const LinOp *solution,
const LinOp *residual_norm) const
{
this->on_iteration_complete(solver, num_iterations, residual, solution,
residual_norm, nullptr);
}


template <typename ValueType>
void Stream<ValueType>::on_iteration_complete(
const LinOp *solver, const size_type &num_iterations, const LinOp *residual,
const LinOp *solution, const LinOp *residual_norm,
const LinOp *implicit_sq_residual_norm) const
{
os_ << prefix_ << "iteration " << num_iterations
<< " completed with solver " << demangle_name(solver)
<< " with residual " << demangle_name(residual) << ", solution "
<< demangle_name(solution) << " and residual_norm "
<< demangle_name(residual_norm) << std::endl;
<< demangle_name(solution) << ", residual_norm "
<< demangle_name(residual_norm) << " and implicit_sq_residual_norm "
<< demangle_name(implicit_sq_residual_norm) << std::endl;
if (verbose_) {
os_ << demangle_name(residual)
<< as<gko::matrix::Dense<ValueType>>(residual) << std::endl;
Expand All @@ -446,6 +458,11 @@ void Stream<ValueType>::on_iteration_complete(const LinOp *solver,
<< as<gko::matrix::Dense<ValueType>>(residual_norm)
<< std::endl;
}
if (implicit_sq_residual_norm != nullptr) {
os_ << demangle_name(implicit_sq_residual_norm)
<< as<gko::matrix::Dense<ValueType>>(implicit_sq_residual_norm)
<< std::endl;
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions core/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ void Bicg<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
z->compute_dot(r2.get(), rho.get());

++iter;
this->template log<log::Logger::iteration_complete>(this, iter, r.get(),
dense_x);
this->template log<log::Logger::iteration_complete>(
this, iter, r.get(), dense_x, nullptr, rho.get());
if (stop_criterion->update()
.num_iterations(iter)
.residual(r.get())
Expand Down
4 changes: 2 additions & 2 deletions core/solver/bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ void Bicgstab<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
*/
while (true) {
++iter;
this->template log<log::Logger::iteration_complete>(this, iter, r.get(),
dense_x);
this->template log<log::Logger::iteration_complete>(
this, iter, r.get(), dense_x, nullptr, rho.get());
rr->compute_dot(r.get(), rho.get());

if (stop_criterion->update()
Expand Down
4 changes: 2 additions & 2 deletions core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ void Cg<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
r->compute_dot(z.get(), rho.get());

++iter;
this->template log<log::Logger::iteration_complete>(this, iter, r.get(),
dense_x);
this->template log<log::Logger::iteration_complete>(
this, iter, r.get(), dense_x, nullptr, rho.get());
if (stop_criterion->update()
.num_iterations(iter)
.residual(r.get())
Expand Down
4 changes: 2 additions & 2 deletions core/solver/cgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ void Cgs<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
alpha.get(), &stop_status));

++iter;
this->template log<log::Logger::iteration_complete>(this, iter, r.get(),
dense_x);
this->template log<log::Logger::iteration_complete>(
this, iter, r.get(), dense_x, nullptr, rho.get());
if (stop_criterion->update()
.num_iterations(iter)
.residual(r.get())
Expand Down
4 changes: 2 additions & 2 deletions core/solver/fcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ void Fcg<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
t->compute_dot(z.get(), rho_t.get());

++iter;
this->template log<log::Logger::iteration_complete>(this, iter, r.get(),
dense_x);
this->template log<log::Logger::iteration_complete>(
this, iter, r.get(), dense_x, nullptr, rho.get());
if (stop_criterion->update()
.num_iterations(iter)
.residual(r.get())
Expand Down
2 changes: 1 addition & 1 deletion core/test/log/papi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ TYPED_TEST(Papi, CatchesIterationComplete)

this->start();
this->logger->template on<gko::log::Logger::iteration_complete>(
A.get(), 42, nullptr, nullptr, nullptr);
A.get(), 42, nullptr, nullptr, nullptr, nullptr);
long long int value = 0;
this->stop(&value);

Expand Down
5 changes: 4 additions & 1 deletion core/test/log/record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,12 @@ TEST(Record, CatchesIterations)
auto residual = gko::initialize<Dense>({-4.4}, exec);
auto solution = gko::initialize<Dense>({-2.2}, exec);
auto residual_norm = gko::initialize<Dense>({-3.3}, exec);
auto implicit_sq_residual_norm = gko::initialize<Dense>({-3.5}, exec);


logger->on<gko::log::Logger::iteration_complete>(
solver.get(), num_iters, residual.get(), solution.get(),
residual_norm.get());
residual_norm.get(), implicit_sq_residual_norm.get());

auto &data = logger->get().iteration_completed.back();
ASSERT_NE(data->solver.get(), nullptr);
Expand All @@ -522,6 +523,8 @@ TEST(Record, CatchesIterations)
GKO_ASSERT_MTX_NEAR(gko::as<Dense>(data->solution.get()), solution, 0);
GKO_ASSERT_MTX_NEAR(gko::as<Dense>(data->residual_norm.get()),
residual_norm, 0);
GKO_ASSERT_MTX_NEAR(gko::as<Dense>(data->implicit_sq_residual_norm.get()),
implicit_sq_residual_norm, 0);
}


Expand Down
4 changes: 3 additions & 1 deletion core/test/log/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,13 +708,15 @@ TYPED_TEST(Stream, CatchesIterations)
auto residual = Dense::create(exec);
auto solution = Dense::create(exec);
auto residual_norm = Dense::create(exec);
auto implicit_sq_residual_norm = Dense::create(exec);
std::stringstream ptrstream_solver;
ptrstream_solver << solver.get();
std::stringstream ptrstream_residual;
ptrstream_residual << residual.get();

logger->template on<gko::log::Logger::iteration_complete>(
solver.get(), num_iters, residual.get());
solver.get(), num_iters, residual.get(), solution.get(),
residual_norm.get(), implicit_sq_residual_norm.get());

GKO_ASSERT_STR_CONTAINS(out.str(),
"iteration " + std::to_string(num_iters));
Expand Down
Loading

0 comments on commit 61f7b7c

Please sign in to comment.