Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up implementation in SolverBicgstab #14257

Merged
merged 2 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
252 changes: 78 additions & 174 deletions include/deal.II/lac/solver_bicgstab.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,54 +36,6 @@ DEAL_II_NAMESPACE_OPEN
* @{
*/

namespace internal
{
/**
* Class containing the non-parameter non-template values used by the
* SolverBicgstab class.
*/
class SolverBicgstabData
{
protected:
/**
* Auxiliary value.
*/
double alpha;
/**
* Auxiliary value.
*/
double beta;
/**
* Auxiliary value.
*/
double omega;
/**
* Auxiliary value.
*/
double rho;
/**
* Auxiliary value.
*/
double rhobar;

/**
* Current iteration step.
*/
unsigned int step;

/**
* Residual.
*/
double res;

/**
* Default constructor. This is protected so that only SolverBicgstab can
* create instances.
*/
SolverBicgstabData();
};
} // namespace internal

/**
* Bicgstab algorithm by van der Vorst.
*
Expand Down Expand Up @@ -124,8 +76,7 @@ namespace internal
* to observe the progress of the iteration.
*/
template <typename VectorType = Vector<double>>
class SolverBicgstab : public SolverBase<VectorType>,
protected internal::SolverBicgstabData
class SolverBicgstab : public SolverBase<VectorType>
Comment on lines -127 to +79
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possibly an incompatible change, but since the base class was marked as internal, I don't think we need to create a way out for the user.

{
public:
/**
Expand Down Expand Up @@ -194,57 +145,15 @@ class SolverBicgstab : public SolverBase<VectorType>,
const PreconditionerType &preconditioner);

protected:
/**
* A pointer to the solution vector passed to solve().
*/
VectorType *Vx;

/**
* Auxiliary vector.
*/
typename VectorMemory<VectorType>::Pointer Vr;

/**
* Auxiliary vector.
*/
typename VectorMemory<VectorType>::Pointer Vrbar;

/**
* Auxiliary vector.
*/
typename VectorMemory<VectorType>::Pointer Vp;

/**
* Auxiliary vector.
*/
typename VectorMemory<VectorType>::Pointer Vy;

/**
* Auxiliary vector.
*/
typename VectorMemory<VectorType>::Pointer Vz;

/**
* Auxiliary vector.
*/
typename VectorMemory<VectorType>::Pointer Vt;

/**
* Auxiliary vector.
*/
typename VectorMemory<VectorType>::Pointer Vv;

/**
* A pointer to the right hand side vector passed to solve().
*/
const VectorType *Vb;

/**
* Computation of the stopping criterion.
*/
template <typename MatrixType>
double
criterion(const MatrixType &A, const VectorType &x, const VectorType &b);
criterion(const MatrixType &A,
const VectorType &x,
const VectorType &b,
VectorType & t);

/**
* Interface for derived class. This function gets the current iteration
Expand Down Expand Up @@ -286,9 +195,14 @@ class SolverBicgstab : public SolverBase<VectorType>,
*/
template <typename MatrixType, typename PreconditionerType>
IterationResult
iterate(const MatrixType &A, const PreconditionerType &preconditioner);
iterate(const MatrixType & A,
VectorType & x,
const VectorType & b,
const PreconditionerType &preconditioner,
const unsigned int step);
};


/** @} */
/*-------------------------Inline functions -------------------------------*/

Expand All @@ -314,8 +228,6 @@ SolverBicgstab<VectorType>::SolverBicgstab(SolverControl & cn,
VectorMemory<VectorType> &mem,
const AdditionalData & data)
: SolverBase<VectorType>(cn, mem)
, Vx(nullptr)
, Vb(nullptr)
, additional_data(data)
{}

Expand All @@ -325,8 +237,6 @@ template <typename VectorType>
SolverBicgstab<VectorType>::SolverBicgstab(SolverControl & cn,
const AdditionalData &data)
: SolverBase<VectorType>(cn)
, Vx(nullptr)
, Vb(nullptr)
, additional_data(data)
{}

Expand All @@ -337,13 +247,11 @@ template <typename MatrixType>
double
SolverBicgstab<VectorType>::criterion(const MatrixType &A,
const VectorType &x,
const VectorType &b)
const VectorType &b,
VectorType & t)
{
A.vmult(*Vt, x);
Vt->add(-1., b);
res = Vt->l2_norm();

return res;
A.vmult(t, x);
return std::sqrt(t.add_and_dot(-1.0, b, t));
}


Expand All @@ -362,18 +270,21 @@ template <typename VectorType>
template <typename MatrixType, typename PreconditionerType>
typename SolverBicgstab<VectorType>::IterationResult
SolverBicgstab<VectorType>::iterate(const MatrixType & A,
const PreconditionerType &preconditioner)
VectorType & x,
const VectorType & b,
const PreconditionerType &preconditioner,
const unsigned int last_step)
{
A.vmult(*Vr, *Vx);
Vr->sadd(-1., 1., *Vb);
res = Vr->l2_norm();

SolverControl::State state = this->iteration_status(step, res, *Vx);
if (state == SolverControl::State::success)
return IterationResult(false, state, step, res);

alpha = omega = rho = 1.;

// Allocate temporary memory.
typename VectorMemory<VectorType>::Pointer Vr(this->memory);
typename VectorMemory<VectorType>::Pointer Vrbar(this->memory);
typename VectorMemory<VectorType>::Pointer Vp(this->memory);
typename VectorMemory<VectorType>::Pointer Vy(this->memory);
typename VectorMemory<VectorType>::Pointer Vz(this->memory);
typename VectorMemory<VectorType>::Pointer Vt(this->memory);
typename VectorMemory<VectorType>::Pointer Vv(this->memory);

// Define a few aliases for simpler use of the vectors
VectorType &r = *Vr;
VectorType &rbar = *Vrbar;
VectorType &p = *Vp;
Expand All @@ -382,24 +293,49 @@ SolverBicgstab<VectorType>::iterate(const MatrixType & A,
VectorType &t = *Vt;
VectorType &v = *Vv;

rbar = r;
bool startup = true;
r.reinit(x, true);
rbar.reinit(x, true);
p.reinit(x, true);
y.reinit(x, true);
z.reinit(x, true);
t.reinit(x, true);
v.reinit(x, true);

using value_type = typename VectorType::value_type;
using real_type = typename numbers::NumberTraits<value_type>::real_type;

A.vmult(r, x);
r.sadd(-1., 1., b);
value_type res = r.l2_norm();

unsigned int step = last_step;

SolverControl::State state = this->iteration_status(step, res, x);
if (state == SolverControl::State::success)
return IterationResult(false, state, step, res);

rbar = r;

value_type alpha = 1.;
value_type rho = 1.;
value_type omega = 1.;

do
{
++step;

rhobar = r * rbar;
const value_type rhobar = (step == 1 + last_step) ? res * res : r * rbar;

if (std::fabs(rhobar) < additional_data.breakdown)
{
return IterationResult(true, state, step, res);
}
beta = rhobar * alpha / (rho * omega);
rho = rhobar;
if (startup == true)

const value_type beta = rhobar * alpha / (rho * omega);
rho = rhobar;
if (step == last_step + 1)
{
p = r;
startup = false;
p = r;
}
else
{
Expand All @@ -409,50 +345,50 @@ SolverBicgstab<VectorType>::iterate(const MatrixType & A,

preconditioner.vmult(y, p);
A.vmult(v, y);
rhobar = rbar * v;
if (std::fabs(rhobar) < additional_data.breakdown)
const value_type rbar_dot_v = rbar * v;
if (std::fabs(rbar_dot_v) < additional_data.breakdown)
{
return IterationResult(true, state, step, res);
}

alpha = rho / rhobar;
alpha = rho / rbar_dot_v;

res = std::sqrt(r.add_and_dot(-alpha, v, r));
res = std::sqrt(real_type(r.add_and_dot(-alpha, v, r)));

// check for early success, see the lac/bicgstab_early testcase as to
// why this is necessary
//
// note: the vector *Vx we pass to the iteration_status signal here is
// only the current approximation, not the one we will return with, which
// will be x=*Vx + alpha*y
if (this->iteration_status(step, res, *Vx) == SolverControl::success)
if (this->iteration_status(step, res, x) == SolverControl::success)
{
Vx->add(alpha, y);
print_vectors(step, *Vx, r, y);
x.add(alpha, y);
print_vectors(step, x, r, y);
return IterationResult(false, SolverControl::success, step, res);
}

preconditioner.vmult(z, r);
A.vmult(t, z);
rhobar = t * r;
auto t_squared = t * t;
const value_type t_dot_r = t * r;
const real_type t_squared = t * t;
if (t_squared < additional_data.breakdown)
{
return IterationResult(true, state, step, res);
}
omega = rhobar / (t * t);
Vx->add(alpha, y, omega, z);
omega = t_dot_r / t_squared;
x.add(alpha, y, omega, z);

if (additional_data.exact_residual)
{
r.add(-omega, t);
res = criterion(A, *Vx, *Vb);
res = criterion(A, x, b, t);
}
else
res = std::sqrt(r.add_and_dot(-omega, t, r));
res = std::sqrt(real_type(r.add_and_dot(-omega, t, r)));

state = this->iteration_status(step, res, *Vx);
print_vectors(step, *Vx, r, y);
state = this->iteration_status(step, res, x);
print_vectors(step, x, r, y);
}
while (state == SolverControl::iterate);

Expand All @@ -471,45 +407,13 @@ SolverBicgstab<VectorType>::solve(const MatrixType & A,
{
LogStream::Prefix prefix("Bicgstab");

// Allocate temporary memory.
Vr = typename VectorMemory<VectorType>::Pointer(this->memory);
Vrbar = typename VectorMemory<VectorType>::Pointer(this->memory);
Vp = typename VectorMemory<VectorType>::Pointer(this->memory);
Vy = typename VectorMemory<VectorType>::Pointer(this->memory);
Vz = typename VectorMemory<VectorType>::Pointer(this->memory);
Vt = typename VectorMemory<VectorType>::Pointer(this->memory);
Vv = typename VectorMemory<VectorType>::Pointer(this->memory);

Vr->reinit(x, true);
Vrbar->reinit(x, true);
Vp->reinit(x, true);
Vy->reinit(x, true);
Vz->reinit(x, true);
Vt->reinit(x, true);
Vv->reinit(x, true);

Vx = &x;
Vb = &b;

step = 0;

IterationResult state(false, SolverControl::failure, 0, 0);
do
{
state = iterate(A, preconditioner);
state = iterate(A, x, b, preconditioner, state.last_step);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I would suggest to inline this function (in a follow-up PR).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried to move everything into solve, but then I was discouraged by the various places where the iterate function might jump out and request a new attempt with a clean residual and no history, so I gave up. I am sure it can be further restructured, but I realized it would take more time than I wanted to spend (which is probably better spent once we do the bigger task of integrating similar tricks as in the SolverCG and need to restructure the code a bit more).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Let's keep it like this for now.

}
while (state.state == SolverControl::iterate);


// Release the temporary memory again.
Vr.reset();
Vrbar.reset();
Vp.reset();
Vy.reset();
Vz.reset();
Vt.reset();
Vv.reset();

// In case of failure: throw exception
AssertThrow(state.state == SolverControl::success,
SolverControl::NoConvergence(state.last_step,
Expand Down
1 change: 0 additions & 1 deletion source/lac/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ SET(_unity_include_src
relaxation_block.cc
read_write_vector.cc
solver.cc
solver_bicgstab.cc
solver_control.cc
sparse_decomposition.cc
sparse_direct.cc
Expand Down