Skip to content

Commit

Permalink
Merge pull request #8206 from ev-br/expm_complex
Browse files Browse the repository at this point in the history
Fix `expm(complex matrix)`
  • Loading branch information
takagi committed Feb 26, 2024
2 parents dcd325a + 8e2d1b8 commit 1144fca
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
16 changes: 12 additions & 4 deletions cupyx/scipy/linalg/_matfuncs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import cmath

import cupy
from cupy.linalg import _util
Expand Down Expand Up @@ -90,9 +91,13 @@ def expm(a):

n = a.shape[0]

# follow scipy.linalg.expm dtype handling
a_dtype = a.dtype if cupy.issubdtype(
a.dtype, cupy.inexact) else cupy.float64

# try reducing the norm
mu = cupy.diag(a).sum() / n
A = a - cupy.eye(n)*mu
A = a - cupy.eye(n, dtype=a_dtype)*mu

# scale factor
nrmA = cupy.linalg.norm(A, ord=1).item()
Expand All @@ -110,9 +115,10 @@ def expm(a):
A4 = A2 @ A2
A6 = A2 @ A4

E = cupy.eye(A.shape[0])
E = cupy.eye(A.shape[0], dtype=a_dtype)
bb = cupy.asarray(b, dtype=a_dtype)

u1, u2, v1, v2 = _expm_inner(E, A, A2, A4, A6, cupy.asarray(b))
u1, u2, v1, v2 = _expm_inner(E, A, A2, A4, A6, bb)
u = A @ (A6 @ u1 + u2)
v = A6 @ v1 + v2

Expand All @@ -124,7 +130,9 @@ def expm(a):
x = x @ x

# undo preprocessing
x *= math.exp(mu)
emu = cmath.exp(mu) if cupy.issubdtype(
mu.dtype, cupy.complexfloating) else math.exp(mu)
x *= emu

return x

Expand Down
6 changes: 6 additions & 0 deletions tests/cupyx_tests/scipy_tests/linalg_tests/test_matfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,9 @@ def test_2x2_input(self, xp, scp):
def test_nx2x2_input(self, xp, scp, a):
a = xp.asarray(a)
return scp.linalg.expm(a)

@testing.for_all_dtypes(no_bool=True, no_float16=True)
@testing.numpy_cupy_allclose(scipy_name='scp', contiguous_check=False)
def test_dtypes(self, xp, scp, dtype):
a = xp.eye(2, dtype=dtype)
return scp.linalg.expm(a)

0 comments on commit 1144fca

Please sign in to comment.