Skip to content

Commit d837379

Browse files
committed
linalg : symmetric_matrix_productを追加 (#1233)
Signed-off-by: Yuya Asano <64895419+sukeya@users.noreply.github.com>
1 parent 641d6a7 commit d837379

File tree

2 files changed

+366
-1
lines changed

2 files changed

+366
-1
lines changed

reference/linalg.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ BLAS 1, 2, 3のアルゴリズムでテンプレートパラメータが特に
9999
| 名前 | 説明 | 対応バージョン |
100100
|------|------|----------------|
101101
| [`matrix_product`](linalg/matrix_product.md) | xGEMM: 2つの一般行列の積を求める (function template) | C++26 |
102-
| `symmetric_matrix_product` | xSYMM: 対称行列と行列の積を求める (function template) | C++26 |
102+
| [`symmetric_matrix_product`](linalg/symmetric_matrix_product.md) | xSYMM: 対称行列と行列の積を求める (function template) | C++26 |
103103
| `hermitian_matrix_product` | xHEMM: ハミルトニアン行列と行列の積を求める (function template) | C++26 |
104104
| `triangular_matrix_product` | xTRMM: 三角行列と行列の積を求める (function template) | C++26 |
105105
| `triangular_matrix_left_product` | xTRMM: In-placeに三角行列と行列の積を求める (function template) | C++26 |
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
# symmetric_matrix_product
2+
3+
4+
* [mathjax enable]
5+
* linalg[meta header]
6+
* function template[meta id-type]
7+
* std::linalg[meta namespace]
8+
* cpp26[meta cpp]
9+
10+
11+
```cpp
12+
namespace std::linalg {
13+
template<in-matrix InMat1,
14+
class Triangle,
15+
in-matrix InMat2,
16+
out-matrix OutMat>
17+
void symmetric_matrix_product(
18+
InMat1 A,
19+
Triangle t,
20+
InMat2 B,
21+
OutMat C); // (1)
22+
23+
template<class ExecutionPolicy,
24+
in-matrix InMat1,
25+
class Triangle,
26+
in-matrix InMat2,
27+
out-matrix OutMat>
28+
void symmetric_matrix_product(
29+
ExecutionPolicy&& exec,
30+
InMat1 A,
31+
Triangle t,
32+
InMat2 B,
33+
OutMat C); // (2)
34+
35+
template<in-matrix InMat1,
36+
in-matrix InMat2,
37+
class Triangle,
38+
out-matrix OutMat>
39+
void symmetric_matrix_product(
40+
InMat1 A,
41+
InMat2 B,
42+
Triangle t,
43+
OutMat C); // (3)
44+
45+
template<class ExecutionPolicy,
46+
in-matrix InMat1,
47+
in-matrix InMat2,
48+
class Triangle,
49+
out-matrix OutMat>
50+
void symmetric_matrix_product(
51+
ExecutionPolicy&& exec,
52+
InMat1 A,
53+
InMat2 B,
54+
Triangle t,
55+
OutMat C); // (4)
56+
57+
template<in-matrix InMat1,
58+
class Triangle,
59+
in-matrix InMat2,
60+
in-matrix InMat3,
61+
out-matrix OutMat>
62+
void symmetric_matrix_product(
63+
InMat1 A,
64+
Triangle t,
65+
InMat2 B,
66+
InMat3 E,
67+
OutMat C); // (5)
68+
69+
template<class ExecutionPolicy,
70+
in-matrix InMat1,
71+
class Triangle,
72+
in-matrix InMat2,
73+
in-matrix InMat3,
74+
out-matrix OutMat>
75+
void symmetric_matrix_product(
76+
ExecutionPolicy&& exec,
77+
InMat1 A,
78+
Triangle t,
79+
InMat2 B,
80+
InMat3 E,
81+
OutMat C); // (6)
82+
83+
template<in-matrix InMat1,
84+
in-matrix InMat2,
85+
class Triangle,
86+
in-matrix InMat3,
87+
out-matrix OutMat>
88+
void symmetric_matrix_product(
89+
InMat1 A,
90+
InMat2 B,
91+
Triangle t,
92+
InMat3 E,
93+
OutMat C); // (7)
94+
95+
template<class ExecutionPolicy,
96+
in-matrix InMat1,
97+
in-matrix InMat2,
98+
class Triangle,
99+
in-matrix InMat3,
100+
out-matrix OutMat>
101+
void symmetric_matrix_product(
102+
ExecutionPolicy&& exec,
103+
InMat1 A,
104+
InMat2 B,
105+
Triangle t,
106+
InMat3 E,
107+
OutMat C); // (8)
108+
}
109+
```
110+
111+
112+
## 概要
113+
行列同士の積を計算する。
114+
115+
- (1): 三角行列`A`と行列`B`に対し、$C \leftarrow AB$
116+
- (2): (1)を指定された実行ポリシーで実行する。
117+
- (3): 行列`A`と三角行列`B`に対し、$C \leftarrow AB$
118+
- (4): (3)を指定された実行ポリシーで実行する。
119+
- (5): 三角行列`A`と行列`B`に対し、$C \leftarrow E + AB$
120+
- (6): (5)を指定された実行ポリシーで実行する。
121+
- (7): 行列`A`と三角行列`B`に対し、$C \leftarrow E + AB$
122+
- (8): (7)を指定された実行ポリシーで実行する。
123+
124+
125+
## 適格要件
126+
- 共通
127+
+ `Triangle`は[`upper_triangle_t`](upper_triangle_t.md)または[`lower_triangle_t`](lower_triangle_t.md)
128+
+ [`possibly-multipliable`](possibly-multipliable.md)`<decltype(A), decltype(B), decltype(C)>()`が`true`
129+
- (1), (2), (5), (6): `InMat1`(`A`の型)が[`layout_blas_packed`](layout_blas_packed.md)を持つなら、レイアウトの`Triangle`テンプレート引数とこの関数の`Triangle`テンプレート引数が同じ型
130+
- (1), (2), (5), (6): [`compatible-static-extents`](compatible-static-extents.md)`<decltype(A), decltype(A)>(0, 1)`が`true` (つまり`A`が正方行列であること)
131+
- (3), (4), (7), (8): `InMat2`(`B`の型)が[`layout_blas_packed`](layout_blas_packed.md)を持つなら、レイアウトの`Triangle`テンプレート引数とこの関数の`Triangle`テンプレート引数が同じ型
132+
- (3), (4), (7), (8): [`compatible-static-extents`](compatible-static-extents.md)`<decltype(B), decltype(B)>(0, 1)`が`true` (つまり`B`が正方行列であること)
133+
- (5), (6), (7), (8): [`possibly-addable`](possibly-addable.md)`<decltype(E),decltype(E),decltype(C)>()`が`true`
134+
135+
136+
## 事前条件
137+
- 共通
138+
+ [`multipliable`](multipliable.md)`(A, B, C) == true`
139+
- (1), (2), (5), (6): [`A.extent(0) == A.extent(1)`]
140+
- (3), (4), (7), (8): [`B.extent(0) == B.extent(1)`]
141+
- (5), (6), (7), (8): [`addable`](addable.md)`(E, E, C) == true`
142+
143+
144+
## 効果
145+
- (1), (2): 三角行列`A`と行列`B`に対し、$C \leftarrow AB$
146+
- (3), (4): 行列`A`と三角行列`B`に対し、$C \leftarrow AB$
147+
- (5), (6): 三角行列`A`と行列`B`に対し、$C \leftarrow E + AB$
148+
- (7), (8): 行列`A`と三角行列`B`に対し、$C \leftarrow E + AB$
149+
150+
151+
## 戻り値
152+
なし
153+
154+
155+
## 計算量
156+
$O(\verb|A.extent(0)| \times \verb|A.extent(1)| \times \verb|B.extent(1)|)$
157+
158+
159+
## 備考
160+
- (5), (6), (7), (8): `C`に`E`を入れても良い。
161+
162+
163+
## 例
164+
**[注意] 処理系にあるコンパイラで確認していないため、間違っているかもしれません。**
165+
166+
```cpp example
167+
#include <array>
168+
#include <iostream>
169+
#include <linalg>
170+
#include <mdspan>
171+
#include <vector>
172+
173+
template <class Matrix>
174+
void print_mat(const Matrix& A) {
175+
for(int i = 0; i < A.extent(0); ++i) {
176+
for(int j = 0; j < A.extent(1) - 1; ++j) {
177+
std::cout << A[i, j] << ' ';
178+
}
179+
std::cout << A[i, A.extent(1) - 1] << '\n';
180+
}
181+
}
182+
183+
template <class Matrix>
184+
void init_mat(Matrix& A, typename Matrix::value_type geta = 1) {
185+
for(int i = 0; i < A.extent(0); ++i) {
186+
for(int j = 0; j < A.extent(1); ++j) {
187+
A[i, j] = i * A.extent(1) + j + geta;
188+
}
189+
}
190+
}
191+
192+
template <class Matrix>
193+
void init_symm_mat(Matrix& A) {
194+
for(int i = 0; i < A.extent(0); ++i) {
195+
for(int j = i; j < A.extent(1); ++j) {
196+
A[i, j] = i * A.extent(1) + j;
197+
}
198+
}
199+
}
200+
201+
int main()
202+
{
203+
constexpr size_t N = 2;
204+
205+
std::vector<double> A_vec(N * N);
206+
std::vector<double> B_vec(N * N);
207+
std::vector<double> C_vec(N * N);
208+
std::vector<double> E_vec(N * N);
209+
210+
std::mdspan C(C_vec.data(), N, N);
211+
std::mdspan E(E_vec.data(), N, N);
212+
213+
init_mat(E, N * N);
214+
215+
{
216+
std::mdspan<
217+
double,
218+
std::extents<size_t, N, N>,
219+
std::linalg::layout_blas_packed<
220+
std::linalg::upper_triangle_t,
221+
std::linalg::row_major_t>
222+
> A(A_vec.data());
223+
std::mdspan B(B_vec.data(), N, N);
224+
225+
init_symm_mat(A);
226+
init_mat(B);
227+
228+
// (1)
229+
std::cout << "(1)\n";
230+
std::linalg::symmetric_matrix_product(A, std::linalg::upper_triangle, B, C);
231+
print_mat(C);
232+
233+
// (2)
234+
std::cout << "(2)\n";
235+
std::linalg::symmetric_matrix_product(std::execution::par, A, std::linalg::upper_triangle, B, C);
236+
print_mat(C);
237+
}
238+
239+
{
240+
std::mdspan A(A_vec.data(), N, N);
241+
std::mdspan<
242+
double,
243+
std::extents<size_t, N, N>,
244+
std::linalg::layout_blas_packed<
245+
std::linalg::upper_triangle_t,
246+
std::linalg::row_major_t>
247+
> B(B_vec.data());
248+
249+
init_mat(A);
250+
init_symm_mat(B);
251+
252+
// (3)
253+
std::cout << "(3)\n";
254+
std::linalg::symmetric_matrix_product(A, B, std::linalg::upper_triangle, C);
255+
print_mat(C);
256+
257+
// (4)
258+
std::cout << "(4)\n";
259+
std::linalg::symmetric_matrix_product(std::execution::par, A, B, std::linalg::upper_triangle, C);
260+
print_mat(C);
261+
}
262+
263+
{
264+
std::mdspan<
265+
double,
266+
std::extents<size_t, N, N>,
267+
std::linalg::layout_blas_packed<
268+
std::linalg::upper_triangle_t,
269+
std::linalg::row_major_t>
270+
> A(A_vec.data());
271+
std::mdspan B(B_vec.data(), N, N);
272+
273+
init_symm_mat(A);
274+
init_mat(B);
275+
276+
// (5)
277+
std::cout << "(5)\n";
278+
std::linalg::symmetric_matrix_product(A, std::linalg::upper_triangle, B, E, C);
279+
print_mat(C);
280+
281+
// (6)
282+
std::cout << "(6)\n";
283+
std::linalg::symmetric_matrix_product(std::execution::par, A, std::linalg::upper_triangle, B, E, C);
284+
print_mat(C);
285+
}
286+
287+
{
288+
std::mdspan A(A_vec.data(), N, N);
289+
std::mdspan<
290+
double,
291+
std::extents<size_t, N, N>,
292+
std::linalg::layout_blas_packed<
293+
std::linalg::upper_triangle_t,
294+
std::linalg::row_major_t>
295+
> B(B_vec.data());
296+
297+
init_mat(A);
298+
init_symm_mat(B);
299+
300+
// (7)
301+
std::cout << "(7)\n";
302+
std::linalg::symmetric_matrix_product(A, B, std::linalg::upper_triangle, E, C);
303+
print_mat(C);
304+
305+
// (8)
306+
std::cout << "(8)\n";
307+
std::linalg::symmetric_matrix_product(std::execution::par, A, B, std::linalg::upper_triangle, E, C);
308+
print_mat(C);
309+
}
310+
311+
return 0;
312+
}
313+
```
314+
315+
316+
### 出力
317+
```
318+
(1)
319+
7 8
320+
11 16
321+
(2)
322+
7 8
323+
11 16
324+
(3)
325+
5 8
326+
11 18
327+
(4)
328+
5 8
329+
11 18
330+
(5)
331+
11 13
332+
17 23
333+
(6)
334+
11 13
335+
17 23
336+
(7)
337+
9 13
338+
17 25
339+
(8)
340+
9 13
341+
17 25
342+
```
343+
344+
345+
## バージョン
346+
### 言語
347+
- C++26
348+
349+
### 処理系
350+
- [Clang](/implementation.md#clang): ??
351+
- [GCC](/implementation.md#gcc): ??
352+
- [ICC](/implementation.md#icc): ??
353+
- [Visual C++](/implementation.md#visual_cpp): ??
354+
355+
356+
## 関連項目
357+
- [`execution`](/reference/execution.md)
358+
- [`mdspan`](/reference/mdspan.md)
359+
- [`upper_triangle_t`](upper_triangle_t.md)
360+
- [`lower_triangle_t`](lower_triangle_t.md)
361+
362+
363+
## 参照
364+
- [P1673R13 A free function linear algebra interface based on the BLAS](https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p1673r13.html)
365+
- [LAPACK: {he,sy}mm: Hermitian/symmetric matrix-matrix multiply](https://netlib.org/lapack/explore-html/d0/d16/group__hemm.html)

0 commit comments

Comments
 (0)