diff --git a/test/test_transforms.py b/test/test_transforms.py index e153b03f8aa..537294976c7 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -169,6 +169,12 @@ def test_ndarray_to_pil_image(self): l, = img.split() assert np.allclose(l, img_data[:, :, 0]) + def test_ndarray16_to_pil_image(self): + trans = transforms.ToPILImage() + img_data = np.random.randint(0, 65535, [4, 4, 1], np.uint16) + img = trans(img_data) + assert img.mode == 'I;16' + assert np.allclose(img, img_data[:, :, 0]) if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 09161d506de..770252c39eb 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -55,21 +55,32 @@ def __call__(self, pic): class ToPILImage(object): - """Converts a torch.*Tensor of range [0, 1] and shape C x H x W - or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C - to a PIL.Image of range [0, 255] + """Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape + H x W x C to a PIL.Image while preserving value range. """ def __call__(self, pic): npimg = pic mode = None - if not isinstance(npimg, np.ndarray): - npimg = pic.mul(255).byte().numpy() - npimg = np.transpose(npimg, (1, 2, 0)) + if isinstance(pic, torch.FloatTensor): + pic = pic.mul(255).byte() + if torch.is_tensor(pic): + npimg = np.transpose(pic.numpy(), (1, 2, 0)) + assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' if npimg.shape[2] == 1: npimg = npimg[:, :, 0] - mode = "L" + + if npimg.dtype == np.uint8: + mode = 'L' + if npimg.dtype == np.uint16: + mode = 'I;16' + elif npimg.dtype == np.float32: + mode = 'F' + else: + if npimg.dtype == np.uint8: + mode = 'RGB' + assert mode is not None, '{} is not supported'.format(npimg.dtype) return Image.fromarray(npimg, mode=mode)