Skip to content

Commit d00e1cd

Browse files
committed
linalg : triangular_matrix_vector_solveを追加 (#1233)
Signed-off-by: Yuya Asano <64895419+sukeya@users.noreply.github.com>
1 parent dcd9949 commit d00e1cd

File tree

2 files changed

+390
-1
lines changed

2 files changed

+390
-1
lines changed

reference/linalg.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ BLAS 1, 2, 3のアルゴリズムでテンプレートパラメータが特に
8585
| [`symmetric_matrix_vector_product`](linalg/symmetric_matrix_vector_product.md) | xSYMV: 対称行列とベクトルの積を求める (function template) | C++26 |
8686
| [`hermitian_matrix_vector_product`](linalg/hermitian_matrix_vector_product.md) | xHEMV: ハミルトニアン行列とベクトルの積を求める (function template) | C++26 |
8787
| [`triangular_matrix_vector_product`](linalg/triangular_matrix_vector_product.md) | xTRMV: 三角行列とベクトルの積を求める (function template) | C++26 |
88-
| `triangular_matrix_vector_solve` | xTRSV: 三角行列を係数とする行列方程式を解く (function template) | C++26 |
88+
| [`triangular_matrix_vector_solve`](linalg/triangular_matrix_vector_solve.md) | xTRSV: 三角行列を係数とする行列方程式を解く (function template) | C++26 |
8989
| `matrix_rank_1_update` | xGER, xGERU: 行列のRank-1更新 (function template) | C++26 |
9090
| `matrix_rank_1_update_c` | xGERC: 複素行列のRank-1更新 (function template) | C++26 |
9191
| `symmetric_matrix_rank_1_update` | xSYR: 対称行列のRank-1更新 (function template) | C++26 |
Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
# triangular_matrix_vector_solve
2+
3+
4+
* [mathjax enable]
5+
* linalg[meta header]
6+
* function template[meta id-type]
7+
* std::linalg[meta namespace]
8+
* cpp26[meta cpp]
9+
10+
11+
```cpp
12+
namespace std::linalg {
13+
template<in-matrix InMat,
14+
class Triangle,
15+
class DiagonalStorage,
16+
in-vector InVec,
17+
out-vector OutVec,
18+
class BinaryDivideOp>
19+
void triangular_matrix_vector_solve(
20+
InMat A,
21+
Triangle t,
22+
DiagonalStorage d,
23+
InVec b,
24+
OutVec x,
25+
BinaryDivideOp divide); // (1)
26+
27+
template<class ExecutionPolicy,
28+
in-matrix InMat,
29+
class Triangle,
30+
class DiagonalStorage,
31+
in-vector InVec,
32+
out-vector OutVec,
33+
class BinaryDivideOp>
34+
void triangular_matrix_vector_solve(
35+
ExecutionPolicy&& exec,
36+
InMat A,
37+
Triangle t,
38+
DiagonalStorage d,
39+
InVec b,
40+
OutVec x,
41+
BinaryDivideOp divide); // (2)
42+
43+
template<in-matrix InMat,
44+
class Triangle,
45+
class DiagonalStorage,
46+
in-vector InVec,
47+
out-vector OutVec>
48+
void triangular_matrix_vector_solve(
49+
InMat A,
50+
Triangle t,
51+
DiagonalStorage d,
52+
InVec b,
53+
OutVec x); // (3)
54+
55+
template<class ExecutionPolicy,
56+
in-matrix InMat,
57+
class Triangle,
58+
class DiagonalStorage,
59+
in-vector InVec,
60+
out-vector OutVec>
61+
void triangular_matrix_vector_solve(
62+
ExecutionPolicy&& exec,
63+
InMat A,
64+
Triangle t,
65+
DiagonalStorage d,
66+
InVec b,
67+
OutVec x); // (4)
68+
69+
template<in-matrix InMat,
70+
class Triangle,
71+
class DiagonalStorage,
72+
inout-vector InOutVec,
73+
class BinaryDivideOp>
74+
void triangular_matrix_vector_solve(
75+
InMat A,
76+
Triangle t,
77+
DiagonalStorage d,
78+
InOutVec b,
79+
BinaryDivideOp divide); // (5)
80+
81+
template<class ExecutionPolicy,
82+
in-matrix InMat,
83+
class Triangle,
84+
class DiagonalStorage,
85+
inout-vector InOutVec,
86+
class BinaryDivideOp>
87+
void triangular_matrix_vector_solve(
88+
ExecutionPolicy&& exec,
89+
InMat A,
90+
Triangle t,
91+
DiagonalStorage d,
92+
InOutVec b,
93+
BinaryDivideOp divide); // (6)
94+
95+
template<in-matrix InMat,
96+
class Triangle,
97+
class DiagonalStorage,
98+
inout-vector InOutVec>
99+
void triangular_matrix_vector_solve(
100+
InMat A,
101+
Triangle t,
102+
DiagonalStorage d,
103+
InOutVec b); // (7)
104+
105+
template<class ExecutionPolicy,
106+
in-matrix InMat,
107+
class Triangle,
108+
class DiagonalStorage,
109+
inout-vector InOutVec>
110+
void triangular_matrix_vector_solve(
111+
ExecutionPolicy&& exec,
112+
InMat A,
113+
Triangle t,
114+
DiagonalStorage d,
115+
InOutVec b); // (8)
116+
}
117+
```
118+
119+
120+
## 概要
121+
三角行列に対して、連立一次方程式を解く。
122+
引数`t`は対称行列の成分が上三角にあるのか、それとも下三角にあるのかを示す。
123+
引数`d`には対称行列の対角成分を暗黙に乗法における単位元とみなすかどうかを指定する。
124+
引数`divide`には値の割り算を指定する。この引数は非可換な掛け算を持つ値型をサポートするためにある。
125+
126+
- (1): 連立一次方程式 $Ay = b$ を解き、`y`を`x`に代入する。もし解が存在しないなら、`x`は有効だが未規定。
127+
- (2): (1)を指定された実行ポリシーで実行する。
128+
- (3): 割り算に[`std::divedes`](/reference/functional/divides.md)`<void>`を用いて、(1)を行う。
129+
- (4): (3)を指定された実行ポリシーで実行する。
130+
- (5): `x`に`b`を使って、in-placeに(1)を行う。
131+
- (6): (5)を指定された実行ポリシーで実行する。
132+
- (7): 割り算に[`std::divedes`](/reference/functional/divides.md)`<void>`を用いて、(5)を行う。
133+
- (8): (7)を指定された実行ポリシーで実行する。
134+
135+
136+
## 適格要件
137+
- 共通:
138+
+ `Triangle`は[`upper_triangle_t`](upper_triangle_t.md)または[`lower_triangle_t`](lower_triangle_t.md)
139+
+ `DiagonalStorage`は[`implicit_unit_diagonal_t`](implicit_unit_diagonal_t.md)または[`explicit_diagonal_t`](explicit_diagonal_t.md)
140+
+ `InMat`が[`layout_blas_packed`](layout_blas_packed.md)を持つなら、レイアウトの`Triangle`テンプレート引数とこの関数の`Triangle`テンプレート引数が同じ型
141+
+ [`compatible-static-extents`](compatible-static-extents.md)`<decltype(A), decltype(A)>(0, 1)`が`true` (つまり`A`が正方行列であること)
142+
+ [`compatible-static-extents`](compatible-static-extents.md)`<decltype(A), decltype(b)>(0, 0)`が`true` (つまり`A`の次元と`b`の次元が同じであること)
143+
- (1), (2), (3), (4): [`compatible-static-extents`](compatible-static-extents.md)`<decltype(A), decltype(x)>(0, 0)`が`true` (つまり`A`の次元と`b`の次元が同じであること)
144+
145+
146+
## 事前条件
147+
- 共通:
148+
+ `A.extent(0) == A.extent(1)` (つまり`A`が正方行列であること)
149+
+ `A.extent(0) == b.extent(0)` (つまり`A`の次元と`b`の次元が同じであること)
150+
- (1), (2), (3), (4): `A.extent(0) == x.extent(0)` (つまり`A`の次元と`x`の次元が同じであること)
151+
152+
153+
## 効果
154+
対称行列の成分の位置を示す`t`と対角成分へアクセスするかどうかを示す`d`を考慮して、連立一次方程式の解を求める。
155+
156+
- (1), (2): 連立一次方程式 $Ay = b$ を解き、`y`を`x`に代入する。もし解が存在しないなら、`x`は有効だが未規定。
157+
- (3): `triangular_matrix_vector_solve(A, t, d, b, x, divides<void>{})`と同じ。
158+
- (4): `triangular_matrix_vector_solve(std::forward<ExecutionPolicy>(exec), A, t, d, b, x, divides<void>{})`と同じ。
159+
- (5), (6): `x`に`b`を使って、in-placeに(1)を行う。
160+
- (7): `triangular_matrix_vector_solve(A, t, d, b, divides<void>{})`と同じ。
161+
- (8): `triangular_matrix_vector_solve(std::forward<ExecutionPolicy>(exec), A, t, d, b, divides<void>{})`と同じ。
162+
163+
164+
## 戻り値
165+
なし
166+
167+
168+
## 計算量
169+
$O(\verb|A.extent(1)|\times \verb|x.extent(0)|)$
170+
171+
172+
## 備考
173+
- (6), (8): in-placeアルゴリズムなので並列実行を妨げるが、ベクトル化といった`ExecutionPolicy`特有の最適化はできる。
174+
175+
176+
## 例
177+
**[注意] 処理系にあるコンパイラで確認していないため、間違っているかもしれません。**
178+
179+
```cpp example
180+
#include <array>
181+
#include <functional>
182+
#include <iostream>
183+
#include <linalg>
184+
#include <mdspan>
185+
#include <vector>
186+
187+
template <class Vector>
188+
void print(const Vector& v, const std::string& name) {
189+
for (int i = 0; i < v.extent(0); ++i) {
190+
std::cout << name << "[" << i << "]" << " = " << v[i] << '\n';
191+
}
192+
}
193+
194+
template <class Vector>
195+
void init(Vector& v) {
196+
for (int i = 0; i < v.extent(0); ++i) {
197+
v[i] = i;
198+
}
199+
}
200+
201+
int main()
202+
{
203+
constexpr size_t N = 4;
204+
205+
std::vector<double> A_vec(N * N);
206+
std::vector<double> x_vec(N);
207+
std::array<double, N> b_vec;
208+
209+
std::mdspan<
210+
double,
211+
std::extents<size_t, N, N>,
212+
std::linalg::layout_blas_packed<
213+
std::linalg::upper_triangle_t,
214+
std::linalg::row_major_t>
215+
> A(A_vec.data());
216+
std::mdspan x(x_vec.data(), N);
217+
std::mdspan b(b_vec.data(), N);
218+
219+
for(int i = 0; i < A.extent(0); ++i) {
220+
for(int j = i + 1; j < A.extent(1); ++j) {
221+
A[i,j] = A.extent(1) * i + j;
222+
}
223+
}
224+
225+
init(b);
226+
227+
// (1)
228+
std::cout << "(1)\n";
229+
std::linalg::triangular_matrix_vector_solve(
230+
A,
231+
std::linalg::upper_triangle,
232+
std::linalg::implicit_unit_diagonal,
233+
b,
234+
x,
235+
std::divides<void>{});
236+
print(x, "x");
237+
238+
// (2)
239+
std::cout << "(2)\n";
240+
std::linalg::triangular_matrix_vector_solve(
241+
std::execution::par,
242+
A,
243+
std::linalg::upper_triangle,
244+
std::linalg::implicit_unit_diagonal,
245+
b,
246+
x,
247+
std::divides<void>{});
248+
print(x, "x");
249+
250+
// (3)
251+
std::cout << "(3)\n";
252+
std::linalg::triangular_matrix_vector_solve(
253+
A,
254+
std::linalg::upper_triangle,
255+
std::linalg::implicit_unit_diagonal,
256+
b,
257+
x);
258+
print(x, "x");
259+
260+
// (4)
261+
std::cout << "(4)\n";
262+
std::linalg::triangular_matrix_vector_solve(
263+
std::execution::par,
264+
A,
265+
std::linalg::upper_triangle,
266+
std::linalg::implicit_unit_diagonal,
267+
b,
268+
x);
269+
print(x, "x");
270+
271+
// (5)
272+
std::cout << "(5)\n";
273+
std::linalg::triangular_matrix_vector_solve(
274+
A,
275+
std::linalg::upper_triangle,
276+
std::linalg::implicit_unit_diagonal,
277+
b,
278+
std::divides<void>{});
279+
print(b, "b");
280+
281+
init(b);
282+
283+
// (6)
284+
std::cout << "(6)\n";
285+
std::linalg::triangular_matrix_vector_solve(
286+
std::execution::par,
287+
A,
288+
std::linalg::upper_triangle,
289+
std::linalg::implicit_unit_diagonal,
290+
b,
291+
std::divides<void>{});
292+
print(b, "b");
293+
294+
init(b);
295+
296+
// (7)
297+
std::cout << "(7)\n";
298+
std::linalg::triangular_matrix_vector_solve(
299+
A,
300+
std::linalg::upper_triangle,
301+
std::linalg::implicit_unit_diagonal,
302+
b);
303+
print(b, "b");
304+
305+
init(b);
306+
307+
// (8)
308+
std::cout << "(8)\n";
309+
std::linalg::triangular_matrix_vector_solve(
310+
std::execution::par,
311+
A,
312+
std::linalg::upper_triangle,
313+
std::linalg::implicit_unit_diagonal,
314+
b);
315+
print(b, "b");
316+
317+
return 0;
318+
}
319+
```
320+
321+
322+
### 出力
323+
```
324+
(1)
325+
x[0] = -3
326+
x[1] = -4
327+
x[2] = -1
328+
x[3] = 3
329+
(2)
330+
x[0] = -3
331+
x[1] = -4
332+
x[2] = -1
333+
x[3] = 3
334+
(3)
335+
x[0] = -3
336+
x[1] = -4
337+
x[2] = -1
338+
x[3] = 3
339+
(4)
340+
x[0] = -3
341+
x[1] = -4
342+
x[2] = -1
343+
x[3] = 3
344+
(5)
345+
b[0] = -3
346+
b[1] = -4
347+
b[2] = -1
348+
b[3] = 3
349+
(6)
350+
b[0] = -3
351+
b[1] = -4
352+
b[2] = -1
353+
b[3] = 3
354+
(7)
355+
b[0] = -3
356+
b[1] = -4
357+
b[2] = -1
358+
b[3] = 3
359+
(8)
360+
b[0] = -3
361+
b[1] = -4
362+
b[2] = -1
363+
b[3] = 3
364+
```
365+
366+
367+
## バージョン
368+
### 言語
369+
- C++26
370+
371+
### 処理系
372+
- [Clang](/implementation.md#clang): ??
373+
- [GCC](/implementation.md#gcc): ??
374+
- [ICC](/implementation.md#icc): ??
375+
- [Visual C++](/implementation.md#visual_cpp): ??
376+
377+
378+
## 関連項目
379+
- [`execution`](/reference/execution.md)
380+
- [`mdspan`](/reference/mdspan.md)
381+
- [`upper_triangle_t`](upper_triangle_t.md)
382+
- [`lower_triangle_t`](lower_triangle_t.md)
383+
- [`implicit_unit_diagonal`](implicit_unit_diagonal_t.md)
384+
- [`explicit_diagonal`](explicit_diagonal_t.md)
385+
386+
387+
## 参照
388+
- [P1673R13 A free function linear algebra interface based on the BLAS](https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p1673r13.html)
389+
- [LAPACK: trmv](https://netlib.org/lapack/explore-html/dd/dc3/group__trsv.html)

0 commit comments

Comments
 (0)