From e9c9945e4064b977bdb5da71443499013692fe5b Mon Sep 17 00:00:00 2001 From: Seiya Tokui Date: Wed, 21 Mar 2018 10:14:40 +0900 Subject: [PATCH] Support deserializing into array of different dtype --- chainer/serializers/hdf5.py | 2 +- chainer/serializers/npz.py | 2 +- tests/chainer_tests/serializers_tests/test_hdf5.py | 9 +++++++++ tests/chainer_tests/serializers_tests/test_npz.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/chainer/serializers/hdf5.py b/chainer/serializers/hdf5.py index 2eceeee39766..2f210cac6596 100644 --- a/chainer/serializers/hdf5.py +++ b/chainer/serializers/hdf5.py @@ -138,7 +138,7 @@ def __call__(self, key, value): if isinstance(value, numpy.ndarray): dataset.read_direct(value) elif isinstance(value, cuda.ndarray): - value.set(numpy.asarray(dataset)) + value.set(numpy.asarray(dataset, dtype=value.dtype)) else: value = type(value)(numpy.asarray(dataset)) return value diff --git a/chainer/serializers/npz.py b/chainer/serializers/npz.py index 1472b9c2f0f9..23a112f637bc 100644 --- a/chainer/serializers/npz.py +++ b/chainer/serializers/npz.py @@ -148,7 +148,7 @@ def __call__(self, key, value): elif isinstance(value, numpy.ndarray): numpy.copyto(value, dataset) elif isinstance(value, cuda.ndarray): - value.set(numpy.asarray(dataset)) + value.set(numpy.asarray(dataset, dtype=value.dtype)) else: value = type(value)(numpy.asarray(dataset)) return value diff --git a/tests/chainer_tests/serializers_tests/test_hdf5.py b/tests/chainer_tests/serializers_tests/test_hdf5.py index 90d7ad699e64..0349073831b6 100644 --- a/tests/chainer_tests/serializers_tests/test_hdf5.py +++ b/tests/chainer_tests/serializers_tests/test_hdf5.py @@ -153,6 +153,15 @@ def test_deserialize_none_value_gpu(self): y = numpy.empty((2, 3), dtype=numpy.float32) self.check_deserialize_none_value(cuda.to_gpu(y)) + def test_deserialize_different_dtype_cpu(self): + y = numpy.empty((2, 3), dtype=numpy.float16) + self.check_deserialize(y) + + @attr.gpu + def test_deserialize_different_dtype_gpu(self): + y = numpy.empty((2, 3), dtype=numpy.float16) + self.check_deserialize(cuda.to_gpu(y)) + def test_deserialize_scalar(self): z = 5 ret = self.deserializer('z', z) diff --git a/tests/chainer_tests/serializers_tests/test_npz.py b/tests/chainer_tests/serializers_tests/test_npz.py index 09bf97de7458..145cd52cb8ab 100644 --- a/tests/chainer_tests/serializers_tests/test_npz.py +++ b/tests/chainer_tests/serializers_tests/test_npz.py @@ -151,6 +151,20 @@ def test_deserialize_gpu_strip_slashes(self): y = numpy.empty((2, 3), dtype=numpy.float32) self.check_deserialize(cuda.to_gpu(y), '/y') + def test_deserialize_different_dtype_cpu(self): + y = numpy.empty((2, 3), dtype=numpy.float16) + ret = self.deserializer('y', y) + numpy.testing.assert_array_equal(y, self.data.astype(numpy.float16)) + self.assertIs(ret, y) + + @attr.gpu + def test_deserialize_different_dtype_gpu(self): + y = cuda.cupy.empty((2, 3), dtype=numpy.float16) + ret = self.deserializer('y', y) + numpy.testing.assert_array_equal( + y.get(), self.data.astype(numpy.float16)) + self.assertIs(ret, y) + def test_deserialize_scalar(self): z = 5 ret = self.deserializer('z', z)