You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In Pytorch, most linear algebra routines can be applied to batches of matrices. In Heat it would be nice to have a similar rule how to "define" linear algebra (matmul,qr,cholesky,svd etc.) for DNDarray's with dimension exceeding 2.
To do so, follow the PyTorch convention (see, e.g., https://pytorch.org/docs/stable/generated/torch.linalg.solve.html#torch.linalg.solve or also the code in #893): for arrays with ndim>=3 we regard the last two dimensions (in the case of matrices) or the last dimension (in the case of vectors) of the array as "linear algebra-dimensions" and the indices of the first ndim-2 (matrices) or ndim-1 (vectors) dimensions as indices of a "batch of linear algebra objects"
"minimal solution" (data/batch-parallel higher-dim matmul only):
introduce batched matmul for ndim>=3 for the case that the split-dimension is a batch-dimension, i.e. no parallel matrix-matrix multiplication algorithm is required
"full solution": go through the existing matmul-code and adapt it in such a way that you add batch-dimensions in order to handle the case of split equal to one of the two last dimensions that has been omitted in the "mimal solution"
Note: qr, svd, cholesky have been split off to #1320
The text was updated successfully, but these errors were encountered:
mrfh92
changed the title
Discuss and implement consistent linear algebra for arrays with dimension > 2
Implement consistent linear algebra for arrays with dimension > 2, in particular matmulAug 17, 2023
In Pytorch, most linear algebra routines can be applied to batches of matrices. In Heat it would be nice to have a similar rule how to "define" linear algebra (
matmul
,qr
,cholesky
,svd
etc.) forDNDarray
's with dimension exceeding 2.To do so, follow the PyTorch convention (see, e.g., https://pytorch.org/docs/stable/generated/torch.linalg.solve.html#torch.linalg.solve or also the code in #893): for arrays with
ndim>=3
we regard the last two dimensions (in the case of matrices) or the last dimension (in the case of vectors) of the array as "linear algebra-dimensions" and the indices of the firstndim-2
(matrices) orndim-1
(vectors) dimensions as indices of a "batch of linear algebra objects"Particular example: see
matmul
(#890 )Suggestion:
introduce batched matmul for
ndim>=3
for the case that thesplit
-dimension is a batch-dimension, i.e. no parallel matrix-matrix multiplication algorithm is requiredsplit
equal to one of the two last dimensions that has been omitted in the "mimal solution"Note: qr, svd, cholesky have been split off to #1320
The text was updated successfully, but these errors were encountered: