Skip to content

Commit

Permalink
new CubicSpline interface (#2800)
Browse files Browse the repository at this point in the history
Co-authored-by: Wenfei Li <38569667+wenfei-li@users.noreply.github.com>
  • Loading branch information
jinzx10 and wenfei-li committed Aug 8, 2023
1 parent 74a5bf2 commit 2dc1ec3
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 137 deletions.
217 changes: 151 additions & 66 deletions source/module_base/cubic_spline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,40 @@
namespace ModuleBase
{

CubicSpline::~CubicSpline()
CubicSpline::CubicSpline(CubicSpline const& other)
{
cleanup();
n_ = other.n_;
is_uniform_ = other.is_uniform_;
uniform_thr_ = other.uniform_thr_;

x_ = new double[n_];
y_ = new double[n_];
s_ = new double[n_];

std::memcpy(x_, other.x_, n_ * sizeof(double));
std::memcpy(y_, other.y_, n_ * sizeof(double));
std::memcpy(s_, other.s_, n_ * sizeof(double));
}

CubicSpline& CubicSpline::operator=(CubicSpline const& other)
{
if (this != &other)
{
cleanup();
n_ = other.n_;
is_uniform_ = other.is_uniform_;
uniform_thr_ = other.uniform_thr_;

x_ = new double[n_];
y_ = new double[n_];
s_ = new double[n_];

std::memcpy(x_, other.x_, n_ * sizeof(double));
std::memcpy(y_, other.y_, n_ * sizeof(double));
std::memcpy(s_, other.s_, n_ * sizeof(double));
}

return *this;
}

void CubicSpline::cleanup()
Expand All @@ -25,13 +56,12 @@ void CubicSpline::cleanup()
s_ = nullptr;
}

void CubicSpline::sanity_check(const int n,
const double* const x,
const double* const y,
BoundaryCondition bc_start,
BoundaryCondition bc_end)
void CubicSpline::check_build(const int n,
const double* const x,
const double* const y,
BoundaryCondition bc_start,
BoundaryCondition bc_end)
{

assert(n > 1);

// periodic boundary condition must apply to both ends
Expand All @@ -49,40 +79,48 @@ void CubicSpline::sanity_check(const int n,
}
}

void CubicSpline::check_interp(const int n,
const double* const x,
const double* const y,
const double* const s,
const int n_interp,
const double* const x_interp,
double* const y_interp,
double* const dy_interp)
{
assert(n > 1 && x && y && s); // make sure the interpolant exists
assert(n_interp > 0 && x_interp); // make sure the interpolation points exist
assert(y_interp || dy_interp); // make sure at least one of y or dy is not null

// check that x_interp is in the range of the interpolant
assert(std::all_of(x_interp, x_interp + n_interp,
[n, x](const double xi) { return xi >= x[0] && xi <= x[n - 1]; }));
}

void CubicSpline::build(const int n,
const double* const x,
const double* const y,
double* const s,
BoundaryCondition bc_start,
BoundaryCondition bc_end,
const double deriv_start,
const double deriv_end)
{
sanity_check(n, x, y, bc_start, bc_end);

cleanup();

n_ = n;
x_ = new double[n];
y_ = new double[n];
std::memcpy(x_, x, sizeof(double) * n);
std::memcpy(y_, y, sizeof(double) * n);

// to be computed
s_ = new double[n];
check_build(n, x, y, bc_start, bc_end);

if (n == 2 && bc_start == BoundaryCondition::periodic)
{ // in this case the polynomial is a constant
s_[0] = s_[1] = 0.0;
s[0] = s[1] = 0.0;
}
else if (n == 3 && bc_start == BoundaryCondition::not_a_knot && bc_end == BoundaryCondition::not_a_knot)
{ // in this case two conditions coincide; simply build a parabola that passes through the three data points
double idx10 = 1. / (x[1] - x[0]);
double idx21 = 1. / (x[2] - x[1]);
double idx20 = 1. / (x[2] - x[0]);

s_[0] = -y[0] * (idx10 + idx20) + y[1] * (idx21 + idx10) + y[2] * (idx20 - idx21);
s_[1] = -y[1] * (-idx10 + idx21) + y[0] * (idx20 - idx10) + y[2] * (idx21 - idx20);
s_[2] = s_[1] + 2.0 * (-y[1] * idx10 + y[2] * idx20) + 2.0 * y[0] * idx10 * idx20 * (x[2] - x[1]);
s[0] = -y[0] * (idx10 + idx20) + y[1] * (idx21 + idx10) + y[2] * (idx20 - idx21);
s[1] = -y[1] * (-idx10 + idx21) + y[0] * (idx20 - idx10) + y[2] * (idx21 - idx20);
s[2] = s[1] + 2.0 * (-y[1] * idx10 + y[2] * idx20) + 2.0 * y[0] * idx10 * idx20 * (x[2] - x[1]);
}
else
{
Expand All @@ -94,18 +132,10 @@ void CubicSpline::build(const int n,
double* subdiag = new double[n - 1];
double* supdiag = new double[n - 1];

is_uniform_ = true;
double dx_avg = (x[n - 1] - x[0]) / (n - 1);

for (int i = 0; i != n - 1; ++i)
{
dx[i] = x[i + 1] - x[i];
dd[i] = (y[i + 1] - y[i]) / dx[i];

if (std::abs(dx[i] - dx_avg) > uniform_thr_)
{
is_uniform_ = false;
}
}

// common part of the tridiagonal linear system
Expand All @@ -114,20 +144,18 @@ void CubicSpline::build(const int n,
diag[i] = 2.0 * (dx[i - 1] + dx[i]);
supdiag[i] = dx[i - 1];
subdiag[i - 1] = dx[i];
s_[i] = 3.0 * (dd[i - 1] * dx[i] + dd[i] * dx[i - 1]);
s[i] = 3.0 * (dd[i - 1] * dx[i] + dd[i] * dx[i - 1]);
}

if (bc_start == BoundaryCondition::periodic)
{

// exclude s[n-1] and solve a a cyclic tridiagonal linear system of size n-1
diag[0] = 2.0 * (dx[n - 2] + dx[0]);
supdiag[0] = dx[n - 2];
subdiag[n - 2] = dx[0];
s_[0] = 3.0 * (dd[0] * dx[n - 2] + dd[n - 2] * dx[0]);
;
solve_cyctri(n - 1, diag, supdiag, subdiag, s_);
s_[n - 1] = s_[0];
s[0] = 3.0 * (dd[0] * dx[n - 2] + dd[n - 2] * dx[0]);
solve_cyctri(n - 1, diag, supdiag, subdiag, s);
s[n - 1] = s[0];
}
else
{
Expand All @@ -136,35 +164,35 @@ void CubicSpline::build(const int n,
case BoundaryCondition::first_deriv:
diag[0] = 1.0 * dx[0];
supdiag[0] = 0.0;
s_[0] = deriv_start * dx[0];
s[0] = deriv_start * dx[0];
break;
case BoundaryCondition::second_deriv:
diag[0] = 2.0 * dx[0];
supdiag[0] = 1.0 * dx[0];
s_[0] = (3.0 * dd[0] - 0.5 * deriv_start * dx[0]) * dx[0];
s[0] = (3.0 * dd[0] - 0.5 * deriv_start * dx[0]) * dx[0];
break;
default: // BoundaryCondition::not_a_knot
diag[0] = dx[1];
supdiag[0] = x[2] - x[0];
s_[0] = (dd[0] * dx[1] * (dx[0] + 2 * (x[2] - x[0])) + dd[1] * dx[0] * dx[0]) / (x[2] - x[0]);
s[0] = (dd[0] * dx[1] * (dx[0] + 2 * (x[2] - x[0])) + dd[1] * dx[0] * dx[0]) / (x[2] - x[0]);
}

switch (bc_end)
{
case BoundaryCondition::first_deriv:
diag[n - 1] = 1.0 * dx[n - 2];
subdiag[n - 2] = 0.0;
s_[n - 1] = deriv_end * dx[n - 2];
s[n - 1] = deriv_end * dx[n - 2];
break;
case BoundaryCondition::second_deriv:
diag[n - 1] = 2.0 * dx[n - 2];
subdiag[n - 2] = 1.0 * dx[n - 2];
s_[n - 1] = (3.0 * dd[n - 2] + 0.5 * deriv_end * dx[n - 2]) * dx[n - 2];
s[n - 1] = (3.0 * dd[n - 2] + 0.5 * deriv_end * dx[n - 2]) * dx[n - 2];
break;
default: // BoundaryCondition::not_a_knot
diag[n - 1] = dx[n - 3];
subdiag[n - 2] = x[n - 1] - x[n - 3];
s_[n - 1] = (dd[n - 2] * dx[n - 3] * (dx[n - 2] + 2 * (x[n - 1] - x[n - 3]))
s[n - 1] = (dd[n - 2] * dx[n - 3] * (dx[n - 2] + 2 * (x[n - 1] - x[n - 3]))
+ dd[n - 3] * dx[n - 2] * dx[n - 2])
/ (x[n - 1] - x[n - 3]);
}
Expand All @@ -174,7 +202,7 @@ void CubicSpline::build(const int n,
int INFO = 0;
int N = n;

dgtsv_(&N, &NRHS, subdiag, diag, supdiag, s_, &LDB, &INFO);
dgtsv_(&N, &NRHS, subdiag, diag, supdiag, s, &LDB, &INFO);
}

delete[] diag;
Expand All @@ -185,41 +213,98 @@ void CubicSpline::build(const int n,
}
}

void CubicSpline::interp(const int n, const double* const x, double* const y, double* const dy)
void CubicSpline::build(const int n,
const double* const x,
const double* const y,
BoundaryCondition bc_start,
BoundaryCondition bc_end,
const double deriv_start,
const double deriv_end)
{
assert(x_ && y_ && s_); // make sure the interpolant exists
assert(y || dy); // make sure at least one of y or dy is not null
cleanup();

// check that x is in the range of the interpolant
assert(std::all_of(x, x + n, [this](double x) -> bool { return x >= x_[0] && x <= x_[n_ - 1]; }));
n_ = n;
x_ = new double[n];
y_ = new double[n];
s_ = new double[n];

std::function<int(double)> search;
if (is_uniform_)
std::memcpy(x_, x, sizeof(double) * n);
std::memcpy(y_, y, sizeof(double) * n);
build(n_, x_, y_, s_, bc_start, bc_end, deriv_start, deriv_end);

double dx_avg = (x[n - 1] - x[0]) / (n - 1);
is_uniform_ = std::all_of(x, x + n,
[dx_avg, &x](const double& xi) -> bool { return std::abs(xi - (&xi - x) * dx_avg) < 1e-15; });
}

void CubicSpline::eval(const int n,
const double* const x,
const double* const y,
const double* const s,
const int n_interp,
const double* const x_interp,
double* const y_interp,
double* const dy_interp)
{
check_interp(n, x, y, s, n_interp, x_interp, y_interp, dy_interp);
_eval(x, y, s, n_interp, x_interp, y_interp, dy_interp, _gen_search(n, x));
}

void CubicSpline::eval(const int n_interp,
const double* const x_interp,
double* const y_interp,
double* const dy_interp)
{
check_interp(n_, x_, y_, s_, n_interp, x_interp, y_interp, dy_interp);
_eval(x_, y_, s_, n_interp, x_interp, y_interp, dy_interp, _gen_search(n_, x_, is_uniform_));
}

std::function<int(double)> CubicSpline::_gen_search(const int n, const double* const x, int is_uniform)
{
if (is_uniform != 0 && is_uniform != 1)
{
double dx = x_[1] - x_[0];
search = [this, dx](double x) -> int { return x == x_[n_ - 1] ? n_ - 2 : x / dx; };
double dx_avg = (x[n - 1] - x[0]) / (n - 1);
is_uniform = std::all_of(x, x + n,
[dx_avg, &x] (const double& xi) { return std::abs(xi - (&xi - x) * dx_avg) < 1e-15; });
}

if (is_uniform)
{
double dx = x[1] - x[0];
return [dx, n, x](double xi) -> int { return xi == x[n - 1] ? n - 2 : xi / dx; };
}
else
{
search = [this](double x) -> int { return (std::upper_bound(x_, x_ + n_, x) - x_) - 1; };
return [n, x](double xi) -> int { return (std::upper_bound(x, x + n, xi) - x) - 1; };
}
}

void (*eval)(double, double, double, double, double, double*, double*) = y ? (dy ? _eval_y_dy : _eval_y) : _eval_dy;
void CubicSpline::_eval(const double* const x,
const double* const y,
const double* const s,
const int n_interp,
const double* const x_interp,
double* const y_interp,
double* const dy_interp,
std::function<int(double)> search)
{
void (*poly_eval)(double, double, double, double, double, double*, double*) =
y_interp ? (dy_interp ? _poly_eval<true, true>: _poly_eval<true, false>) : _poly_eval<false, true>;

for (int i = 0; i != n; ++i)
for (int i = 0; i != n_interp; ++i)
{
int p = search(x[i]);
double w = x[i] - x_[p];
int p = search(x_interp[i]);
double w = x_interp[i] - x[p];

double dx = x_[p + 1] - x_[p];
double dd = (y_[p + 1] - y_[p]) / dx;
double dx = x[p + 1] - x[p];
double dd = (y[p + 1] - y[p]) / dx;

double c0 = y_[p];
double c1 = s_[p];
double c3 = (s_[p] + s_[p + 1] - 2.0 * dd) / (dx * dx);
double c2 = (dd - s_[p]) / dx - c3 * dx;
double c0 = y[p];
double c1 = s[p];
double c3 = (s[p] + s[p + 1] - 2.0 * dd) / (dx * dx);
double c2 = (dd - s[p]) / dx - c3 * dx;

eval(w, c0, c1, c2, c3, y + i, dy + i);
poly_eval(w, c0, c1, c2, c3, y_interp + i, dy_interp + i);
}
}

Expand Down
Loading

0 comments on commit 2dc1ec3

Please sign in to comment.