Skip to content

Commit

Permalink
Return earlier when n==0 and mode=='complete'
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Aug 19, 2019
1 parent a864aa9 commit 1874f05
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions chainerx_cc/chainerx/cuda/cuda_device/linalg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ void QrImpl(const Array& a, const Array& q, const Array& r, const Array& tau, Qr
int64_t k = std::min(m, n);
int64_t lda = std::max(int64_t{1}, m);

// cuSOLVER does not return correct result in this case
if (mode == QrMode::kComplete && n == 0) {
device.backend().CallKernel<IdentityKernel>(q);
return;
}

Array r_temp = a.Transpose().Copy(); // QR decomposition is done in-place

cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
Expand Down Expand Up @@ -385,12 +391,6 @@ void QrImpl(const Array& a, const Array& q, const Array& r, const Array& tau, Qr
}
Array q_temp = Empty(q_shape, dtype, device);

// cuSOLVER does not return correct result in this case
if (mode == QrMode::kComplete && n == 0) {
device.backend().CallKernel<IdentityKernel>(q);
return;
}

device.backend().CallKernel<CopyKernel>(r_temp, q_temp.At(std::vector<ArrayIndex>{Slice{0, n}, Slice{}})); // Q[0:n, :] = R
auto q_ptr = static_cast<T*>(internal::GetRawOffsetData(q_temp));

Expand Down

0 comments on commit 1874f05

Please sign in to comment.