Skip to content

Commit

Permalink
Merge 15b7eb6 into e7f8d8b
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Aug 26, 2019
2 parents e7f8d8b + 15b7eb6 commit 8eb830f
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 22 deletions.
25 changes: 17 additions & 8 deletions chainerx_cc/chainerx/cuda/cuda_device/linalg.cu
Expand Up @@ -288,6 +288,7 @@ void SolveImpl(const Array& a, const Array& b, const Array& out) {
auto lu_ptr = static_cast<T*>(internal::GetRawOffsetData(lu_matrix));

int64_t m = a.shape()[0];
int64_t lda = std::max(int64_t{1}, m);
int64_t nrhs = 1;
if (b.ndim() == 2) {
nrhs = b.shape()[1];
Expand All @@ -297,14 +298,14 @@ void SolveImpl(const Array& a, const Array& b, const Array& out) {
auto ipiv_ptr = static_cast<int*>(internal::GetRawOffsetData(ipiv));

int buffersize = 0;
device_internals.cusolverdn_handle().Call(GetrfBuffersize<T>, m, m, lu_ptr, m, &buffersize);
device_internals.cusolverdn_handle().Call(GetrfBuffersize<T>, m, m, lu_ptr, lda, &buffersize);

Array work = Empty(Shape{buffersize}, dtype, device);
auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));

std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));

device_internals.cusolverdn_handle().Call(Getrf<T>, m, m, lu_ptr, m, work_ptr, ipiv_ptr, static_cast<int*>(devinfo.get()));
device_internals.cusolverdn_handle().Call(Getrf<T>, m, m, lu_ptr, lda, work_ptr, ipiv_ptr, static_cast<int*>(devinfo.get()));

int devinfo_h = 0;
Device& native_device = GetDefaultContext().GetDevice({"native", 0});
Expand All @@ -317,7 +318,7 @@ void SolveImpl(const Array& a, const Array& b, const Array& out) {
auto out_ptr = static_cast<T*>(internal::GetRawOffsetData(out_transposed));

device_internals.cusolverdn_handle().Call(
Getrs<T>, CUBLAS_OP_N, m, nrhs, lu_ptr, m, ipiv_ptr, out_ptr, m, static_cast<int*>(devinfo.get()));
Getrs<T>, CUBLAS_OP_N, m, nrhs, lu_ptr, lda, ipiv_ptr, out_ptr, lda, static_cast<int*>(devinfo.get()));

device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
if (devinfo_h != 0) {
Expand Down Expand Up @@ -461,14 +462,21 @@ CHAINERX_CUDA_REGISTER_KERNEL(InverseKernel, CudaInverseKernel);

class CudaSvdKernel : public SvdKernel {
public:
void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices) override {
void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices, bool compute_uv) override {
Device& device = a.device();
Dtype dtype = a.dtype();
CudaSetDeviceScope scope{device.index()};

CHAINERX_ASSERT(a.ndim() == 2);

bool compute_uv = u.shape()[0] != 0 && vt.shape()[0] != 0;
if (a.shape().GetTotalSize() == 0) {
if (full_matrices && compute_uv) {
device.backend().CallKernel<IdentityKernel>(u);
device.backend().CallKernel<IdentityKernel>(vt);
}
// This kernel works correctly for zero-sized input also without early return
return;
}

// cuSOLVER assumes arrays are in column-major order.
// In order to avoid transposing the input matrix, matrix dimensions are swapped.
Expand Down Expand Up @@ -516,8 +524,9 @@ public:
vt_temp = Empty(vt_shape, dtype, device);
}

int64_t ldu = m;
int64_t ldvt = full_matrices ? n : k;
int64_t lda = std::max(int64_t{1}, m);
int64_t ldu = std::max(int64_t{1}, m);
int64_t ldvt = full_matrices ? std::max(int64_t{1}, n) : std::max(int64_t{1}, k);

auto svd_impl = [&](auto pt) {
using T = typename decltype(pt)::type;
Expand Down Expand Up @@ -555,7 +564,7 @@ public:
m,
n,
x_ptr,
m,
lda,
s_ptr,
vt_ptr,
ldu,
Expand Down
2 changes: 1 addition & 1 deletion chainerx_cc/chainerx/kernels/linalg.h
Expand Up @@ -37,7 +37,7 @@ class SvdKernel : public Kernel {
public:
static const char* name() { return "Svd"; }

virtual void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices) = 0;
virtual void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices, bool compute_uv) = 0;
};

class QrKernel : public Kernel {
Expand Down
31 changes: 24 additions & 7 deletions chainerx_cc/chainerx/native/native_device/linalg.cc
Expand Up @@ -226,6 +226,7 @@ void SolveImpl(const Array& a, const Array& b, const Array& out) {
auto lu_ptr = static_cast<T*>(internal::GetRawOffsetData(lu_matrix));

int64_t n = a.shape()[0];
int64_t lda = std::max(int64_t{1}, n);
int64_t nrhs = 1;
if (b.ndim() == 2) {
nrhs = b.shape()[1];
Expand All @@ -238,7 +239,7 @@ void SolveImpl(const Array& a, const Array& b, const Array& out) {
auto out_ptr = static_cast<T*>(internal::GetRawOffsetData(out_transposed));

int info;
Gesv(n, nrhs, lu_ptr, n, ipiv_ptr, out_ptr, n, &info);
Gesv(n, nrhs, lu_ptr, lda, ipiv_ptr, out_ptr, lda, &info);

if (info != 0) {
throw ChainerxError{"Unsuccessful gesv (Solve) execution. Info = ", info};
Expand All @@ -252,6 +253,13 @@ void InverseImpl(const Array& a, const Array& out) {
Device& device = a.device();
Dtype dtype = a.dtype();

// 'getri' segfaults for 0-sized array, documentation says that minimum size is 1
// NumPy works for 0-sized array because 'gesv' routine is used
// to obtain the inverse via solving the linear system.
if (a.shape().GetTotalSize() == 0) {
return;
}

device.backend().CallKernel<CopyKernel>(a, out);
auto out_ptr = static_cast<T*>(internal::GetRawOffsetData(out));

Expand Down Expand Up @@ -406,14 +414,21 @@ CHAINERX_NATIVE_REGISTER_KERNEL(InverseKernel, NativeInverseKernel);

class NativeSvdKernel : public SvdKernel {
public:
void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices) override {
void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices, bool compute_uv) override {
#if CHAINERX_ENABLE_LAPACK
Device& device = a.device();
Dtype dtype = a.dtype();

CHAINERX_ASSERT(a.ndim() == 2);

bool compute_uv = u.shape()[0] != 0 && vt.shape()[0] != 0;
if (a.shape().GetTotalSize() == 0) {
if (full_matrices && compute_uv) {
device.backend().CallKernel<IdentityKernel>(u);
device.backend().CallKernel<IdentityKernel>(vt);
}
// This kernel works correctly for zero-sized input also without early return
return;
}

// LAPACK assumes arrays are in column-major order.
// In order to avoid transposing the input matrix, matrix dimensions are swapped.
Expand All @@ -422,8 +437,9 @@ class NativeSvdKernel : public SvdKernel {
int64_t n = a.shape()[0];
int64_t m = a.shape()[1];
int64_t k = std::min(m, n);
int64_t ldu = m;
int64_t ldvt = full_matrices ? n : k;
int64_t lda = std::max(int64_t{1}, m);
int64_t ldu = std::max(int64_t{1}, m);
int64_t ldvt = full_matrices ? std::max(int64_t{1}, n) : std::max(int64_t{1}, k);

Array x = EmptyLike(a, device);
device.backend().CallKernel<CopyKernel>(a, x);
Expand All @@ -450,13 +466,13 @@ class NativeSvdKernel : public SvdKernel {
int buffersize = -1;
T work_size;
// When calling Gesdd pointers to u and vt are swapped instead of transposing the input matrix.
Gesdd(job, m, n, x_ptr, m, s_ptr, vt_ptr, ldu, u_ptr, ldvt, &work_size, buffersize, iwork_ptr, &info);
Gesdd(job, m, n, x_ptr, lda, s_ptr, vt_ptr, ldu, u_ptr, ldvt, &work_size, buffersize, iwork_ptr, &info);
buffersize = static_cast<int>(work_size);

Array work = Empty(Shape{buffersize}, dtype, device);
auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));

Gesdd(job, m, n, x_ptr, m, s_ptr, vt_ptr, ldu, u_ptr, ldvt, work_ptr, buffersize, iwork_ptr, &info);
Gesdd(job, m, n, x_ptr, lda, s_ptr, vt_ptr, ldu, u_ptr, ldvt, work_ptr, buffersize, iwork_ptr, &info);

if (info != 0) {
throw ChainerxError{"Unsuccessful gesdd (SVD) execution. Info = ", info};
Expand All @@ -470,6 +486,7 @@ class NativeSvdKernel : public SvdKernel {
(void)s; // unused
(void)vt; // unused
(void)full_matrices; // unused
(void)compute_uv; // unused
throw ChainerxError{"LAPACK is not linked to ChainerX."};
#endif // CHAINERX_LAPACK_AVAILABLE
}
Expand Down
9 changes: 8 additions & 1 deletion chainerx_cc/chainerx/routines/linalg.cc
Expand Up @@ -216,7 +216,7 @@ std::tuple<Array, Array, Array> Svd(const Array& a, bool full_matrices, bool com

{
NoBackpropModeScope scope{};
a.device().backend().CallKernel<SvdKernel>(a, u, s, vt, full_matrices);
a.device().backend().CallKernel<SvdKernel>(a, u, s, vt, full_matrices, compute_uv);
}

// Reference:
Expand Down Expand Up @@ -302,6 +302,13 @@ Array PseudoInverse(const Array& a, float rcond) {

std::tie(u, s, vt) = Svd(a, /*full_matrices=*/false, /*compute_uv=*/true);

// Computing the maximum along zero-sized axis is not supported
// therefore return earlier
if (a.shape().GetTotalSize() == 0) {
// Copy instead of new empty array is used so that backward does not raise errors
return Copy(a.Transpose());
}

Array cutoff = rcond * s.Max();
Array cutoff_indices = s <= cutoff;

Expand Down
31 changes: 26 additions & 5 deletions tests/chainerx_tests/unit_tests/routines_tests/test_linalg.py
Expand Up @@ -120,10 +120,17 @@ def setup(self):
self.check_double_backward_options.update({'rtol': 5e-3})


_numpy_does_not_support_0d_input113 = \
numpy.lib.NumpyVersion(numpy.__version__) < '1.13.0'

_numpy_does_not_support_0d_input116 = \
numpy.lib.NumpyVersion(numpy.__version__) < '1.16.0'


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': [(1, 1), (3, 3), (6, 6)],
'shape': [(0, 0), (1, 1), (3, 3), (6, 6)],
'b_columns': [(), (1,), (3,), (4,)],
'dtypes': [
('float32', 'float32'),
Expand Down Expand Up @@ -191,12 +198,15 @@ def forward_xp(self, inputs, xp):
@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': [(1, 1), (3, 3), (6, 6)],
'shape': [(0, 0), (1, 1), (3, 3), (6, 6)],
'dtype': ['float32', 'float64']
})
))
class TestInverse(NumpyLinalgOpTest):

# For zero sized input strides are different
check_numpy_strides_compliance = False

def generate_inputs(self):
a = numpy.random.random(self.shape).astype(self.dtype)
return a,
Expand Down Expand Up @@ -250,12 +260,12 @@ def forward_xp(self, inputs, xp):
@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': [(1, 1), (2, 3), (3, 2), (6, 6)],
'shape': [(0, 0), (0, 3), (3, 0), (1, 1), (2, 3), (3, 2), (6, 6)],
'dtype': ['float32', 'float64'],
'full_matrices': [False],
'compute_uv': [True]
}) + chainer.testing.product({
'shape': [(1, 1), (2, 3), (3, 2), (6, 6)],
'shape': [(0, 0), (0, 3), (3, 0), (1, 1), (2, 3), (3, 2), (6, 6)],
'dtype': ['float32', 'float64'],
'full_matrices': [True],
'compute_uv': [False],
Expand All @@ -271,6 +281,10 @@ def generate_inputs(self):

def forward_xp(self, inputs, xp):
a, = inputs

if (_numpy_does_not_support_0d_input116 and a.size == 0):
pytest.skip('Older NumPy versions do not work with empty arrays')

out = xp.linalg.svd(a,
full_matrices=self.full_matrices,
compute_uv=self.compute_uv)
Expand Down Expand Up @@ -306,19 +320,26 @@ def forward_xp(self, inputs, xp):
@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': [(1, 1), (2, 3), (3, 2), (6, 6)],
'shape': [(0, 0), (0, 3), (3, 0), (1, 1), (2, 3), (3, 2), (6, 6)],
'rcond': [1e-15, 0.3, 0.5, 0.6],
'dtype': ['float32', 'float64']
})
))
class TestPseudoInverse(NumpyLinalgOpTest):

# For zero sized input strides are different
check_numpy_strides_compliance = False

def generate_inputs(self):
a = numpy.random.random(self.shape).astype(self.dtype)
return a,

def forward_xp(self, inputs, xp):
a, = inputs

if (_numpy_does_not_support_0d_input113 and a.size == 0):
pytest.skip('Older NumPy versions do not work with empty arrays')

out = xp.linalg.pinv(a, rcond=self.rcond)
return out,

Expand Down

0 comments on commit 8eb830f

Please sign in to comment.