-
-
Notifications
You must be signed in to change notification settings - Fork 780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support batched QR solver #5583
Conversation
Jenkins, test this please |
Jenkins CI test (for commit 777bd74, target branch master) succeeded! |
I looked at the perf numbers and compared the execution times of both Script
import numpy as np
import cupy as cp
from cupy._core.internal import prod
from cupyx.time import repeat
s = [(256, 256), (4, 256, 256), (16, 256, 256), (64, 256, 256), (256, 256, 256)]
t = [np.float32, np.complex64]
def np_timing(a):
out = []
batch_size = prod(a.shape[:-2])
a = a.reshape(batch_size, *a.shape[-2:])
for i in range(batch_size):
out.append(np.linalg.qr(a[i]))
return out
for dtype in t:
for shape in s:
a = np.random.random(shape).astype(dtype)
perfs = []
for xp in (np, cp):
a = xp.asarray(a)
if xp is np:
func = np_timing
else:
func = cp.linalg.qr
perf = repeat(func, (a,), n_repeat=20, name=f"{'numpy' if xp is np else 'cupy'} QR")
perfs.append(perf)
print(f'shape={shape}, dtype={dtype}, speedup (np/cp)={perfs[0].gpu_times.mean()/perfs[1].gpu_times.mean()}') Output (CUDA 11.2 + 2080 Ti):
In fact, the speedup increases slightly as the batch size increases, but presumably it's because this PR loops in C while I looped |
Apparently rocSOLVER performs QR decomposition very badly even with a single matrix...Way slower than NumPy does lol Output (ROCm 4.2.0 + Radeon VII):
Output (ROCm 4.2.0 + MI50):
cc: @amathews-amd |
The CPU side took quite some time, yes, but the GPU spent much longer time:
|
Jenkins, test this please |
Jenkins CI test (for commit bb4bdbb, target branch master) succeeded! |
The ROCm 4.2 CI will be fixed once the gpg key is updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
`_util._assert_rank2` in `qr` was moved by cupy#5583.
Close #4986.
To prepare for the upcoming NumPy 1.22 (numpy/numpy#19151) and the Array API standard.
UPDATE: I made some educated guesses about zero-size input. But before NumPy 1.22 is out so that we can actually test all corner cases the validity remains to be seen. As a result, I require those tests to be run on NumPy 1.22+, and added an experimental warning if batching is in use.
TODO: