Skip to content

Support batched SVD #4628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Feb 18, 2021
Merged

Support batched SVD #4628

merged 34 commits into from
Feb 18, 2021

Conversation

leofang
Copy link
Member

@leofang leofang commented Feb 5, 2021

Close #3470. Part of #3062. Another shot based on #3247 to avoid manual looping.
UPDATE: See the below discussions in this PR.

Prepare for adopting Array API (data-apis/array-api#114).

Code path divergence:

  • On CUDA: Call gesvdj if m,n <= 32, else call gesvd in a C loop
  • On ROCm: Call gesvd as it has no limitation in the input size

TODO:

  • Support integers
  • Support zero-size (empty) input
  • Docstring
  • Test on HIP
  • Benchmark

@leofang
Copy link
Member Author

leofang commented Feb 5, 2021

...*gesvdjBatched() only support a stack of matrices with dimension < (32, 32) on either side... @toslunar Looks like we have to do a manual looping anyway for general cases?

@leofang
Copy link
Member Author

leofang commented Feb 5, 2021

I should have read carefully the discussion in #2337 first 😅 Looks like I was duplicating the work...

@leofang
Copy link
Member Author

leofang commented Feb 5, 2021

...*gesvdjBatched() only support a stack of matrices with dimension < (32, 32) on either side...

I am testing rocsolver_<t>gesvd_batched and it doesn't seem to have this limitation, so I'll try to test & support it in this PR.

u.data.ptr, ldu, v.data.ptr, ldv, params, batch_size)
work = _cupy.empty(lwork, dtype=a.dtype)
info = _cupy.empty(1, dtype=_numpy.int32)
solver(handle, jobz, m, n, a.data.ptr, lda, s.data.ptr,
info = _cupy.empty(batch_size, dtype=_numpy.int32)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug: both cuSOLVER and rocSOLVER need the info array to be of the batch size. I thought we fixed it earlier...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backported this fix to v8 (#4747).

@leofang

This comment has been minimized.

@leofang

This comment has been minimized.

@leofang
Copy link
Member Author

leofang commented Feb 8, 2021

On CUDA: the manual looping follows what's done in JAX: https://github.com/google/jax/blob/8bf3f032989caddf05c702acc1ff5353abbe72d2/jaxlib/cusolver.cc#L937-L947
It is made possible here by cythonizing cupy/cusolver.py so that we can speed up the loop using Cython's constructs. But the caveat here is that by calling the <t>gesvd() wrapper we release and acquire the gil in every iteration, which isn't optimal. UPDATE: this is fixed now, see #4628 (comment).

@toslunar toslunar self-assigned this Feb 8, 2021
@leofang

This comment has been minimized.

@leofang
Copy link
Member Author

leofang commented Feb 17, 2021

Jenkins, test this please

@emcastillo
Copy link
Member

@anaruse can you please take a look at this? 😇

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 3515fb9, target branch master) succeeded!

Comment on lines +66 to +69
A += m * n;
S += k;
U += m * m;
VT += n * n;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it assumed that (lda, ldu, ldvt) == (m, m, n)?
Keeping the signature of gesvd is nice, but I'd like to see the wrapper is trivially correct.

#4628 (comment)

Copy link
Member Author

@leofang leofang Feb 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when the wrapper is called this is ensured. How about removing lda etc from the signature and just using plain m, n?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about removing lda etc from the signature and just using plain m, n?

Done in ce32d03.

a_gpu_usv = cupy.matmul(u_gpu[..., :k] * s_gpu[..., None, :],
vh_gpu[..., :k, :])
else:
a_gpu_usv = cupy.matmul(u_gpu*s_gpu[..., None, :], vh_gpu)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with calling matmul regardless of shape.

# copy (via possible type casting) is done in _gesvd_batched
out = _gesvd_batched(a, a_dtype, full_matrices, compute_uv, False)
if compute_uv:
u, s, v = out
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I read the code first, I missed _gesvd_batched returns v instead of vt. Adding a comment to the docstring of _gesvd_batched seems sufficient for now.

Copy link
Member Author

@leofang leofang Feb 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, @toslunar! There was actually a bug related to this, which I fixed in 62e11ac. Yes, it returns V instead of V^H (transpose conjugate), in order to align with _gesvdj_batched.

@leofang
Copy link
Member Author

leofang commented Feb 17, 2021

Check again

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 3515fb9, target branch master) succeeded!

@leofang
Copy link
Member Author

leofang commented Feb 17, 2021

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 5a0376a, target branch master) succeeded!

@leofang
Copy link
Member Author

leofang commented Feb 18, 2021

Jenkins, test this please

Copy link
Member

@toslunar toslunar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@toslunar toslunar added this to the v9.0.0b3 milestone Feb 18, 2021
@chainer-ci
Copy link
Member

Jenkins CI test (for commit 5a0376a, target branch master) failed with status FAILURE.

@toslunar
Copy link
Member

The failed test on Jenkins is known to be flaky (#4673).

@toslunar toslunar merged commit bdc6072 into cupy:master Feb 18, 2021
@leofang leofang deleted the batched_svd branch February 18, 2021 13:16
@leofang
Copy link
Member Author

leofang commented Feb 18, 2021

Thanks, @toslunar!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support batched SVD (cupy.linalg.svd)
5 participants