-
-
Notifications
You must be signed in to change notification settings - Fork 782
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8140 from steppi/add-lambertw
Add lambertw function
- Loading branch information
Showing
9 changed files
with
412 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
BasedOnStyle: LLVM | ||
Standard: Cpp11 | ||
UseTab: Never | ||
IndentWidth: 4 | ||
BreakBeforeBraces: Attach | ||
Cpp11BracedListStyle: true | ||
NamespaceIndentation: Inner | ||
AlwaysBreakTemplateDeclarations: true | ||
SpaceAfterCStyleCast: true | ||
ColumnLimit: 120 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#pragma once | ||
|
||
#ifdef __CUDACC__ | ||
#define SPECFUN_HOST_DEVICE __host__ __device__ | ||
|
||
// Define math constants if they are not available | ||
#ifndef M_E | ||
#define M_E 2.71828182845904523536 | ||
#endif | ||
|
||
#ifndef M_LOG2E | ||
#define M_LOG2E 1.44269504088896340736 | ||
#endif | ||
|
||
#ifndef M_LOG10E | ||
#define M_LOG10E 0.434294481903251827651 | ||
#endif | ||
|
||
#ifndef M_LN2 | ||
#define M_LN2 0.693147180559945309417 | ||
#endif | ||
|
||
#ifndef M_LN10 | ||
#define M_LN10 2.30258509299404568402 | ||
#endif | ||
|
||
#ifndef M_PI | ||
#define M_PI 3.14159265358979323846 | ||
#endif | ||
|
||
#ifndef M_PI_2 | ||
#define M_PI_2 1.57079632679489661923 | ||
#endif | ||
|
||
#ifndef M_PI_4 | ||
#define M_PI_4 0.785398163397448309616 | ||
#endif | ||
|
||
#ifndef M_1_PI | ||
#define M_1_PI 0.318309886183790671538 | ||
#endif | ||
|
||
#ifndef M_2_PI | ||
#define M_2_PI 0.636619772367581343076 | ||
#endif | ||
|
||
#ifndef M_2_SQRTPI | ||
#define M_2_SQRTPI 1.12837916709551257390 | ||
#endif | ||
|
||
#ifndef M_SQRT2 | ||
#define M_SQRT2 1.41421356237309504880 | ||
#endif | ||
|
||
#ifndef M_SQRT1_2 | ||
#define M_SQRT1_2 0.707106781186547524401 | ||
#endif | ||
|
||
#include <cuda/std/cmath> | ||
#include <cuda/std/limits> | ||
|
||
// Fallback to global namespace for functions unsupported on NVRTC Jit | ||
#ifdef _LIBCUDACXX_COMPILER_NVRTC | ||
#include <cuda_runtime.h> | ||
#endif | ||
|
||
namespace std { | ||
|
||
SPECFUN_HOST_DEVICE inline double abs(double num) { return cuda::std::abs(num); } | ||
|
||
SPECFUN_HOST_DEVICE inline double exp(double num) { return cuda::std::exp(num); } | ||
|
||
SPECFUN_HOST_DEVICE inline double log(double num) { return cuda::std::log(num); } | ||
|
||
SPECFUN_HOST_DEVICE inline double sqrt(double num) { return cuda::std::sqrt(num); } | ||
|
||
SPECFUN_HOST_DEVICE inline bool isnan(double num) { return cuda::std::isnan(num); } | ||
|
||
SPECFUN_HOST_DEVICE inline bool isfinite(double num) { return cuda::std::isfinite(num); } | ||
|
||
SPECFUN_HOST_DEVICE inline double pow(double x, double y) { return cuda::std::pow(x, y); } | ||
|
||
SPECFUN_HOST_DEVICE inline double sin(double x) { return cuda::std::sin(x); } | ||
|
||
// Fallback to global namespace for functions unsupported on NVRTC | ||
#ifndef _LIBCUDACXX_COMPILER_NVRTC | ||
SPECFUN_HOST_DEVICE inline double floor(double x) { return cuda::std::floor(x); } | ||
SPECFUN_HOST_DEVICE inline double fma(double x, double y, double z) { return cuda::std::fma(x, y, z); } | ||
#else | ||
SPECFUN_HOST_DEVICE inline double floor(double x) { return ::floor(x); } | ||
SPECFUN_HOST_DEVICE inline double fma(double x, double y, double z) { return ::fma(x, y, z); } | ||
#endif | ||
|
||
template <typename T> | ||
using numeric_limits = cuda::std::numeric_limits<T>; | ||
|
||
// Must use thrust for complex types in order to support CuPy | ||
template <typename T> | ||
using complex = thrust::complex<T>; | ||
|
||
template <typename T> | ||
SPECFUN_HOST_DEVICE T abs(const complex<T> &z) { | ||
return thrust::abs(z); | ||
} | ||
|
||
template <typename T> | ||
SPECFUN_HOST_DEVICE complex<T> exp(const complex<T> &z) { | ||
return thrust::exp(z); | ||
} | ||
|
||
template <typename T> | ||
SPECFUN_HOST_DEVICE complex<T> log(const complex<T> &z) { | ||
return thrust::log(z); | ||
} | ||
|
||
template <typename T> | ||
SPECFUN_HOST_DEVICE T norm(const complex<T> &z) { | ||
return thrust::norm(z); | ||
} | ||
|
||
template <typename T> | ||
SPECFUN_HOST_DEVICE complex<T> sqrt(const complex<T> &z) { | ||
return thrust::sqrt(z); | ||
} | ||
|
||
} // namespace std | ||
|
||
#else | ||
#define SPECFUN_HOST_DEVICE | ||
|
||
#include <cmath> | ||
#include <complex> | ||
#include <limits> | ||
#include <math.h> | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#pragma once | ||
|
||
// should be included from config.h, but that won't work until we've cleanly separated out the C and C++ parts of the | ||
// code | ||
#ifdef __CUDACC__ | ||
#define SPECFUN_HOST_DEVICE __host__ __device__ | ||
#else | ||
#define SPECFUN_HOST_DEVICE | ||
#endif | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
typedef enum { | ||
SF_ERROR_OK = 0, /* no error */ | ||
SF_ERROR_SINGULAR, /* singularity encountered */ | ||
SF_ERROR_UNDERFLOW, /* floating point underflow */ | ||
SF_ERROR_OVERFLOW, /* floating point overflow */ | ||
SF_ERROR_SLOW, /* too many iterations required */ | ||
SF_ERROR_LOSS, /* loss of precision */ | ||
SF_ERROR_NO_RESULT, /* no result obtained */ | ||
SF_ERROR_DOMAIN, /* out of domain */ | ||
SF_ERROR_ARG, /* invalid input parameter */ | ||
SF_ERROR_OTHER, /* unclassified error */ | ||
SF_ERROR__LAST | ||
} sf_error_t; | ||
|
||
#ifdef __cplusplus | ||
namespace special { | ||
|
||
#ifndef SP_SPECFUN_ERROR | ||
SPECFUN_HOST_DEVICE inline void set_error(const char *func_name, sf_error_t code, const char *fmt, ...) { | ||
// nothing | ||
} | ||
#else | ||
void set_error(const char *func_name, sf_error_t code, const char *fmt, ...); | ||
#endif | ||
} // namespace special | ||
|
||
} // closes extern "C" | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* Evaluate polynomials. | ||
* | ||
* All of the coefficients are stored in reverse order, i.e. if the | ||
* polynomial is | ||
* | ||
* u_n x^n + u_{n - 1} x^{n - 1} + ... + u_0, | ||
* | ||
* then coeffs[0] = u_n, coeffs[1] = u_{n - 1}, ..., coeffs[n] = u_0. | ||
* | ||
* References | ||
* ---------- | ||
* [1] Knuth, "The Art of Computer Programming, Volume II" | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "config.h" | ||
|
||
namespace special { | ||
|
||
SPECFUN_HOST_DEVICE inline std::complex<double> cevalpoly(const double *coeffs, int degree, std::complex<double> z) { | ||
/* Evaluate a polynomial with real coefficients at a complex point. | ||
* | ||
* Uses equation (3) in section 4.6.4 of [1]. Note that it is more | ||
* efficient than Horner's method. | ||
*/ | ||
double a = coeffs[0]; | ||
double b = coeffs[1]; | ||
double r = 2 * z.real(); | ||
double s = std::norm(z); | ||
double tmp; | ||
|
||
for (int j = 2; j < degree + 1; j++) { | ||
tmp = b; | ||
b = std::fma(-s, a, coeffs[j]); | ||
a = std::fma(r, a, tmp); | ||
} | ||
|
||
return z * a + b; | ||
} | ||
|
||
} // namespace special |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
/* Implementation of the Lambert W function [1]. Based on MPMath | ||
* Implementation [2], and documentation [3]. | ||
* | ||
* Copyright: Yosef Meller, 2009 | ||
* Author email: mellerf@netvision.net.il | ||
* | ||
* Distributed under the same license as SciPy | ||
* Translated into C++ by SciPy developers, 2023. | ||
* | ||
* References: | ||
* [1] On the Lambert W function, Adv. Comp. Math. 5 (1996) 329-359, | ||
* available online: https://web.archive.org/web/20230123211413/https://cs.uwaterloo.ca/research/tr/1993/03/W.pdf | ||
* [2] mpmath source code, | ||
https://github.com/mpmath/mpmath/blob/c5939823669e1bcce151d89261b802fe0d8978b4/mpmath/functions/functions.py#L435-L461 | ||
* [3] | ||
https://web.archive.org/web/20230504171447/https://mpmath.org/doc/current/functions/powers.html#lambert-w-function | ||
* | ||
* TODO: use a series expansion when extremely close to the branch point | ||
* at `-1/e` and make sure that the proper branch is chosen there. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "config.h" | ||
#include "error.h" | ||
#include "evalpoly.h" | ||
|
||
namespace special { | ||
constexpr double EXPN1 = 0.36787944117144232159553; // exp(-1) | ||
constexpr double OMEGA = 0.56714329040978387299997; // W(1, 0) | ||
|
||
namespace detail { | ||
SPECFUN_HOST_DEVICE inline std::complex<double> lambertw_branchpt(std::complex<double> z) { | ||
// Series for W(z, 0) around the branch point; see 4.22 in [1]. | ||
double coeffs[] = {-1.0 / 3.0, 1.0, -1.0}; | ||
std::complex<double> p = std::sqrt(2.0 * (M_E * z + 1.0)); | ||
|
||
return cevalpoly(coeffs, 2, p); | ||
} | ||
|
||
SPECFUN_HOST_DEVICE inline std::complex<double> lambertw_pade0(std::complex<double> z) { | ||
// (3, 2) Pade approximation for W(z, 0) around 0. | ||
double num[] = {12.85106382978723404255, 12.34042553191489361902, 1.0}; | ||
double denom[] = {32.53191489361702127660, 14.34042553191489361702, 1.0}; | ||
|
||
/* This only gets evaluated close to 0, so we don't need a more | ||
* careful algorithm that avoids overflow in the numerator for | ||
* large z. */ | ||
return z * cevalpoly(num, 2, z) / cevalpoly(denom, 2, z); | ||
} | ||
|
||
SPECFUN_HOST_DEVICE inline std::complex<double> lambertw_asy(std::complex<double> z, long k) { | ||
/* Compute the W function using the first two terms of the | ||
* asymptotic series. See 4.20 in [1]. | ||
*/ | ||
std::complex<double> w = std::log(z) + 2.0 * M_PI * k * std::complex<double>(0, 1); | ||
return w - std::log(w); | ||
} | ||
|
||
} // namespace detail | ||
|
||
SPECFUN_HOST_DEVICE inline std::complex<double> lambertw(std::complex<double> z, long k, double tol) { | ||
double absz; | ||
std::complex<double> w; | ||
std::complex<double> ew, wew, wewz, wn; | ||
|
||
if (std::isnan(z.real()) || std::isnan(z.imag())) { | ||
return z; | ||
} | ||
if (z.real() == std::numeric_limits<double>::infinity()) { | ||
return z + 2.0 * M_PI * k * std::complex<double>(0, 1); | ||
} | ||
if (z.real() == -std::numeric_limits<double>::infinity()) { | ||
return -z + (2.0 * M_PI * k + M_PI) * std::complex<double>(0, 1); | ||
} | ||
if (z == 0.0) { | ||
if (k == 0) { | ||
return z; | ||
} | ||
set_error("lambertw", SF_ERROR_SINGULAR, NULL); | ||
return -std::numeric_limits<double>::infinity(); | ||
} | ||
if (z == 1.0 && k == 0) { | ||
// Split out this case because the asymptotic series blows up | ||
return OMEGA; | ||
} | ||
|
||
absz = std::abs(z); | ||
// Get an initial guess for Halley's method | ||
if (k == 0) { | ||
if (std::abs(z + EXPN1) < 0.3) { | ||
w = detail::lambertw_branchpt(z); | ||
} else if (-1.0 < z.real() && z.real() < 1.5 && std::abs(z.imag()) < 1.0 && | ||
-2.5 * std::abs(z.imag()) - 0.2 < z.real()) { | ||
/* Empirically determined decision boundary where the Pade | ||
* approximation is more accurate. */ | ||
w = detail::lambertw_pade0(z); | ||
} else { | ||
w = detail::lambertw_asy(z, k); | ||
} | ||
} else if (k == -1) { | ||
if (absz <= EXPN1 && z.imag() == 0.0 && z.real() < 0.0) { | ||
w = std::log(-z.real()); | ||
} else { | ||
w = detail::lambertw_asy(z, k); | ||
} | ||
} else { | ||
w = detail::lambertw_asy(z, k); | ||
} | ||
|
||
// Halley's method; see 5.9 in [1] | ||
if (w.real() >= 0) { | ||
// Rearrange the formula to avoid overflow in exp | ||
for (int i = 0; i < 100; i++) { | ||
ew = std::exp(-w); | ||
wewz = w - z * ew; | ||
wn = w - wewz / (w + 1.0 - (w + 2.0) * wewz / (2.0 * w + 2.0)); | ||
if (std::abs(wn - w) <= tol * std::abs(wn)) { | ||
return wn; | ||
} | ||
w = wn; | ||
} | ||
} else { | ||
for (int i = 0; i < 100; i++) { | ||
ew = std::exp(w); | ||
wew = w * ew; | ||
wewz = wew - z; | ||
wn = w - wewz / (wew + ew - (w + 2.0) * wewz / (2.0 * w + 2.0)); | ||
if (std::abs(wn - w) <= tol * std::abs(wn)) { | ||
return wn; | ||
} | ||
w = wn; | ||
} | ||
} | ||
|
||
set_error("lambertw", SF_ERROR_SLOW, "iteration failed to converge: %g + %gj", z.real(), z.imag()); | ||
return std::complex<double>(std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN()); | ||
} | ||
|
||
} // namespace special |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.