Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A case where jnp.sinc + jit make a function unstable and non-deterministic #16129

Closed
astanziola opened this issue May 25, 2023 · 9 comments
Closed
Labels
bug Something isn't working

Comments

@astanziola
Copy link

astanziola commented May 25, 2023

Description

Hello jax team!

While working on my jax-based wave simulator, I have encountered this edge case where things are quite weird. Following is the best MRE that I managed to prepare from stripping down the simulation code, sorry if it is quite long, but in a minute I will explain why I can't obviously see how to make it shorter

from matplotlib import pyplot as plt

from jax import numpy as jnp
from jax import jit
from jax.lax import scan

def fft_gradient(u, axis, k_space_op, k_filter):
    Fu = jnp.fft.fftn(u)
    Fx = jnp.moveaxis(Fu, axis, -1)
    k_op = jnp.moveaxis(k_space_op, axis, -1)
    iku = jnp.moveaxis(Fx * k_filter[axis] * k_op, -1, axis)
    return jnp.fft.ifftn(iku).real

def mass_conservation_rhs(N,dx,u,params):
    k_vec, k_space_op = params
    k_filter = [1j * k * jnp.exp(-1j * k * delta / 2) for k, delta in zip(k_vec, dx)]
    du = jnp.stack([fft_gradient(u[..., axis], axis, k_space_op, k_filter) for axis in range(len(N))], axis=-1)
    return -du

def momentum_conservation_rhs(N,dx,p,params):
    k_vec, k_space_op = params
    k_filter = [1j * k * jnp.exp(1j * k * delta / 2) for k, delta in zip(k_vec, dx)]
    dp = jnp.stack([fft_gradient(p[..., 0], axis, k_space_op, k_filter) for axis in range(len(N))], axis=-1)
    return -dp

def kspace_op(N, dx):
    k_vec = [jnp.fft.fftfreq(n, delta) * 2 * jnp.pi for n, delta in zip(N, dx)]
    k_space_op = jnp.sinc(jnp.zeros(N)) # <-- This should be equivalent to jnp.ones(N), but weirdly is not.
    # k_space_op = jnp.ones(N) # <-- Using this makes the computation deterministic and equal to non-jit version
    return k_vec, k_space_op

def integrator(N,dx,p0=None):
    # Setup parameters
    dt, output_steps = 2e-5, jnp.arange(0, 200, 1)
    
    # Initialize variables
    u0 = jnp.zeros(N + (len(N),))
    rho = jnp.stack([p0[..., 0] for i in range(len(N))], axis=-1) / len(N)
    
    # Numerical simulation parameters
    params = kspace_op(N, dx)

    def scan_fun(fields, n):
        p, u, rho = fields
        
        du = momentum_conservation_rhs(N, dx, p, params)
        u = u + dt * du

        drho = mass_conservation_rhs(N, dx, u, params)
        rho = rho + dt * drho

        p = jnp.sum(rho, -1, keepdims=True)
        observation = p[12,12,12] if len(N) == 3 else p[12,12]
        return [p, u, rho], observation

    # Semi-implicit integration
    y0 = [p0, u0, rho]
    _, ys = scan(scan_fun, y0, output_steps)

    return ys

# Domain size
N, dx = (64, 64, 64), (0.1e-3, 0.1e-3, 0.1e-3)

# Initial pressure is a localized spike
p0 = jnp.zeros(N)
p0 = p0.at[32,32,32].set(1.0) if len(N) == 3 else p0.at[32,32].set(1.0)
p0 = jnp.expand_dims(p0, -1)

def f(p0):
    return integrator(N, dx, p0=p0)

g = jit(f) # <-- This function is the weird one

## Running the functions 5 times and recording output
for _ in range(5):
    y = f(p0)                           # Uncompiled function
    plt.plot(jnp.abs(y), color="k")
    y = g(p0)                           # Compiled function
    plt.plot(jnp.abs(y), color="r")
plt.yscale("log")
plt.show()

The last parts of the code are used to visually plot the results of running the function 5 times with jit (red) and without jit (black):

image

Please note the log-axis. Where the red traces disappear, the output is a NaN.

As you can see, the function without jit is stable, while the outputs with jit are unstable but, more importantly I believe, non deterministic.

Some extra observations that make things more weird:

  • If I replace the line k_space_op = jnp.sinc(jnp.zeros(N)) with k_space_op = jnp.ones(N), which in every way should be numerically equivalent since N is a static argument, the results are almost correct and the function becomes deterministic. I say almost because the very first values are actually different than for the non-jitted version: this is probably just a matter of numerical precision.
  • This is derived from a 3D simulation. If I make it 2D by setting N, dx = (64, 64), (0.1e-3, 0.1e-3), the jitted function is stable and deterministic again. It also exactly matches the non-jitted code.
  • The issue arises only on GPU. If I disable the GPU and run this on the CPU the problem almost disappears. Again, by almost here I mean that the function becomes deterministic and also the very initial values are different compared from the non-jitted version.
  • Happens both on my desktop computer (specs below) and on colab
  • This code comes from a function that was working correctly at least up to jax 0.3.20.
  • If I simplify the code further, the problem seems to disappear

What jax/jaxlib version are you using?

jax 0.4.10, jaxlib 0.4.10

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.10.11, Ubuntu 22.04.2 LTS

NVIDIA GPU info

$nvidia-smi

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.41.03              Driver Version: 530.41.03    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce GTX 1080         Off| 00000000:01:00.0  On |                  N/A |
| 36%   53C    P2               34W / 180W|    963MiB /  8192MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|

@hawkinsp
Copy link
Member

I'm trying to run your reproducer, but Fx is not defined in fft_gradient.

Can you fix the example code?

@astanziola
Copy link
Author

Sorry, not sure how I skipped that line, I have added it back now.

Should be correct now, I have just tried it on colab and confirmed that I get the problem.

@astanziola
Copy link
Author

Just adding on this, it seems to not matter if one uses jnp.sinc or not. The same problem happens if one replaces jnp.sinc with jnp.cos for example.

@hawkinsp
Copy link
Member

I just tried to reproduce this again: it reproduces for me with jax/jaxlib 0.4.10, but not with 0.4.12. Can you try updating to 0.4.12?

@astanziola
Copy link
Author

Yes I have just tested it on 0.4.11 and I don't see the issue anymore, not sure what changed

@hawkinsp
Copy link
Member

I'm going to see if I can bisect it to a change, but fixed is fixed!

@astanziola
Copy link
Author

Indeed! 😄

@astanziola
Copy link
Author

Not sure if it can be helpful, but this also got solved by 0.4.11: #14302
The shared component between them is mainly the FFT-based algorithm of calculating gradients.

@hawkinsp
Copy link
Member

I bisected the fix to this XLA change openxla/xla@215705b

I'm not sure why it fixes the problem, but it seems to do the trick!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants