Skip to content

Commit

Permalink
Merge 0dbea62 into 4381df8
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Aug 21, 2019
2 parents 4381df8 + 0dbea62 commit e955e34
Show file tree
Hide file tree
Showing 10 changed files with 570 additions and 0 deletions.
30 changes: 30 additions & 0 deletions chainerx/_docs/routines.py
Expand Up @@ -665,6 +665,36 @@ def _docs_linalg():
.. seealso:: :func:`numpy.linalg.pinv`
""")

_docs.set_doc(
chainerx.linalg.qr,
"""qr(a, mode='reduced')
Compute the qr factorization of a matrix.
Factor the matrix ``a`` as *qr*, where ``q`` is orthonormal and ``r`` is
upper-triangular.
Args:
a (~chainerx.ndarray): Matrix to be factored.
mode (str): The mode of decomposition.
'reduced' : returns q, r with dimensions (M, K), (K, N) (default)
'complete' : returns q, r with dimensions (M, M), (M, N)
'r' : returns r only with dimensions (K, N)
'raw' : returns h, tau with dimensions (N, M), (K,),
where ``(M, N)`` is the shape of the input matrix and ``K = min(M, N)``
Returns:
q (~chainerx.ndarray): A matrix with orthonormal columns.
r (~chainerx.ndarray): The upper-triangular matrix.
Note:
* The ``dtype`` must be ``float32`` or ``float64`` (``float16`` is not
supported yet.)
* Backpropagation is not implemented for non-square output matrix ``r``.
* Backpropagation is not implemented for 'r' or 'raw' modes.
.. seealso:: :func:`numpy.linalg.qr`
""")


def _docs_logic():
_docs.set_doc(
Expand Down
3 changes: 3 additions & 0 deletions chainerx/linalg/__init__.pyi
Expand Up @@ -13,3 +13,6 @@ def solve(a: ndarray, b: ndarray) -> ndarray: ...
def svd(a: ndarray,
full_matrices: bool=...,
compute_uv: bool=...) -> tp.Union[tp.Tuple[ndarray, ndarray, ndarray], ndarray]: ...


def qr(a: ndarray) -> tp.Union[tp.Tuple[ndarray, ndarray], ndarray]: ...
194 changes: 194 additions & 0 deletions chainerx_cc/chainerx/cuda/cuda_device/linalg.cu
Expand Up @@ -90,6 +90,46 @@ cusolverStatus_t Gesvd(
throw DtypeError{"Only Arrays of float or double type are supported by gesvd (SVD)"};
}

template <typename T>
cusolverStatus_t GeqrfBufferSize(cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, T* /*a*/, int /*lda*/, int* /*lwork*/) {
throw DtypeError{"Only Arrays of float or double type are supported by geqrf (QR)"};
}

template <typename T>
cusolverStatus_t Geqrf(
cusolverDnHandle_t /*handle*/,
int /*m*/,
int /*n*/,
T* /*a*/,
int /*lda*/,
T* /*tau*/,
T* /*workspace*/,
int /*lwork*/,
int* /*devinfo*/) {
throw DtypeError{"Only Arrays of float or double type are supported by geqrf (QR)"};
}

template <typename T>
cusolverStatus_t OrgqrBufferSize(
cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, int /*k*/, T* /*a*/, int /*lda*/, T* /*tau*/, int* /*lwork*/) {
throw DtypeError{"Only Arrays of float or double type are supported by orgqr (QR)"};
}

template <typename T>
cusolverStatus_t Orgqr(
cusolverDnHandle_t /*handle*/,
int /*m*/,
int /*n*/,
int /*k*/,
T* /*a*/,
int /*lda*/,
T* /*tau*/,
T* /*work*/,
int /*lwork*/,
int* /*devinfo*/) {
throw DtypeError{"Only Arrays of float or double type are supported by orgqr (QR)"};
}

template <>
cusolverStatus_t GetrfBuffersize<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, int* lwork) {
return cusolverDnDgetrf_bufferSize(handle, m, n, a, lda, lwork);
Expand Down Expand Up @@ -192,6 +232,50 @@ cusolverStatus_t Gesvd<float>(
return cusolverDnSgesvd(handle, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, devinfo);
}

template <>
cusolverStatus_t GeqrfBufferSize<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, int* lwork) {
return cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, lwork);
}

template <>
cusolverStatus_t GeqrfBufferSize<float>(cusolverDnHandle_t handle, int m, int n, float* a, int lda, int* lwork) {
return cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, lwork);
}

template <>
cusolverStatus_t Geqrf<double>(
cusolverDnHandle_t handle, int m, int n, double* a, int lda, double* tau, double* workspace, int lwork, int* devinfo) {
return cusolverDnDgeqrf(handle, m, n, a, lda, tau, workspace, lwork, devinfo);
}

template <>
cusolverStatus_t Geqrf<float>(
cusolverDnHandle_t handle, int m, int n, float* a, int lda, float* tau, float* workspace, int lwork, int* devinfo) {
return cusolverDnSgeqrf(handle, m, n, a, lda, tau, workspace, lwork, devinfo);
}

template <>
cusolverStatus_t OrgqrBufferSize<double>(cusolverDnHandle_t handle, int m, int n, int k, double* a, int lda, double* tau, int* lwork) {
return cusolverDnDorgqr_bufferSize(handle, m, n, k, a, lda, tau, lwork);
}

template <>
cusolverStatus_t OrgqrBufferSize<float>(cusolverDnHandle_t handle, int m, int n, int k, float* a, int lda, float* tau, int* lwork) {
return cusolverDnSorgqr_bufferSize(handle, m, n, k, a, lda, tau, lwork);
}

template <>
cusolverStatus_t Orgqr<double>(
cusolverDnHandle_t handle, int m, int n, int k, double* a, int lda, double* tau, double* work, int lwork, int* devinfo) {
return cusolverDnDorgqr(handle, m, n, k, a, lda, tau, work, lwork, devinfo);
}

template <>
cusolverStatus_t Orgqr<float>(
cusolverDnHandle_t handle, int m, int n, int k, float* a, int lda, float* tau, float* work, int lwork, int* devinfo) {
return cusolverDnSorgqr(handle, m, n, k, a, lda, tau, work, lwork, devinfo);
}

template <typename T>
void SolveImpl(const Array& a, const Array& b, const Array& out) {
Device& device = a.device();
Expand Down Expand Up @@ -243,6 +327,98 @@ void SolveImpl(const Array& a, const Array& b, const Array& out) {
device.backend().CallKernel<CopyKernel>(out_transposed.Transpose(), out);
}

template <typename T>
void QrImpl(const Array& a, const Array& q, const Array& r, const Array& tau, QrMode mode) {
Device& device = a.device();
Dtype dtype = a.dtype();

int64_t m = a.shape()[0];
int64_t n = a.shape()[1];
int64_t k = std::min(m, n);
int64_t lda = std::max(int64_t{1}, m);

// cuSOLVER does not return correct result in this case and older versions of cuSOLVER (<10.1)
// might not work well with zero-sized arrays therefore it's better to return earlier
if (a.shape().GetTotalSize() == 0) {
if (mode == QrMode::kComplete) {
device.backend().CallKernel<IdentityKernel>(q);
}
return;
}

Array r_temp = a.Transpose().Copy(); // QR decomposition is done in-place

cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));

auto r_ptr = static_cast<T*>(internal::GetRawOffsetData(r_temp));
auto tau_ptr = static_cast<T*>(internal::GetRawOffsetData(tau));

std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));

int buffersize_geqrf = 0;
device_internals.cusolverdn_handle().Call(GeqrfBufferSize<T>, m, n, r_ptr, lda, &buffersize_geqrf);

Array work = Empty(Shape{buffersize_geqrf}, dtype, device);
auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));

device_internals.cusolverdn_handle().Call(
Geqrf<T>, m, n, r_ptr, lda, tau_ptr, work_ptr, buffersize_geqrf, static_cast<int*>(devinfo.get()));

int devinfo_h = 0;
Device& native_device = GetDefaultContext().GetDevice({"native", 0});
device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
if (devinfo_h != 0) {
throw ChainerxError{"Unsuccessful geqrf (QR) execution. Info = ", devinfo_h};
}

if (mode == QrMode::kR) {
r_temp = r_temp.At(std::vector<ArrayIndex>{Slice{}, Slice{0, k}}).Transpose(); // R = R[:, 0:k].T
r_temp = Triu(r_temp, 0);
device.backend().CallKernel<CopyKernel>(r_temp, r);
return;
}

if (mode == QrMode::kRaw) {
device.backend().CallKernel<CopyKernel>(r_temp, r);
return;
}

int64_t mc;
Shape q_shape{0};
if (mode == QrMode::kComplete && m > n) {
mc = m;
q_shape = Shape{m, m};
} else {
mc = k;
q_shape = Shape{n, m};
}
Array q_temp = Empty(q_shape, dtype, device);

device.backend().CallKernel<CopyKernel>(r_temp, q_temp.At(std::vector<ArrayIndex>{Slice{0, n}, Slice{}})); // Q[0:n, :] = R
auto q_ptr = static_cast<T*>(internal::GetRawOffsetData(q_temp));

int buffersize_orgqr = 0;
device_internals.cusolverdn_handle().Call(OrgqrBufferSize<T>, m, mc, k, q_ptr, lda, tau_ptr, &buffersize_orgqr);

Array work_orgqr = Empty(Shape{buffersize_orgqr}, dtype, device);
auto work_orgqr_ptr = static_cast<T*>(internal::GetRawOffsetData(work_orgqr));

device_internals.cusolverdn_handle().Call(
Orgqr<T>, m, mc, k, q_ptr, lda, tau_ptr, work_orgqr_ptr, buffersize_orgqr, static_cast<int*>(devinfo.get()));

device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
if (devinfo_h != 0) {
throw ChainerxError{"Unsuccessful orgqr (QR) execution. Info = ", devinfo_h};
}

q_temp = q_temp.At(std::vector<ArrayIndex>{Slice{0, mc}, Slice{}}).Transpose(); // Q = Q[0:mc, :].T
r_temp = r_temp.At(std::vector<ArrayIndex>{Slice{}, Slice{0, mc}}).Transpose(); // R = R[:, 0:mc].T
r_temp = Triu(r_temp, 0);

device.backend().CallKernel<CopyKernel>(q_temp, q);
device.backend().CallKernel<CopyKernel>(r_temp, r);
}

} // namespace

class CudaSolveKernel : public SolveKernel {
Expand Down Expand Up @@ -409,5 +585,23 @@ public:

CHAINERX_CUDA_REGISTER_KERNEL(SvdKernel, CudaSvdKernel);

class CudaQrKernel : public QrKernel {
public:
void Call(const Array& a, const Array& q, const Array& r, const Array& tau, QrMode mode) override {
Device& device = a.device();
Dtype dtype = a.dtype();
CudaSetDeviceScope scope{device.index()};

CHAINERX_ASSERT(a.ndim() == 2);

VisitFloatingPointDtype(dtype, [&](auto pt) {
using T = typename decltype(pt)::type;
QrImpl<T>(a, q, r, tau, mode);
});
}
};

CHAINERX_CUDA_REGISTER_KERNEL(QrKernel, CudaQrKernel);

} // namespace cuda
} // namespace chainerx
8 changes: 8 additions & 0 deletions chainerx_cc/chainerx/kernels/linalg.h
Expand Up @@ -4,6 +4,7 @@

#include "chainerx/array.h"
#include "chainerx/kernel.h"
#include "chainerx/routines/linalg.h"

namespace chainerx {

Expand Down Expand Up @@ -39,4 +40,11 @@ class SvdKernel : public Kernel {
virtual void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices) = 0;
};

class QrKernel : public Kernel {
public:
static const char* name() { return "Qr"; }

virtual void Call(const Array& a, const Array& q, const Array& r, const Array& tau, QrMode mode) = 0;
};

} // namespace chainerx

0 comments on commit e955e34

Please sign in to comment.