Skip to content

Commit

Permalink
Merge pull request #8140 from steppi/add-lambertw
Browse files Browse the repository at this point in the history
Add lambertw function
  • Loading branch information
asi1024 committed Jan 27, 2024
2 parents 64cc70d + 67d2d5e commit f5d8ed4
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 0 deletions.
10 changes: 10 additions & 0 deletions cupy/_core/include/cupy/special/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
BasedOnStyle: LLVM
Standard: Cpp11
UseTab: Never
IndentWidth: 4
BreakBeforeBraces: Attach
Cpp11BracedListStyle: true
NamespaceIndentation: Inner
AlwaysBreakTemplateDeclarations: true
SpaceAfterCStyleCast: true
ColumnLimit: 120
136 changes: 136 additions & 0 deletions cupy/_core/include/cupy/special/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#pragma once

#ifdef __CUDACC__
#define SPECFUN_HOST_DEVICE __host__ __device__

// Define math constants if they are not available
#ifndef M_E
#define M_E 2.71828182845904523536
#endif

#ifndef M_LOG2E
#define M_LOG2E 1.44269504088896340736
#endif

#ifndef M_LOG10E
#define M_LOG10E 0.434294481903251827651
#endif

#ifndef M_LN2
#define M_LN2 0.693147180559945309417
#endif

#ifndef M_LN10
#define M_LN10 2.30258509299404568402
#endif

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#ifndef M_PI_2
#define M_PI_2 1.57079632679489661923
#endif

#ifndef M_PI_4
#define M_PI_4 0.785398163397448309616
#endif

#ifndef M_1_PI
#define M_1_PI 0.318309886183790671538
#endif

#ifndef M_2_PI
#define M_2_PI 0.636619772367581343076
#endif

#ifndef M_2_SQRTPI
#define M_2_SQRTPI 1.12837916709551257390
#endif

#ifndef M_SQRT2
#define M_SQRT2 1.41421356237309504880
#endif

#ifndef M_SQRT1_2
#define M_SQRT1_2 0.707106781186547524401
#endif

#include <cuda/std/cmath>
#include <cuda/std/limits>

// Fallback to global namespace for functions unsupported on NVRTC Jit
#ifdef _LIBCUDACXX_COMPILER_NVRTC
#include <cuda_runtime.h>
#endif

namespace std {

SPECFUN_HOST_DEVICE inline double abs(double num) { return cuda::std::abs(num); }

SPECFUN_HOST_DEVICE inline double exp(double num) { return cuda::std::exp(num); }

SPECFUN_HOST_DEVICE inline double log(double num) { return cuda::std::log(num); }

SPECFUN_HOST_DEVICE inline double sqrt(double num) { return cuda::std::sqrt(num); }

SPECFUN_HOST_DEVICE inline bool isnan(double num) { return cuda::std::isnan(num); }

SPECFUN_HOST_DEVICE inline bool isfinite(double num) { return cuda::std::isfinite(num); }

SPECFUN_HOST_DEVICE inline double pow(double x, double y) { return cuda::std::pow(x, y); }

SPECFUN_HOST_DEVICE inline double sin(double x) { return cuda::std::sin(x); }

// Fallback to global namespace for functions unsupported on NVRTC
#ifndef _LIBCUDACXX_COMPILER_NVRTC
SPECFUN_HOST_DEVICE inline double floor(double x) { return cuda::std::floor(x); }
SPECFUN_HOST_DEVICE inline double fma(double x, double y, double z) { return cuda::std::fma(x, y, z); }
#else
SPECFUN_HOST_DEVICE inline double floor(double x) { return ::floor(x); }
SPECFUN_HOST_DEVICE inline double fma(double x, double y, double z) { return ::fma(x, y, z); }
#endif

template <typename T>
using numeric_limits = cuda::std::numeric_limits<T>;

// Must use thrust for complex types in order to support CuPy
template <typename T>
using complex = thrust::complex<T>;

template <typename T>
SPECFUN_HOST_DEVICE T abs(const complex<T> &z) {
return thrust::abs(z);
}

template <typename T>
SPECFUN_HOST_DEVICE complex<T> exp(const complex<T> &z) {
return thrust::exp(z);
}

template <typename T>
SPECFUN_HOST_DEVICE complex<T> log(const complex<T> &z) {
return thrust::log(z);
}

template <typename T>
SPECFUN_HOST_DEVICE T norm(const complex<T> &z) {
return thrust::norm(z);
}

template <typename T>
SPECFUN_HOST_DEVICE complex<T> sqrt(const complex<T> &z) {
return thrust::sqrt(z);
}

} // namespace std

#else
#define SPECFUN_HOST_DEVICE

#include <cmath>
#include <complex>
#include <limits>
#include <math.h>

#endif
42 changes: 42 additions & 0 deletions cupy/_core/include/cupy/special/error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

// should be included from config.h, but that won't work until we've cleanly separated out the C and C++ parts of the
// code
#ifdef __CUDACC__
#define SPECFUN_HOST_DEVICE __host__ __device__
#else
#define SPECFUN_HOST_DEVICE
#endif

#ifdef __cplusplus
extern "C" {
#endif

typedef enum {
SF_ERROR_OK = 0, /* no error */
SF_ERROR_SINGULAR, /* singularity encountered */
SF_ERROR_UNDERFLOW, /* floating point underflow */
SF_ERROR_OVERFLOW, /* floating point overflow */
SF_ERROR_SLOW, /* too many iterations required */
SF_ERROR_LOSS, /* loss of precision */
SF_ERROR_NO_RESULT, /* no result obtained */
SF_ERROR_DOMAIN, /* out of domain */
SF_ERROR_ARG, /* invalid input parameter */
SF_ERROR_OTHER, /* unclassified error */
SF_ERROR__LAST
} sf_error_t;

#ifdef __cplusplus
namespace special {

#ifndef SP_SPECFUN_ERROR
SPECFUN_HOST_DEVICE inline void set_error(const char *func_name, sf_error_t code, const char *fmt, ...) {
// nothing
}
#else
void set_error(const char *func_name, sf_error_t code, const char *fmt, ...);
#endif
} // namespace special

} // closes extern "C"
#endif
42 changes: 42 additions & 0 deletions cupy/_core/include/cupy/special/evalpoly.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Evaluate polynomials.
*
* All of the coefficients are stored in reverse order, i.e. if the
* polynomial is
*
* u_n x^n + u_{n - 1} x^{n - 1} + ... + u_0,
*
* then coeffs[0] = u_n, coeffs[1] = u_{n - 1}, ..., coeffs[n] = u_0.
*
* References
* ----------
* [1] Knuth, "The Art of Computer Programming, Volume II"
*/

#pragma once

#include "config.h"

namespace special {

SPECFUN_HOST_DEVICE inline std::complex<double> cevalpoly(const double *coeffs, int degree, std::complex<double> z) {
/* Evaluate a polynomial with real coefficients at a complex point.
*
* Uses equation (3) in section 4.6.4 of [1]. Note that it is more
* efficient than Horner's method.
*/
double a = coeffs[0];
double b = coeffs[1];
double r = 2 * z.real();
double s = std::norm(z);
double tmp;

for (int j = 2; j < degree + 1; j++) {
tmp = b;
b = std::fma(-s, a, coeffs[j]);
a = std::fma(r, a, tmp);
}

return z * a + b;
}

} // namespace special
141 changes: 141 additions & 0 deletions cupy/_core/include/cupy/special/lambertw.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/* Implementation of the Lambert W function [1]. Based on MPMath
* Implementation [2], and documentation [3].
*
* Copyright: Yosef Meller, 2009
* Author email: mellerf@netvision.net.il
*
* Distributed under the same license as SciPy
* Translated into C++ by SciPy developers, 2023.
*
* References:
* [1] On the Lambert W function, Adv. Comp. Math. 5 (1996) 329-359,
* available online: https://web.archive.org/web/20230123211413/https://cs.uwaterloo.ca/research/tr/1993/03/W.pdf
* [2] mpmath source code,
https://github.com/mpmath/mpmath/blob/c5939823669e1bcce151d89261b802fe0d8978b4/mpmath/functions/functions.py#L435-L461
* [3]
https://web.archive.org/web/20230504171447/https://mpmath.org/doc/current/functions/powers.html#lambert-w-function
*
* TODO: use a series expansion when extremely close to the branch point
* at `-1/e` and make sure that the proper branch is chosen there.
*/

#pragma once

#include "config.h"
#include "error.h"
#include "evalpoly.h"

namespace special {
constexpr double EXPN1 = 0.36787944117144232159553; // exp(-1)
constexpr double OMEGA = 0.56714329040978387299997; // W(1, 0)

namespace detail {
SPECFUN_HOST_DEVICE inline std::complex<double> lambertw_branchpt(std::complex<double> z) {
// Series for W(z, 0) around the branch point; see 4.22 in [1].
double coeffs[] = {-1.0 / 3.0, 1.0, -1.0};
std::complex<double> p = std::sqrt(2.0 * (M_E * z + 1.0));

return cevalpoly(coeffs, 2, p);
}

SPECFUN_HOST_DEVICE inline std::complex<double> lambertw_pade0(std::complex<double> z) {
// (3, 2) Pade approximation for W(z, 0) around 0.
double num[] = {12.85106382978723404255, 12.34042553191489361902, 1.0};
double denom[] = {32.53191489361702127660, 14.34042553191489361702, 1.0};

/* This only gets evaluated close to 0, so we don't need a more
* careful algorithm that avoids overflow in the numerator for
* large z. */
return z * cevalpoly(num, 2, z) / cevalpoly(denom, 2, z);
}

SPECFUN_HOST_DEVICE inline std::complex<double> lambertw_asy(std::complex<double> z, long k) {
/* Compute the W function using the first two terms of the
* asymptotic series. See 4.20 in [1].
*/
std::complex<double> w = std::log(z) + 2.0 * M_PI * k * std::complex<double>(0, 1);
return w - std::log(w);
}

} // namespace detail

SPECFUN_HOST_DEVICE inline std::complex<double> lambertw(std::complex<double> z, long k, double tol) {
double absz;
std::complex<double> w;
std::complex<double> ew, wew, wewz, wn;

if (std::isnan(z.real()) || std::isnan(z.imag())) {
return z;
}
if (z.real() == std::numeric_limits<double>::infinity()) {
return z + 2.0 * M_PI * k * std::complex<double>(0, 1);
}
if (z.real() == -std::numeric_limits<double>::infinity()) {
return -z + (2.0 * M_PI * k + M_PI) * std::complex<double>(0, 1);
}
if (z == 0.0) {
if (k == 0) {
return z;
}
set_error("lambertw", SF_ERROR_SINGULAR, NULL);
return -std::numeric_limits<double>::infinity();
}
if (z == 1.0 && k == 0) {
// Split out this case because the asymptotic series blows up
return OMEGA;
}

absz = std::abs(z);
// Get an initial guess for Halley's method
if (k == 0) {
if (std::abs(z + EXPN1) < 0.3) {
w = detail::lambertw_branchpt(z);
} else if (-1.0 < z.real() && z.real() < 1.5 && std::abs(z.imag()) < 1.0 &&
-2.5 * std::abs(z.imag()) - 0.2 < z.real()) {
/* Empirically determined decision boundary where the Pade
* approximation is more accurate. */
w = detail::lambertw_pade0(z);
} else {
w = detail::lambertw_asy(z, k);
}
} else if (k == -1) {
if (absz <= EXPN1 && z.imag() == 0.0 && z.real() < 0.0) {
w = std::log(-z.real());
} else {
w = detail::lambertw_asy(z, k);
}
} else {
w = detail::lambertw_asy(z, k);
}

// Halley's method; see 5.9 in [1]
if (w.real() >= 0) {
// Rearrange the formula to avoid overflow in exp
for (int i = 0; i < 100; i++) {
ew = std::exp(-w);
wewz = w - z * ew;
wn = w - wewz / (w + 1.0 - (w + 2.0) * wewz / (2.0 * w + 2.0));
if (std::abs(wn - w) <= tol * std::abs(wn)) {
return wn;
}
w = wn;
}
} else {
for (int i = 0; i < 100; i++) {
ew = std::exp(w);
wew = w * ew;
wewz = wew - z;
wn = w - wewz / (wew + ew - (w + 2.0) * wewz / (2.0 * w + 2.0));
if (std::abs(wn - w) <= tol * std::abs(wn)) {
return wn;
}
w = wn;
}
}

set_error("lambertw", SF_ERROR_SLOW, "iteration failed to converge: %g + %gj", z.real(), z.imag());
return std::complex<double>(std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN());
}

} // namespace special
1 change: 1 addition & 0 deletions cupyx/scipy/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from cupyx.scipy.special._logsoftmax import log_softmax # NOQA
from cupyx.scipy.special._zeta import zeta # NOQA
from cupyx.scipy.special._zetac import zetac # NOQA
from cupyx.scipy.special._lambertw import lambertw # NOQA

# Convenience functions
from cupyx.scipy.special._basic import cbrt # NOQA
Expand Down

0 comments on commit f5d8ed4

Please sign in to comment.