-
-
Notifications
You must be signed in to change notification settings - Fork 944
Support batched SVD #4628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support batched SVD #4628
Conversation
|
... |
|
I should have read carefully the discussion in #2337 first 😅 Looks like I was duplicating the work... |
I am testing |
Though the result is still incorrect...
| u.data.ptr, ldu, v.data.ptr, ldv, params, batch_size) | ||
| work = _cupy.empty(lwork, dtype=a.dtype) | ||
| info = _cupy.empty(1, dtype=_numpy.int32) | ||
| solver(handle, jobz, m, n, a.data.ptr, lda, s.data.ptr, | ||
| info = _cupy.empty(batch_size, dtype=_numpy.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bug: both cuSOLVER and rocSOLVER need the info array to be of the batch size. I thought we fixed it earlier...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backported this fix to v8 (#4747).
...but runs almost as fast as a pure Python loop...
avoid slicing overhead, check info in the end, etc
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
|
On CUDA: the manual looping follows what's done in JAX: https://github.com/google/jax/blob/8bf3f032989caddf05c702acc1ff5353abbe72d2/jaxlib/cusolver.cc#L937-L947 |
This comment has been minimized.
This comment has been minimized.
|
Jenkins, test this please |
|
@anaruse can you please take a look at this? 😇 |
|
Jenkins CI test (for commit 3515fb9, target branch master) succeeded! |
| A += m * n; | ||
| S += k; | ||
| U += m * m; | ||
| VT += n * n; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it assumed that (lda, ldu, ldvt) == (m, m, n)?
Keeping the signature of gesvd is nice, but I'd like to see the wrapper is trivially correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when the wrapper is called this is ensured. How about removing lda etc from the signature and just using plain m, n?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about removing
ldaetc from the signature and just using plainm,n?
Done in ce32d03.
| a_gpu_usv = cupy.matmul(u_gpu[..., :k] * s_gpu[..., None, :], | ||
| vh_gpu[..., :k, :]) | ||
| else: | ||
| a_gpu_usv = cupy.matmul(u_gpu*s_gpu[..., None, :], vh_gpu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with calling matmul regardless of shape.
| # copy (via possible type casting) is done in _gesvd_batched | ||
| out = _gesvd_batched(a, a_dtype, full_matrices, compute_uv, False) | ||
| if compute_uv: | ||
| u, s, v = out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I read the code first, I missed _gesvd_batched returns v instead of vt. Adding a comment to the docstring of _gesvd_batched seems sufficient for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Check again Jenkins, test this please |
|
Jenkins CI test (for commit 3515fb9, target branch master) succeeded! |
Co-authored-by: Toshiki Kataoka <kataoka@preferred.jp>
|
Jenkins, test this please |
|
Jenkins CI test (for commit 5a0376a, target branch master) succeeded! |
|
Jenkins, test this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Jenkins CI test (for commit 5a0376a, target branch master) failed with status FAILURE. |
|
The failed test on Jenkins is known to be flaky (#4673). |
|
Thanks, @toslunar! |
Close #3470. Part of #3062. Another shot based on #3247
to avoid manual looping.UPDATE: See the below discussions in this PR.
Prepare for adopting Array API (data-apis/array-api#114).
Code path divergence:
TODO: