-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of np.interp #3860
Comments
This looks great! Note that there are a couple additional arguments to I suspect your approach is more efficient for larger def interp(x, xp, fp, left=None, right=None, period=None):
x, xp, fp = map(np.asarray, (x, xp, fp))
if period:
x = x % period
xp = xp % period
i = np.argsort(xp)
xp = xp[i]
fp = fp[i]
xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[1:]])
i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
f = (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])
if not period:
if left is None:
left = fp[0]
if right is None:
right = fp[-1]
f = np.where(x < xp[0], left, f)
f = np.where(x > xp[-1], right, f)
return f |
I concur about the additional arguments but they are trivially implemented and I didn't want to take the attention away from the big piece of inner logic. |
Also there is probably a lot of commenting to do in there too... It's probably a bit complex at first sight. |
My main concern with the def interp(x, xp, fp):
i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
return (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1]) If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the |
I very much agree with that. The only reason I offered it in the first place is because I needed the sorted one for myself anyway. |
Hi all, While we're discussing the topic of interpolation, I wanted to throw in a link to this repository that has developed some pretty slick JAX-compatible interpolators that may be of interest: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/jax_cosmo/scipy/interpolate.py In addition to There's an issue ticket in |
I just did a rough test of your method (removing the left, right and period arguments), and it seems like it is substantially (10 times) slower indeed (I would expect that it is due to the way search sorted is implemented in jax as it is "only" between 3 and 5 times slower in raw numpy depending on the input size. Note that np.interp is twice as fast as me (probably due to the "guessing" of the index they do, and some contiguity they manage to ensure by not sorting the input array but I've not checked exactly). |
Actually I just said a lie, your method is faster than mine when jitted, in contradiction with raw numpy stuff. |
Interesting, thanks for doing the benchmarks! In general, loopy code in XLA will not be as fast as array/matrix operations, and the extent of the slowdown will vary depending on the accelerator (the hit is not so bad on CPU, but on GPU or TPU loopy code can be extremely slow). One way to get a sense of the XLA code that's being generated is via the make_jaxpr function. Note that |
I guess it's not suffering from it as much as the depth of every individual loop you do is shallower than my unique one. I expect lax.scan is also going to suffer from the slowdown then? |
I made some |
So when using your method, the repeated values bit is easy to fix using either:
i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
xp_i = xp[i]
xp_i_1 = xp[i-1]
fp_i = fp[i]
fp_i_1 = fp[i-1]
f = np.where(xp_i > xp_i_1, fp_i_1 * (xp_i - x) + fp_i * (x - xp_i_1) / (xp_i - xp_i_1), fp_i) not sure what's the most efficient (how's np.unique behaving on GPU/TPU? Shape depends on values...) I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me. |
Also, very minimal, but that would probably be a bit more readable (and save a multiplication but that's secondary):
|
For example: xp = np.linspace(0, 10, 10)
fp = np.sin(xp)
x = np.linspace(0, 10, 100)
y1 = np.interp(x, xp, fp, period=5)
y2 = interp(x, xp, fp, period=5)
import matplotlib.pyplot as plt
plt.plot(xp % 5, fp, '.k', label='input')
plt.plot(x % 5, y1, '.', label='np.interp')
plt.plot(x % 5, y2, '.', label='interp')
plt.legend(); |
I think you actually made a typo when you copy pasted that part of the code from the numpy function :) You wrote xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[1:]]) In the numpy code it's xp = np.concatenate([xp[-1:] - period, xp, xp[0:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[0:1]]) (note that I don't know why they wrote When I fix it it works just fine |
If I'd copy-pasted, there wouldn't be a typo 😁 Another set of eyes is always helpful – thanks! |
I've prepared a PR with a searchsorted-based interp in #3949, because we've had some feature requests for it. @AdrienCorenflos – Please consider this to be reference implementation, and if and when it's merged, feel free to prepare a PR improving on it. Hopefully the test suite that is part of that PR will be useful! |
@jakevdp solved the above |
I am leaving it here for reference, but there might be a way to leverage associative scan to implement this more efficiently: Caveat being that I don't think it has been done in any mainstream library. Adrien |
Hi,
As mentioned here, I am happy to leverage some code I'm putting together for my own project (sorted interpolation) to provide jax with an implementation of np.interp.
Modulo some high level checks on dimensionality of the arrays + the "period" argument which is trivial to transfer to jax, the below is my best attempt so far at it.
I've tested that the results were almost equal to numpy ones, plus the gradients match the numerical ones. Do the org members/collaborators have any problem with the approach (or ideas to make it more JAX-y), or should I iron it out, put it in a PR and send it over?
Adrien
The text was updated successfully, but these errors were encountered: