Skip to content

Commit dae7d3d

Browse files
committed
linalg : triangular_matrix_matrix_left_solveを追加 (#1233)
Signed-off-by: Yuya Asano <64895419+sukeya@users.noreply.github.com>
1 parent ddd8063 commit dae7d3d

File tree

2 files changed

+280
-1
lines changed

2 files changed

+280
-1
lines changed

reference/linalg.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ BLAS 1, 2, 3のアルゴリズムでテンプレートパラメータが特に
108108
| [`hermitian_matrix_rank_k_update`](linalg/hermitian_matrix_rank_k_update.md) | xHERK: ハミルトニアン行列のRank-k更新 (function template) | C++26 |
109109
| [`symmetric_matrix_rank_2k_update`](linalg/symmetric_matrix_rank_2k_update.md) | xSYR2K: 対称行列のRank-2k更新 (function template) | C++26 |
110110
| [`hermitian_matrix_rank_2k_update`](linalg/hermitian_matrix_rank_2k_update.md) | xHER2K: ハミルトニアン行列のRank-2k更新 (function template) | C++26 |
111-
| `triangular_matrix_matrix_left_solve` | xTRSM: 三角行列の連立一次方程式を解く (function template) | C++26 |
111+
| [`triangular_matrix_matrix_left_solve`](linalg/triangular_matrix_matrix_left_solve.md) | xTRSM: 三角行列の連立一次方程式を解く (function template) | C++26 |
112112
| `triangular_matrix_matrix_right_solve` | xTRSM: 三角行列の連立一次方程式を解く (function template) | C++26 |
113113

114114

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# triangular_matrix_matrix_left_solve
2+
3+
4+
* [mathjax enable]
5+
* linalg[meta header]
6+
* function template[meta id-type]
7+
* std::linalg[meta namespace]
8+
* cpp26[meta cpp]
9+
10+
11+
```cpp
12+
namespace std::linalg {
13+
template<in-matrix InMat1,
14+
class Triangle,
15+
class DiagonalStorage,
16+
in-matrix InMat2,
17+
out-matrix OutMat,
18+
class BinaryDivideOp>
19+
void triangular_matrix_matrix_left_solve(
20+
InMat1 A,
21+
Triangle t,
22+
DiagonalStorage d,
23+
InMat2 B,
24+
OutMat X,
25+
BinaryDivideOp divide); // (1)
26+
27+
template<class ExecutionPolicy,
28+
in-matrix InMat1,
29+
class Triangle,
30+
class DiagonalStorage,
31+
in-matrix InMat2,
32+
out-matrix OutMat,
33+
class BinaryDivideOp>
34+
void triangular_matrix_matrix_left_solve(
35+
ExecutionPolicy&& exec,
36+
InMat1 A,
37+
Triangle t,
38+
DiagonalStorage d,
39+
InMat2 B,
40+
OutMat X,
41+
BinaryDivideOp divide); // (2)
42+
43+
template<in-matrix InMat1,
44+
class Triangle,
45+
class DiagonalStorage,
46+
in-matrix InMat2,
47+
out-matrix OutMat>
48+
void triangular_matrix_matrix_left_solve(
49+
InMat1 A,
50+
Triangle t,
51+
DiagonalStorage d,
52+
InMat2 B,
53+
OutMat X); // (3)
54+
55+
template<class ExecutionPolicy,
56+
in-matrix InMat1,
57+
class Triangle,
58+
class DiagonalStorage,
59+
in-matrix InMat2,
60+
out-matrix OutMat>
61+
void triangular_matrix_matrix_left_solve(
62+
ExecutionPolicy&& exec,
63+
InMat1 A,
64+
Triangle t,
65+
DiagonalStorage d,
66+
InMat2 B,
67+
OutMat X); // (4)
68+
}
69+
```
70+
71+
72+
## 概要
73+
三角行列に対して、連立一次方程式を解く。
74+
引数`t`は対称行列の成分が上三角にあるのか、それとも下三角にあるのかを示す。
75+
引数`d`には対称行列の対角成分を暗黙に乗法における単位元とみなすかどうかを指定する。
76+
引数`divide`には値の割り算を指定する。この引数は非可換な掛け算を持つ値型をサポートするためにある。
77+
78+
- (1): 連立一次方程式 $AY = B$ を解き、`Y`を`X`に代入する。もし解が存在しないなら、`X`は有効だが未規定。
79+
- (2): (1)を指定された実行ポリシーで実行する。
80+
- (3): 割り算に[`std::divides`](/reference/functional/divides.md)`<void>`を用いて、(1)を行う。
81+
- (4): (3)を指定された実行ポリシーで実行する。
82+
83+
84+
## 適格要件
85+
- 共通:
86+
+ `Triangle`は[`upper_triangle_t`](upper_triangle_t.md)または[`lower_triangle_t`](lower_triangle_t.md)
87+
+ `DiagonalStorage`は[`implicit_unit_diagonal_t`](implicit_unit_diagonal_t.md)または[`explicit_diagonal_t`](explicit_diagonal_t.md)
88+
+ `InMat1`(`A`の型)が[`layout_blas_packed`](layout_blas_packed.md)を持つなら、レイアウトの`Triangle`テンプレート引数とこの関数の`Triangle`テンプレート引数が同じ型
89+
+ [`possibly-multipliable`](possibly-multipliable.md)`<decltype(A), decltype(X), decltype(B)>()`が`true`
90+
+ [`compatible-static-extents`](compatible-static-extents.md)`<decltype(A), decltype(A)>(0, 1)`が`true` (つまり`A`が正方行列であること)
91+
- (2), (4): [`is_execution_policy`](/reference/execution/is_execution_policy.md)`<ExecutionPolicy>::value`が`true`
92+
93+
94+
## 事前条件
95+
- [`multipliable`](multipliable.md)`(A, X, B)`が`true`
96+
- `A.extent(0) == A.extent(1)` (つまり`A`が正方行列であること)
97+
98+
99+
## 効果
100+
対称行列の成分の位置を示す`t`と対角成分へアクセスするかどうかを示す`d`を考慮して、連立一次方程式の解を求める。
101+
102+
- (1), (2): 連立一次方程式 $AY = B$ を解き、`Y`を`X`に代入する。もし解が存在しないなら、`X`は有効だが未規定。
103+
- (3): `triangular_matrix_matrix_left_solve(A, t, d, B, X, divides<void>{})`と同じ。
104+
- (4): `triangular_matrix_matrix_left_solve(std::forward<ExecutionPolicy>(exec), A, t, d, B, X, divides<void>{})`と同じ。
105+
106+
107+
## 戻り値
108+
なし
109+
110+
111+
## 計算量
112+
$O(\verb|A.extent(0)| \times (\verb|X.extent(0)|)^2)$
113+
114+
115+
## 備考
116+
- 三角行列が左側にあるので、非可換な掛け算の場合の`divide`の望ましい実装は数学では$y^{-1}x$と同等と思われる。ここで`x`は最初の引数で`y`は2番目の引数、$y^{-1}$は`y`の掛け算での逆元である。
117+
118+
119+
## 例
120+
**[注意] 処理系にあるコンパイラで確認していないため、間違っているかもしれません。**
121+
122+
```cpp example
123+
#include <array>
124+
#include <functional>
125+
#include <iostream>
126+
#include <linalg>
127+
#include <mdspan>
128+
#include <vector>
129+
130+
template <class Matrix>
131+
void print_mat(const Matrix& A) {
132+
for(int i = 0; i < A.extent(0); ++i) {
133+
for(int j = 0; j < A.extent(1) - 1; ++j) {
134+
std::cout << A[i, j] << ' ';
135+
}
136+
std::cout << A[i, A.extent(1) - 1] << '\n';
137+
}
138+
}
139+
140+
template <class Matrix>
141+
void init_mat(Matrix& A, typename Matrix::value_type geta = 1) {
142+
for(int i = 0; i < A.extent(0); ++i) {
143+
for(int j = 0; j < A.extent(1); ++j) {
144+
A[i, j] = i * A.extent(1) + j + geta;
145+
}
146+
}
147+
}
148+
149+
template <class Matrix>
150+
void init_tria_mat(Matrix& A) {
151+
for(int i = 0; i < A.extent(0); ++i) {
152+
for(int j = i + 1; j < A.extent(1); ++j) {
153+
A[i, j] = i * A.extent(1) + j;
154+
}
155+
}
156+
}
157+
158+
int main()
159+
{
160+
constexpr size_t N = 4;
161+
162+
std::vector<double> A_vec(N * N);
163+
std::vector<double> X_vec(N * N);
164+
std::vector<double> B_vec(N * N);
165+
166+
std::mdspan<
167+
double,
168+
std::extents<size_t, N, N>,
169+
std::linalg::layout_blas_packed<
170+
std::linalg::upper_triangle_t,
171+
std::linalg::row_major_t>
172+
> A(A_vec.data());
173+
std::mdspan X(X_vec.data(), N, N);
174+
std::mdspan B(B_vec.data(), N, N);
175+
176+
init_mat(A)
177+
init_mat(B);
178+
179+
// (1)
180+
std::cout << "(1)\n";
181+
std::linalg::triangular_matrix_matrix_left_solve(
182+
A,
183+
std::linalg::upper_triangle,
184+
std::linalg::implicit_unit_diagonal,
185+
B,
186+
X,
187+
std::divides<void>{});
188+
print_mat(X);
189+
190+
// (2)
191+
std::cout << "(2)\n";
192+
std::linalg::triangular_matrix_matrix_left_solve(
193+
std::execution::par,
194+
A,
195+
std::linalg::upper_triangle,
196+
std::linalg::implicit_unit_diagonal,
197+
B,
198+
X,
199+
std::divides<void>{});
200+
print(X);
201+
202+
// (3)
203+
std::cout << "(3)\n";
204+
std::linalg::triangular_matrix_matrix_left_solve(
205+
A,
206+
std::linalg::upper_triangle,
207+
std::linalg::implicit_unit_diagonal,
208+
B,
209+
X);
210+
print(X);
211+
212+
// (4)
213+
std::cout << "(4)\n";
214+
std::linalg::triangular_matrix_matrix_left_solve(
215+
std::execution::par,
216+
A,
217+
std::linalg::upper_triangle,
218+
std::linalg::implicit_unit_diagonal,
219+
B,
220+
X);
221+
print(X);
222+
223+
return 0;
224+
}
225+
```
226+
* A.extent[link /reference/mdspan/extents/extent.md]
227+
* v.extent[link /reference/mdspan/extents/extent.md]
228+
* std::mdspan[link /reference/mdspan/mdspan.md]
229+
* std::extents[link /reference/mdspan/extents.md]
230+
* std::linalg::layout_blas_packed[link /reference/linalg/layout_blas_packed.md]
231+
* std::linalg::upper_triangle_t[link /reference/linalg/upper_triangle_t.md]
232+
* std::linalg::row_major_t[link /reference/linalg/row_major_t.md]
233+
* std::linalg::upper_triangle[link /reference/linalg/upper_triangle_t.md]
234+
* std::linalg::implicit_unit_diagonal[link /reference/linalg/implicit_unit_diagonal_t.md]
235+
* std::execution::par[link /reference/execution/execution/execution_policy.md]
236+
* std::linalg::triangular_matrix_matrix_left_solve[color ff0000]
237+
238+
239+
### 出力
240+
```
241+
(1)
242+
-2 -2
243+
2 3
244+
(2)
245+
-2 -2
246+
2 3
247+
(3)
248+
-2 -2
249+
2 3
250+
(4)
251+
-2 -2
252+
2 3
253+
```
254+
255+
256+
## バージョン
257+
### 言語
258+
- C++26
259+
260+
### 処理系
261+
- [Clang](/implementation.md#clang): ??
262+
- [GCC](/implementation.md#gcc): ??
263+
- [ICC](/implementation.md#icc): ??
264+
- [Visual C++](/implementation.md#visual_cpp): ??
265+
266+
267+
## 関連項目
268+
- [`execution`](/reference/execution.md)
269+
- [`mdspan`](/reference/mdspan.md)
270+
- [`upper_triangle_t`](upper_triangle_t.md)
271+
- [`lower_triangle_t`](lower_triangle_t.md)
272+
- [`implicit_unit_diagonal`](implicit_unit_diagonal_t.md)
273+
- [`explicit_diagonal`](explicit_diagonal_t.md)
274+
275+
276+
## 参照
277+
- [P1673R13 A free function linear algebra interface based on the BLAS](https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p1673r13.html)
278+
- [LAPACK: trsm](https://netlib.org/lapack/explore-html/d9/de5/group__trsm.html)
279+

0 commit comments

Comments
 (0)