Skip to content

Commit

Permalink
Merge pull request #4511 from beam2d/fp16-serialize
Browse files Browse the repository at this point in the history
Support deserializing into array of different dtype
  • Loading branch information
hvy committed Mar 26, 2018
2 parents 7aca074 + e79febd commit 9744c49
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
2 changes: 1 addition & 1 deletion chainer/serializers/hdf5.py
Expand Up @@ -138,7 +138,7 @@ def __call__(self, key, value):
if isinstance(value, numpy.ndarray):
dataset.read_direct(value)
elif isinstance(value, cuda.ndarray):
value.set(numpy.asarray(dataset))
value.set(numpy.asarray(dataset, dtype=value.dtype))
else:
value = type(value)(numpy.asarray(dataset))
return value
Expand Down
2 changes: 1 addition & 1 deletion chainer/serializers/npz.py
Expand Up @@ -148,7 +148,7 @@ def __call__(self, key, value):
elif isinstance(value, numpy.ndarray):
numpy.copyto(value, dataset)
elif isinstance(value, cuda.ndarray):
value.set(numpy.asarray(dataset))
value.set(numpy.asarray(dataset, dtype=value.dtype))
else:
value = type(value)(numpy.asarray(dataset))
return value
Expand Down
14 changes: 14 additions & 0 deletions tests/chainer_tests/serializers_tests/test_hdf5.py
Expand Up @@ -153,6 +153,20 @@ def test_deserialize_none_value_gpu(self):
y = numpy.empty((2, 3), dtype=numpy.float32)
self.check_deserialize_none_value(cuda.to_gpu(y))

def test_deserialize_different_dtype_cpu(self):
y = numpy.empty((2, 3), dtype=numpy.float16)
ret = self.deserializer('y', y)
numpy.testing.assert_array_equal(y, self.data.astype(numpy.float16))
self.assertIs(ret, y)

@attr.gpu
def test_deserialize_different_dtype_gpu(self):
y = cuda.cupy.empty((2, 3), dtype=numpy.float16)
ret = self.deserializer('y', y)
numpy.testing.assert_array_equal(
y.get(), self.data.astype(numpy.float16))
self.assertIs(ret, y)

def test_deserialize_scalar(self):
z = 5
ret = self.deserializer('z', z)
Expand Down
14 changes: 14 additions & 0 deletions tests/chainer_tests/serializers_tests/test_npz.py
Expand Up @@ -151,6 +151,20 @@ def test_deserialize_gpu_strip_slashes(self):
y = numpy.empty((2, 3), dtype=numpy.float32)
self.check_deserialize(cuda.to_gpu(y), '/y')

def test_deserialize_different_dtype_cpu(self):
y = numpy.empty((2, 3), dtype=numpy.float16)
ret = self.deserializer('y', y)
numpy.testing.assert_array_equal(y, self.data.astype(numpy.float16))
self.assertIs(ret, y)

@attr.gpu
def test_deserialize_different_dtype_gpu(self):
y = cuda.cupy.empty((2, 3), dtype=numpy.float16)
ret = self.deserializer('y', y)
numpy.testing.assert_array_equal(
y.get(), self.data.astype(numpy.float16))
self.assertIs(ret, y)

def test_deserialize_scalar(self):
z = 5
ret = self.deserializer('z', z)
Expand Down

0 comments on commit 9744c49

Please sign in to comment.