Skip to content

Commit

Permalink
Merge 5bbb615 into d00b3cb
Browse files Browse the repository at this point in the history
  • Loading branch information
hvy committed Mar 12, 2020
2 parents d00b3cb + 5bbb615 commit c6d9258
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
3 changes: 2 additions & 1 deletion chainer/backend.py
Expand Up @@ -67,7 +67,8 @@ def copyto(dst, src):
elif isinstance(dst, cuda.ndarray):
if isinstance(src, chainer.get_cpu_array_types()):
src = numpy.asarray(src)
if dst.flags.c_contiguous or dst.flags.f_contiguous:
if (src.dtype == dst.dtype
and (dst.flags.c_contiguous or dst.flags.f_contiguous)):
dst.set(src)
else:
cuda.cupy.copyto(dst, cuda.to_gpu(src, device=dst.device))
Expand Down
20 changes: 16 additions & 4 deletions tests/chainer_tests/test_backend.py
Expand Up @@ -65,15 +65,21 @@ def test_from_chx_cuda(self):
numpy.testing.assert_array_equal(self._to_cpu(dst), self.src_data)


@testing.parameterize(*testing.product({
'dtype': [numpy.float16, numpy.float32],
}))
class TestCopyToCPU(_TestCopyToBase, unittest.TestCase):
def _get_dst(self):
return self.dst_data
return self.dst_data.astype(self.dtype, copy=False)


@testing.parameterize(*testing.product({
'dtype': [numpy.float16, numpy.float32],
}))
@attr.gpu
class TestCopyToGPU(_TestCopyToBase, unittest.TestCase):
def _get_dst(self):
return cuda.cupy.array(self.dst_data)
return cuda.cupy.array(self.dst_data, self.dtype)

@attr.multi_gpu(2)
def test_gpu_to_another_gpu(self):
Expand All @@ -92,17 +98,23 @@ def _get_dst(self):
return dst


@testing.parameterize(*testing.product({
'dtype': [numpy.float16, numpy.float32],
}))
@attr.chainerx
class TestCopyToChxNative(_TestCopyToBase, unittest.TestCase):
def _get_dst(self):
return chainerx.array(self.dst_data, device='native')
return chainerx.array(self.dst_data, dtype=self.dtype, device='native')


@testing.parameterize(*testing.product({
'dtype': [numpy.float16, numpy.float32],
}))
@attr.chainerx
@attr.gpu
class TestCopyToChxCuda(_TestCopyToBase, unittest.TestCase):
def _get_dst(self):
return chainerx.array(self.dst_data, device='cuda:0')
return chainerx.array(self.dst_data, dtype=self.dtype, device='cuda:0')


class TestCopyToError(unittest.TestCase):
Expand Down

0 comments on commit c6d9258

Please sign in to comment.