Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fast tanh #6

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Implement fast tanh #6

wants to merge 7 commits into from

Conversation

vks
Copy link
Contributor

@vks vks commented Jan 8, 2019

Partially adresses #1.

@codecov-io
Copy link

codecov-io commented Jan 13, 2019

Codecov Report

Merging #6 into master will increase coverage by 0.36%.
The diff coverage is 95.69%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master       #6      +/-   ##
==========================================
+ Coverage   94.35%   94.72%   +0.36%     
==========================================
  Files           5        6       +1     
  Lines         248      341      +93     
==========================================
+ Hits          234      323      +89     
- Misses         14       18       +4
Impacted Files Coverage Δ
src/lib.rs 0% <ø> (ø) ⬆️
src/tanh.rs 95.69% <95.69%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 704ae2e...f0f0549. Read the comment docs.

@huonw huonw mentioned this pull request Jan 13, 2019
3 tasks
Copy link
Owner

@huonw huonw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pull request!

My main comment is I think this approximation is quite a few operations, and we could possibly get away with something a bit simpler/faster.

Also, I've landed some large(ish) changes on master (#7), which require some adjustments (criterion for benchmarks now), and bring in the ieee754 library for convenient access to some basic float operations.

src/tanh.rs Outdated Show resolved Hide resolved
((28. * x2 + 3150.) * x2 + 62370.) * x2 + 135135.
}

/// Compute a fast approximation of the hyperbolic tangent of `x`.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Compute a fast approximation of the hyperbolic tangent of `x`.
/// Compute a fast approximation of the hyperbolic tangent of `x` for -4 < `x` < 4.


/// Compute a fast approximation of the hyperbolic tangent of `x`.
///
/// For large |x|, the output may be outside of [-1, 1].
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer this to just make no guarantees about the behaviour at all, e.g.

/// This will return unspecified nonsense if `x` is doesn't 
/// satisfy those constraints. Use `tanh` if correct handling is
/// required (at the expense of some speed).

/// See `atanh_raw` for a faster version that may return incorrect results for
/// large `|x|` and `nan`.
#[inline]
pub fn tanh(x: f32) -> f32 {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could be something like:

pub fn tanh(x: f32) -> f32 {
    if x < -4.97 {
        -1.
    } else if x > 4.97 {
        1.
    } else {
        // if x is NaN, it will propagate through the arithmetic
        tanh_raw(x)
    }
}

This is likely to be easier to vectorize, and does fewer operations. If you rebase/merge this PR onto the latest master (and so can use Ieee754::copy_sign), this could even be:

pub fn tanh(x: f32) -> f32 {
    if x.abs() > 4.97 {
         // the true value |tanh(x)| > 0.9999 when |x| > 4.97, so 
        // rounding to ±1 is close enough
        1_f32.copy_sign(x)
    } else {
        // |tanh_raw(x)| < 1 when |x| <= 4.97, so no post-processing is needed,
        // and x being NaN is handled by propagating through the arithmetic
        tanh_raw(x)
    }
}

With this adjustment, a and b no longer need to be separate functions and can be inlined straight into tanh_raw.

Copy link
Contributor Author

@vks vks Mar 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you think clipping like this might be problematic, because it results in discontinuities?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a potential problem. An alternative would be to find when the approximation is exactly +/-1, and clip there instead of 4.97 (I think it should be symmetric?), so that the tanh approximation is continuous (although it's derivative won't be).

pub fn tanh_raw(x: f32) -> f32 {
// Implementation based on
// https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction
a(x) / b(x)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, did you consider other approaches? E.g.

  1. using a lower-degree continued fraction approximation instead of (7, 6), such as cutting it off at the level with the 5 (this seems to have maximum relative and absolute errors of about 0.02 if it's used on [-2.3, 2.3] and clipped to +/-1 outside that):

    (x2 + 15.) * x / (6. * x2 + 15.)
  2. optimize the parameters of the approximation (just truncating series like the Taylor series of continued fractions won't be the most accurate approximation); for the form (x2 + a) * x / (b * x2 + a) on some interval [-limit, limit] (like the above), I get a = 21.350693, b = 7.8355837, limit = 2.933833 as the best, with relative and absolute errors of approximately 0.0057, which is about as accurate as other functions in fast-math. (I used the approx.py script referenced at the end of this comment.) (I suspect this expensive form could benefit from optimizing its coefficients too.)

  3. Use the tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) form, with an approximate exp such as that described in https://stackoverflow.com/a/50379934/1256624, which could look something like the following in Rust (haven't tested):

    /// Computes an approximation to (exp(x), exp(-x))
    #[inline]
    fn pm_exp(x: f32) -> (f32, f32) {
        const A: f32 = (1 << 23) as f32 / LN_2;
        const B: u32 = 127 << 23;
        let r = (A * x) as i32 as u32;
        (f32::from_bits(B.wrapping_add(r)),
         f32::from_bits(B.wrapping_sub(r)))
    }
    pub fn tanh_raw(x: f32) -> f32 {
        let (plus, minus) = pm_exp(x);
        (plus - minus) / (plus + minus)
    }

    It could also use the exp now in the library, but it's more expensive (uses a quadratic approximation that requires pulling more info out of the floats), and I'm lead to believe that the above typically benefits from some cancellation of errors that doesn't occur for an isolated exp.

My suspicion is that 2. will be a good balance of speed and accuracy, but 3. could surprise me. Do you know any details about the above?

approx.py

Run like python approx.py, may require Python 3.

import numpy as np
from scipy import optimize

def rel_error(approx, true):
    # if the true value is 0, the approximate one must be too
    return np.where(true != 0,
                    np.abs((approx - true) / true),
                    np.where(approx == 0, 0, np.inf))
def abs_error(approx, true):
    return np.abs(approx - true)

f32 = np.float32
def approx(coeffs, points):
    a, b, limit = coeffs
    # approximate with (x^3 + a x) / (b x ^ 2 + a) on the interval
    # [-limit, limit].
    #
    # Why this form?  tanh is odd, so we should have odd / even (and
    # so the unlisted coeffs must be zero), tanh(x) ~= x for small x
    # (so we can share a in the top and bottom so it approximates a x
    # / a = x when x is small).
    points2 = points * points
    poly = (points2 + a) * points / (b * points2 + a)
    return np.where(np.abs(points) <= limit, poly, np.sign(points))

def evaluation(coeffs, points):
    a = approx(coeffs, points)
    t = np.tanh(points)
    rel = rel_error(a, t).max()
    abs = abs_error(a, t).max()
    return (rel, abs)

start = np.array([15, 6, 2.3])
opt_points = np.linspace(-5, 5, 100001)

# optimize on the relative error
result = optimize.fmin(lambda c: evaluation(f32(c), f32(opt_points))[0], start, maxiter=10000)

final_values = approx(result, opt_points)
assert np.all((final_values >= -1) & (final_values <= 1)), "not allowed to overshoot"

rel, abs = evaluation(result, opt_points)
print("a = %s, b = %s, limit = %s" % tuple(f32(result)))
print("on [-5, 5]: rel error = %.6f, abs error = %.6f" % (rel, abs))
for x in np.arange(-5, 5.01, 0.5):
    print("%7.4f: %f (%f)" % (x, approx(result, x), np.tanh(x)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/tanh.rs Show resolved Hide resolved
src/tanh.rs Outdated Show resolved Hide resolved
@vks vks changed the title Implement fast atanh Implement fast tanh Jan 14, 2019
@vks
Copy link
Contributor Author

vks commented Mar 6, 2019

I tried the implementations you suggested:

Current implementation:

scalar/tanh/baseline    time:   [6.2435 ns 6.2553 ns 6.2696 ns]                                  
scalar/tanh/raw         time:   [78.785 ns 78.925 ns 79.074 ns]                            
scalar/tanh/full        time:   [101.70 ns 101.85 ns 102.01 ns]                             
scalar/tanh/std         time:   [362.90 ns 363.99 ns 365.11 ns]                            

vector/tanh/baseline    time:   [4.8044 ns 4.8112 ns 4.8177 ns]                                  
vector/tanh/raw         time:   [79.245 ns 79.367 ns 79.482 ns]                            
vector/tanh/full        time:   [99.708 ns 99.943 ns 100.22 ns]                             
vector/tanh/std         time:   [365.35 ns 366.00 ns 366.72 ns]                            

Suggested clipping:

scalar/tanh/baseline    time:   [6.2427 ns 6.2547 ns 6.2670 ns]                                  
scalar/tanh/raw         time:   [81.526 ns 81.854 ns 82.318 ns]                            
scalar/tanh/full        time:   [96.473 ns 96.680 ns 96.911 ns]                             
scalar/tanh/std         time:   [359.96 ns 360.61 ns 361.29 ns]                            

vector/tanh/baseline    time:   [4.8269 ns 4.8392 ns 4.8526 ns]                                  
vector/tanh/raw         time:   [86.026 ns 86.170 ns 86.317 ns]                            
vector/tanh/full        time:   [96.642 ns 96.800 ns 96.959 ns]                             
vector/tanh/std         time:   [356.28 ns 357.21 ns 358.32 ns]

It looks like there might be a small improvement, but the results are weird (with unexpected changes for tanh/raw and tanh/std, which were not changed). I'm not sure it's worth introducing the discontinuity.

Suggest clipping with optimized lower-order approximation:

scalar/tanh/baseline    time:   [6.1683 ns 6.1784 ns 6.1889 ns]                                  
scalar/tanh/raw         time:   [37.826 ns 37.905 ns 37.987 ns]                             
scalar/tanh/full        time:   [33.853 ns 33.925 ns 34.004 ns]                              
scalar/tanh/std         time:   [386.83 ns 387.95 ns 389.08 ns]                            

vector/tanh/baseline    time:   [4.8231 ns 4.8306 ns 4.8384 ns]                                  
vector/tanh/raw         time:   [9.4154 ns 9.4496 ns 9.4851 ns]                             
vector/tanh/full        time:   [10.176 ns 10.203 ns 10.235 ns]                              
vector/tanh/std         time:   [356.60 ns 357.30 ns 358.03 ns]                            

This seems to result in good performance improvements (54% for scalar/tanh/raw, and 90% for the vectorized code).

exp-based implementation:

scalar/tanh/baseline    time:   [6.2289 ns 6.2519 ns 6.2815 ns]                                  
scalar/tanh/raw         time:   [38.523 ns 38.623 ns 38.736 ns]                             
scalar/tanh/full        time:   [38.654 ns 38.774 ns 38.902 ns]                              
scalar/tanh/std         time:   [360.46 ns 361.12 ns 361.82 ns]                            

vector/tanh/baseline    time:   [4.8622 ns 4.8726 ns 4.8830 ns]                                  
vector/tanh/raw         time:   [10.374 ns 10.434 ns 10.518 ns]                             
vector/tanh/full        time:   [13.021 ns 13.056 ns 13.093 ns]                              
vector/tanh/std         time:   [363.48 ns 366.64 ns 371.14 ns]                            

This is a bit slower than the truncated continued fraction.

@vks
Copy link
Contributor Author

vks commented Mar 6, 2019

Should I switch to the implementation optimized for 0.0057 error tolerance?

Copy link
Owner

@huonw huonw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think switching would be good, given how much of a performance improvement it is.

I think this library potentially needs to be restructured to give more control about errors, but for now, I think 0.0057 error tolerance is fine. Do you have something for which you might use this? If so, is that error tolerance acceptable?

/// See `atanh_raw` for a faster version that may return incorrect results for
/// large `|x|` and `nan`.
#[inline]
pub fn tanh(x: f32) -> f32 {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a potential problem. An alternative would be to find when the approximation is exactly +/-1, and clip there instead of 4.97 (I think it should be symmetric?), so that the tanh approximation is continuous (although it's derivative won't be).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants