Skip to content

Commit

Permalink
Merge pull request #7966 from andfoy/fix_solve_lapack
Browse files Browse the repository at this point in the history
Import cupyx.lapack inside cupy.linalg.solve
  • Loading branch information
kmaehashi committed Oct 29, 2023
2 parents f6b794a + 1fbb3ab commit 1cd6db5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cupy/linalg/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def solve(a, b):
.. seealso:: :func:`numpy.linalg.solve`
"""
from cupyx import lapack
from cupy.cublas import batched_gesv, get_batched_gesv_limit

if a.ndim > 2 and a.shape[-1] <= get_batched_gesv_limit():
Expand Down Expand Up @@ -64,7 +65,7 @@ def solve(a, b):
# prevent 'a' and 'b' to be overwritten
a = a.astype(dtype, copy=True, order='F')
b = b.astype(dtype, copy=True, order='F')
cupyx.lapack.gesv(a, b)
lapack.gesv(a, b)
return b.astype(out_dtype, copy=False)

# prevent 'a' to be overwritten
Expand All @@ -75,7 +76,7 @@ def solve(a, b):
index = numpy.unravel_index(i, shape)
# prevent 'b' to be overwritten
bi = b[index].astype(dtype, copy=True, order='F')
cupyx.lapack.gesv(a[index], bi)
lapack.gesv(a[index], bi)
x[index] = bi
return x

Expand Down

0 comments on commit 1cd6db5

Please sign in to comment.