Skip to content

Commit

Permalink
cherry-pick _supported_float_type from scikit-imagegh-5219
Browse files Browse the repository at this point in the history
  • Loading branch information
grlee77 committed Apr 21, 2021
1 parent 5b47e84 commit 59e1cce
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions skimage/_shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numbers
import sys
import warnings
from collections.abc import Iterable

import numpy as np
from numpy.lib import NumpyVersion
Expand Down Expand Up @@ -579,3 +580,47 @@ def _fix_ndimage_mode(mode):
if NumpyVersion(scipy.__version__) >= '1.6.0':
mode = grid_modes.get(mode, mode)
return mode


new_float_type = {
# preserved types
np.float32().dtype.char: np.float32,
np.float64().dtype.char: np.float64,
np.complex64().dtype.char: np.complex64,
np.complex128().dtype.char: np.complex128,
# altered types
np.float16().dtype.char: np.float32,
'g': np.float64, # np.float128 ; doesn't exist on windows
'G': np.complex128, # np.complex256 ; doesn't exist on windows
}


def _supported_float_type(input_dtype, allow_complex=False):
"""Return an appropriate floating-point dtype for a given dtype.
float32, float64, complex64, complex128 are preserved.
float16 is promoted to float32.
complex256 is demoted to complex128.
Other types are cast to float64.
Paramters
---------
input_dtype : np.dtype or Iterable of np.dtype
The input dtype. If a sequence of multiple dtypes is provided, each
dtype is first converted to a supported floating point type and the
final dtype is then determined by applying `np.result_type` on the
sequence of supported floating point types.
allow_complex : bool, optional
If False, raise a ValueError on complex-valued inputs.
Retruns
-------
float_type : dtype
Floating-point dtype for the image.
"""
if isinstance(input_dtype, Iterable) and not isinstance(input_dtype, str):
return np.result_type(*(_supported_float_type(d) for d in input_dtype))
input_dtype = np.dtype(input_dtype)
if not allow_complex and input_dtype.kind == 'c':
raise ValueError("complex valued input is not supported")
return new_float_type.get(input_dtype.char, np.float64)

0 comments on commit 59e1cce

Please sign in to comment.