Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] mlx crashes with msg - uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64 #1076

Closed
Tracked by #19571
lkarthee opened this issue May 4, 2024 · 4 comments · Fixed by #1077

Comments

@lkarthee
Copy link

lkarthee commented May 4, 2024

Describe the bug
A clear and concise description of what the bug is.

To Reproduce

Include code snippet

import numpy as np
import mlx.core as mx
from keras.src.ops import core
indices = np.array([[1], [3], [4], [7]])
values = np.array([9, 10, 11, 12])
from keras.src import backend
backend.backend()
# >>> 'mlx'
x = core.scatter(indices, values, (8,))
x
# libc++abi: terminating due to uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64
zsh: abort      python
# keras.ops.scatter for mlx backend
def scatter(indices, values, shape):
    indices = convert_to_tensor(indices)
    values = convert_to_tensor(values)
    zeros = mx.zeros(shape, dtype=values.dtype)
    indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
    zeros = zeros.at[indices].add(values)

    return zeros

Expected behavior
Mlx should not crash - it should throw an exception or error.

Desktop (please complete the following information):

  • OS Version: [e.g. MacOS 14.1.2] 14.1
  • Version [e.g. 0.7.0] '0.12.2'

Additional context
Add any other context about the problem here.

@lkarthee lkarthee changed the title [BUG] scatter op crashes [BUG] mlx crashes with msg uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64 May 4, 2024
@lkarthee lkarthee changed the title [BUG] mlx crashes with msg uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64 [BUG] mlx crashes with msg - uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64 May 4, 2024
@awni
Copy link
Member

awni commented May 5, 2024

If you are just asking for a catchable exception then #1077 should close this. We would like to eventually allow int64 and other 8 byte types to work with scatter, but that is more involved.

@lkarthee
Copy link
Author

lkarthee commented May 5, 2024

Thank you Awni. some observations:

  • crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?
  • scatter ops can use cpu device for int64 and uint64 ?
  • adding a note about supported dtypes and devices cpu, gpu in mlx.core.array.at is helpful .
  • are there any other ops which are not supported on gpu and run on cpu ?
zeros = mx.zeros(shape, dtype=values.dtype) 
zeros = zeros.at[indices].add(values) 

i tried this and it does not work as add does not take device kw_arg:

if zeros.dtype in [mx.int64, mx.uint64] and mx.get_default_device == mx.DeviceType.gpu :
  device = mx.Device(type=mx.DeviceType.cpu)
  zeros = zeros.at[indices].add(values, device=device) 
else:
  zeros = zeros.at[indices].add(values)

It would be helpful if mlx can fallback to cpu for scatter ops which are not supported on gpu or allow device kw_arg for all scatter ops.

Additional ops which are impacted by this bug:

  • mx.cumsum
  • mx.cumprod
  • mx.diag

@awni
Copy link
Member

awni commented May 14, 2024

crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?

I improved the message in #1077. The problem is with the values.

scatter ops can use cpu device for int64 and uint64 ?

We prefer not to silently route to the CPU for ops without a GPU back-end. You can do this in the API by changing the default stream to the CPU before calling the scatter when the dytpe is int64/uint64.

are there any other ops which are not supported on gpu and run on cpu ?

Just a few. FFT and some of the lapack ops (QR / Inverse). Metal support for FFT is coming soon in #981 .

i tried this and it does not work as add does not take device kw_arg:

You can use a context manager. For most free ops stream kwarg also works. E.g.

v = mx.array([1, 2, 3])
u = mx.array([1, 2])
idx = mx.array([0, 1])

with mx.stream(mx.cpu):
    out = v.at[idx].add(u)

@lkarthee
Copy link
Author

Thank you @awni for the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants