Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of np.interp #3860

Closed
AdrienCorenflos opened this issue Jul 25, 2020 · 19 comments
Closed

Implementation of np.interp #3860

AdrienCorenflos opened this issue Jul 25, 2020 · 19 comments
Assignees
Labels
enhancement New feature or request question Questions for the JAX team

Comments

@AdrienCorenflos
Copy link
Contributor

Hi,

As mentioned here, I am happy to leverage some code I'm putting together for my own project (sorted interpolation) to provide jax with an implementation of np.interp.

Modulo some high level checks on dimensionality of the arrays + the "period" argument which is trivial to transfer to jax, the below is my best attempt so far at it.

I've tested that the results were almost equal to numpy ones, plus the gradients match the numerical ones. Do the org members/collaborators have any problem with the approach (or ideas to make it more JAX-y), or should I iron it out, put it in a PR and send it over?

Adrien

@jit
def sorted_interp(x, xp, fp):
    m = x.shape[0]
    n = xp.shape[0]

    x = jnp.atleast_1d(x)

    j = 0
    xp_0 = xp[0]
    fp_0 = fp[0]


    def inner_fun(args):
        x_i, j = args
        def cond_fun(state):
            is_continuing, *_ = state
            return is_continuing

        def body_fun(state):
            _, _, curr_j, curr_xp_j, curr_fp_j = state

            next_xp_j = xp[curr_j + 1]
            next_fp_j = fp[curr_j + 1]

            cond = x_i > next_xp_j

            def cond_true(_):
                inner_cond = curr_j + 1 == n - 1

                def fun_true(_): return False, True, curr_j, next_xp_j, next_fp_j
                def fun_false(_): return True, False, curr_j + 1, next_xp_j, next_fp_j

                return lax.cond(inner_cond, fun_true, fun_false, None)

            def cond_false(_):
                inner_cond = curr_fp_j == next_xp_j
                
                def fun_true(_):  return False, True, curr_j, next_xp_j, next_fp_j
                def fun_false(_):  return False, False, curr_j, next_xp_j, next_fp_j

                return lax.cond(inner_cond, fun_true, fun_false, None)

            return lax.cond(cond, cond_true, cond_false, None)

        _, use_next_fp_j, new_j, *_ = lax.while_loop(cond_fun, body_fun, (True, False, j, xp[j], fp[j]))
        # We don't compute the result inside the loop to allow for seemless backward mode differentiability
        return new_j, lax.cond(use_next_fp_j, 
                               lambda _: fp[new_j + 1],
                               lambda _: fp[new_j] + + (fp[new_j + 1] - fp[new_j]) * (x_i - xp[new_j]) / (xp[new_j + 1] - xp[new_j]),
                               None)

    def body_fun(j, x_i):
        return lax.cond(x_i <= xp_0, lambda *_: (j, fp_0), inner_fun, (x_i, j))

    _, f = lax.scan(body_fun, 0, x)
    return f


@jit
def _interp(x, xp, fp):
    x = jnp.atleast_1d(x)
    argsort = jnp.argsort(x)
    sorted_res = sorted_interp(x[argsort], xp, fp)
    return jnp.empty_like(sorted_res).at[argsort].set(sorted_res)

def interp(x, xp, fp):
    # Do the checks like in the numpy version
    return _interp(x, xp, fp)
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 25, 2020

This looks great! Note that there are a couple additional arguments to np.interp that we should probably handle. I took a stab a while back at a jax.numpy.interp, but it stalled because I had trouble matching the behavior of np.interp in corner cases (period boundaries & repeated values, in particualar).

I suspect your approach is more efficient for larger x arrays, but FWIW here's what I came up with:

def interp(x, xp, fp, left=None, right=None, period=None):
  x, xp, fp = map(np.asarray, (x, xp, fp))
  if period:
    x = x % period
    xp = xp % period
    i = np.argsort(xp)
    xp = xp[i]
    fp = fp[i]
    xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
    fp = np.concatenate([fp[-1:], fp, fp[1:]])
  
  i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  f = (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

  if not period:
    if left is None:
      left = fp[0]
    if right is None:
      right = fp[-1]
    f = np.where(x < xp[0], left, f)
    f = np.where(x > xp[-1], right, f)
  return f

@AdrienCorenflos
Copy link
Contributor Author

I concur about the additional arguments but they are trivially implemented and I didn't want to take the attention away from the big piece of inner logic.

@AdrienCorenflos
Copy link
Contributor Author

Also there is probably a lot of commenting to do in there too... It's probably a bit complex at first sight.

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 25, 2020

My main concern with the sorted_interp approach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation using searchsorted is just a few lines of code:

def interp(x, xp, fp):
  i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the searchsorted version proves too slow, maybe focusing our optimization effort on searchsorted would give more bang for the buck?

@AdrienCorenflos
Copy link
Contributor Author

I very much agree with that. The only reason I offered it in the first place is because I needed the sorted one for myself anyway.
For the record, the numpy core implementation actually is an hybrid between your method and mine: they use some divide and conquer trick that you would find in the np.searchsorted implementation, but that also leverage (at first sight) potential regular spacing in xp and potential sorting in x.

@mattjj mattjj added question Questions for the JAX team enhancement New feature or request labels Jul 25, 2020
@peterdsharpe
Copy link

Hi all,

While we're discussing the topic of interpolation, I wanted to throw in a link to this repository that has developed some pretty slick JAX-compatible interpolators that may be of interest: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/jax_cosmo/scipy/interpolate.py

In addition to np.interp, they have some more sophisticated interpolators like scipy.interpolate.InterpolatedUnivariateSpline.

There's an issue ticket in jax_cosmo where discussion of PR'ing into JAX was started: DifferentiableUniverseInitiative/jax_cosmo#29 (comment)

@AdrienCorenflos
Copy link
Contributor Author

My main concern with the sorted_interp approach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation using searchsorted is just a few lines of code:

def interp(x, xp, fp):
  i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the searchsorted version proves too slow, maybe focusing our optimization effort on searchsorted would give more bang for the buck?

I just did a rough test of your method (removing the left, right and period arguments), and it seems like it is substantially (10 times) slower indeed (I would expect that it is due to the way search sorted is implemented in jax as it is "only" between 3 and 5 times slower in raw numpy depending on the input size.

Note that np.interp is twice as fast as me (probably due to the "guessing" of the index they do, and some contiguity they manage to ensure by not sorting the input array but I've not checked exactly).

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Jul 27, 2020

My main concern with the sorted_interp approach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation using searchsorted is just a few lines of code:

def interp(x, xp, fp):
  i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the searchsorted version proves too slow, maybe focusing our optimization effort on searchsorted would give more bang for the buck?

I just did a rough test of your method (removing the left, right and period arguments), and it seems like it is substantially (10 times) slower indeed (I would expect that it is due to the way search sorted is implemented in jax as it is "only" between 3 and 5 times slower in raw numpy depending on the input size.

Note that np.interp is twice as fast as me (probably due to the "guessing" of the index they do, and some contiguity they manage to ensure by not sorting the input array but I've not checked exactly).

Actually I just said a lie, your method is faster than mine when jitted, in contradiction with raw numpy stuff.
Maybe my jax loops are unefficient then? Do you know how I can check the generated code to see what's going on under the hood?

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 27, 2020

Interesting, thanks for doing the benchmarks!

In general, loopy code in XLA will not be as fast as array/matrix operations, and the extent of the slowdown will vary depending on the accelerator (the hit is not so bad on CPU, but on GPU or TPU loopy code can be extremely slow).

One way to get a sense of the XLA code that's being generated is via the make_jaxpr function.

Note that searchsorted itself is currently implemented as a while_loop, so it will suffer from this as well. I've been thinking of experimenting with changing this to a scan over binary search depth, which could potentially yield a vast improvement in efficiency for multiple searches.

@AdrienCorenflos
Copy link
Contributor Author

Note that searchsorted itself is currently implemented as a while_loop, so it will suffer from this as well. I've been thinking of experimenting with changing this to a scan over binary search depth, which could potentially yield a vast improvement in efficiency for multiple searches.

I guess it's not suffering from it as much as the depth of every individual loop you do is shallower than my unique one.

I expect lax.scan is also going to suffer from the slowdown then?

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 27, 2020

I made some searchsorted improvements here: #3873

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Jul 28, 2020

So when using your method, the repeated values bit is easy to fix using either:

  • np.unique with return_index=True on xp at the very beginning
  • the following
i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
xp_i = xp[i]
xp_i_1 = xp[i-1]
fp_i = fp[i]
fp_i_1 = fp[i-1]
f = np.where(xp_i > xp_i_1, fp_i_1 *  (xp_i - x) + fp_i * (x - xp_i_1) / (xp_i - xp_i_1), fp_i)

not sure what's the most efficient (how's np.unique behaving on GPU/TPU? Shape depends on values...)

I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.

@AdrienCorenflos
Copy link
Contributor Author

Also, very minimal, but that would probably be a bit more readable (and save a multiplication but that's secondary):

fp_i_1 + (x - xp_i_1) * (fp_i - fp_i_1) / (xp_i - xp_i_1)

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 28, 2020

I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.

For example:

xp = np.linspace(0, 10, 10)
fp = np.sin(xp)
x = np.linspace(0, 10, 100)

y1 = np.interp(x, xp, fp, period=5)
y2 = interp(x, xp, fp, period=5)

import matplotlib.pyplot as plt
plt.plot(xp % 5, fp, '.k', label='input')
plt.plot(x % 5, y1, '.', label='np.interp')
plt.plot(x % 5, y2, '.', label='interp')
plt.legend();

periodic

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Jul 28, 2020

I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.

For example:

xp = np.linspace(0, 10, 10)
fp = np.sin(xp)
x = np.linspace(0, 10, 100)

y1 = np.interp(x, xp, fp, period=5)
y2 = interp(x, xp, fp, period=5)

import matplotlib.pyplot as plt
plt.plot(xp % 5, fp, '.k', label='input')
plt.plot(x % 5, y1, '.', label='np.interp')
plt.plot(x % 5, y2, '.', label='interp')
plt.legend();

periodic

I think you actually made a typo when you copy pasted that part of the code from the numpy function :)

You wrote

xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[1:]])

In the numpy code it's

xp = np.concatenate([xp[-1:] - period, xp, xp[0:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[0:1]])

(note that I don't know why they wrote [0:1] instead of [:1])

When I fix it it works just fine

Screenshot from 2020-07-28 17-02-44

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 28, 2020

If I'd copy-pasted, there wouldn't be a typo 😁

Another set of eyes is always helpful – thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 3, 2020

I've prepared a PR with a searchsorted-based interp in #3949, because we've had some feature requests for it.

@AdrienCorenflos – Please consider this to be reference implementation, and if and when it's merged, feel free to prepare a PR improving on it. Hopefully the test suite that is part of that PR will be useful!

@AdrienCorenflos
Copy link
Contributor Author

@jakevdp solved the above

@AdrienCorenflos
Copy link
Contributor Author

I am leaving it here for reference, but there might be a way to leverage associative scan to implement this more efficiently:

https://www.researchgate.net/publication/225920730_A_parallel_method_for_fast_and_practical_high-order_newton_interpolation

Caveat being that I don't think it has been done in any mainstream library.

Adrien

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants