Skip to content

Commit

Permalink
Merge pull request #25919 from mhvk/logspace-broadcast-only-arrays
Browse files Browse the repository at this point in the history
BUG: Ensure non-array logspace base does not influence dtype of output.
  • Loading branch information
rgommers committed Mar 4, 2024
2 parents b9c4c21 + 136b9ed commit 9d69a62
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
15 changes: 9 additions & 6 deletions numpy/_core/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,16 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
>>> plt.show()
"""
ndmax = np.broadcast(start, stop, base).ndim
start, stop, base = (
np.array(a, copy=None, subok=True, ndmin=ndmax)
for a in (start, stop, base)
)
if not isinstance(base, (float, int)) and np.ndim(base):
# If base is non-scalar, broadcast it with the others, since it
# may influence how axis is interpreted.
ndmax = np.broadcast(start, stop, base).ndim
start, stop, base = (
np.array(a, copy=None, subok=True, ndmin=ndmax)
for a in (start, stop, base)
)
base = np.expand_dims(base, axis=axis)
y = linspace(start, stop, num=num, endpoint=endpoint, axis=axis)
base = np.expand_dims(base, axis=axis)
if dtype is None:
return _nx.power(base, y)
return _nx.power(base, y).astype(dtype, copy=False)
Expand Down
16 changes: 16 additions & 0 deletions numpy/_core/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2598,3 +2598,19 @@ def test_replace_regression(self):
expected = np.char.chararray((2,), itemsize=25)
expected[:] = [s.replace(b"E", b"D") for s in test_strings]
assert_array_equal(out, expected)

def test_logspace_base_does_not_determine_dtype(self):
# gh-24957 and cupy/cupy/issues/7946
start = np.array([0, 2], dtype=np.float16)
stop = np.array([2, 0], dtype=np.float16)
out = np.logspace(start, stop, num=5, axis=1, dtype=np.float32)
expected = np.array([[1., 3.1621094, 10., 31.625, 100.],
[100., 31.625, 10., 3.1621094, 1.]],
dtype=np.float32)
assert_almost_equal(out, expected)
# Check test fails if the calculation is done in float64, as happened
# before when a python float base incorrectly influenced the dtype.
out2 = np.logspace(start, stop, num=5, axis=1, dtype=np.float32,
base=np.array([10.0]))
with pytest.raises(AssertionError, match="not almost equal"):
assert_almost_equal(out2, expected)

0 comments on commit 9d69a62

Please sign in to comment.