Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] dask.array.linalg.inv fails on CuPy backed arrays #4899

Open
beckernick opened this issue Jun 7, 2019 · 8 comments
Open

[BUG] dask.array.linalg.inv fails on CuPy backed arrays #4899

beckernick opened this issue Jun 7, 2019 · 8 comments
Labels

Comments

@beckernick
Copy link
Member

beckernick commented Jun 7, 2019

Linalg.inv appears to fail for the same general reason as #4898 . Filing this for visibility purposes in case others run into this issue before #4731 is resolved.

import dask.array as da
import cupy as cp
import numpy as npnp.random.seed(12)
arr = np.random.normal(10, 10, (1000,1000))
arr_cp = cp.asarray(arr)
darr_cp = da.from_array(arr_cp, asarray=False)
da.linalg.inv(darr_cp).compute()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-1-a9c479838e0d> in <module>
      7 arr_cp = cp.asarray(arr)
      8 darr_cp = da.from_array(arr_cp, asarray=False)
----> 9 da.linalg.inv(darr_cp).compute()

/conda/envs/rapids/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    154         dask.base.compute
    155         """
--> 156         (result,) = compute(self, traverse=False, **kwargs)
    157         return result
    158 

/conda/envs/rapids/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    396     keys = [x.__dask_keys__() for x in collections]
    397     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 398     results = schedule(dsk, keys, **kwargs)
    399     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    400 

/conda/envs/rapids/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     74     results = get_async(pool.apply_async, len(pool._pool), dsk, result,
     75                         cache=cache, get_id=_thread_get_id,
---> 76                         pack_exception=pack_exception, **kwargs)
     77 
     78     # Cleanup pools associated to dead threads

/conda/envs/rapids/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    460                         _execute_task(task, data)  # Re-execute locally
    461                     else:
--> 462                         raise_exception(exc, tb)
    463                 res, worker_id = loads(res_info)
    464                 state['cache'][key] = res

/conda/envs/rapids/lib/python3.7/site-packages/dask/compatibility.py in reraise(exc, tb)
    110         if exc.__traceback__ is not tb:
    111             raise exc.with_traceback(tb)
--> 112         raise exc
    113 
    114     import pickle as cPickle

/conda/envs/rapids/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    228     try:
    229         task, data = loads(task_info)
--> 230         result = _execute_task(task, data)
    231         id = get_id()
    232         result = dumps((result, id))

/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

/conda/envs/rapids/lib/python3.7/site-packages/scipy/linalg/decomp_lu.py in lu(a, permute_l, overwrite_a, check_finite)
    209     """
    210     if check_finite:
--> 211         a1 = asarray_chkfinite(a)
    212     else:
    213         a1 = asarray(a)

/conda/envs/rapids/lib/python3.7/site-packages/numpy/lib/function_base.py in asarray_chkfinite(a, dtype, order)
    493 
    494     """
--> 495     a = asarray(a, dtype=dtype, order=order)
    496     if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():
    497         raise ValueError(

/conda/envs/rapids/lib/python3.7/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

ValueError: object __array__ method not producing an array

I do have the NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1 environment variable set for this test

import os
os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION']
'1'
@mrocklin
Copy link
Member

mrocklin commented Jun 7, 2019

This is somewhat tangential, but are you sure that you want inv? The inv function is almost never appropriate to call. Perhaps you wanted to call solve instead? It is usually much faster and more numerically stable.

@jakirkham jakirkham added the array label Jun 7, 2019
@jakirkham
Copy link
Member

It would be interesting to know more about the context where inv is used (if that can be shared).

@beckernick
Copy link
Member Author

beckernick commented Jun 7, 2019

Let's take this discussion offline to discuss inv vs solve.

@pentschev
Copy link
Member

For the record, we'll need #4883 for this.

@jakirkham
Copy link
Member

It sounds like this is less important.

@beckernick
Copy link
Member Author

Agreed. This does not need to be prioritized.

@jsignell
Copy link
Member

It seems like this can be worked on now, since #4883 is in.

@pentschev
Copy link
Member

Not really, this still doesn't work because dask.linalg is a module that relies mostly on SciPy, which today doesn't support NEP-18 and NEP-35. I think we'll need special-casing to handle things such as linalg.inv moving forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants