-
-
Notifications
You must be signed in to change notification settings - Fork 813
/
solve_triangular.py
99 lines (80 loc) · 3.14 KB
/
solve_triangular.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy
import cupy
from cupy.cuda import cublas
from cupy.cuda import device
from cupy.linalg import _util
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
overwrite_b=False, check_finite=False):
"""Solve the equation a x = b for x, assuming a is a triangular matrix.
Args:
a (cupy.ndarray): The matrix with dimension ``(M, M)``.
b (cupy.ndarray): The matrix with dimension ``(M,)`` or
``(M, N)``.
lower (bool): Use only data contained in the lower triangle of ``a``.
Default is to use upper triangle.
trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve:
- *'0'* or *'N'* -- :math:`a x = b`
- *'1'* or *'T'* -- :math:`a^T x = b`
- *'2'* or *'C'* -- :math:`a^H x = b`
unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are
assumed to be 1 and will not be referenced.
overwrite_b (bool): Allow overwriting data in b (may enhance
performance)
check_finite (bool): Whether to check that the input matrices contain
only finite numbers. Disabling may give a performance gain, but may
result in problems (crashes, non-termination) if the inputs do
contain infinities or NaNs.
Returns:
cupy.ndarray:
The matrix with dimension ``(M,)`` or ``(M, N)``.
.. seealso:: :func:`scipy.linalg.solve_triangular`
"""
_util._assert_cupy_array(a, b)
if len(a.shape) != 2 or a.shape[0] != a.shape[1]:
raise ValueError('expected square matrix')
if len(a) != len(b):
raise ValueError('incompatible dimensions')
# Cast to float32 or float64
if a.dtype.char in 'fd':
dtype = a.dtype
else:
dtype = numpy.promote_types(a.dtype.char, 'f')
a = cupy.array(a, dtype=dtype, order='F', copy=False)
b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b))
if check_finite:
if a.dtype.kind == 'f' and not cupy.isfinite(a).all():
raise ValueError(
'array must not contain infs or NaNs')
if b.dtype.kind == 'f' and not cupy.isfinite(b).all():
raise ValueError(
'array must not contain infs or NaNs')
m, n = (b.size, 1) if b.ndim == 1 else b.shape
cublas_handle = device.get_cublas_handle()
if dtype == 'f':
trsm = cublas.strsm
elif dtype == 'd':
trsm = cublas.dtrsm
elif dtype == 'F':
trsm = cublas.ctrsm
else: # dtype == 'D'
trsm = cublas.ztrsm
one = numpy.array(1, dtype=dtype)
if lower:
uplo = cublas.CUBLAS_FILL_MODE_LOWER
else:
uplo = cublas.CUBLAS_FILL_MODE_UPPER
if trans == 'N':
trans = cublas.CUBLAS_OP_N
elif trans == 'T':
trans = cublas.CUBLAS_OP_T
elif trans == 'C':
trans = cublas.CUBLAS_OP_C
if unit_diagonal:
diag = cublas.CUBLAS_DIAG_UNIT
else:
diag = cublas.CUBLAS_DIAG_NON_UNIT
trsm(
cublas_handle, cublas.CUBLAS_SIDE_LEFT, uplo,
trans, diag,
m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m)
return b