diff --git a/becquerel/core/utils.py b/becquerel/core/utils.py index 69d072f3..9ceb23f0 100644 --- a/becquerel/core/utils.py +++ b/becquerel/core/utils.py @@ -124,3 +124,20 @@ def bin_centers_from_edges(edges_kev): edges_kev = np.array(edges_kev) centers_kev = (edges_kev[:-1] + edges_kev[1:]) / 2 return centers_kev + + +def sqrt_bins(bin_edge_min, bin_edge_max, nbins): + """ + Square root binning + + Args: + bin_edge_min (float): Minimum bin edge (must be >= 0) + bin_edge_max (float): Maximum bin edge (must be greater than bin_min) + nbins (int): Number of bins + + Returns: + np.array of bin edges (length = nbins + 1) + """ + assert bin_edge_min >= 0 + assert bin_edge_max > bin_edge_min + return np.linspace(np.sqrt(bin_edge_min), np.sqrt(bin_edge_max), nbins + 1) ** 2 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..ab4b7ba2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,28 @@ +import pytest +import numpy as np +import becquerel as bq + + +# ---------------------------------------------- +# Test utils +# ---------------------------------------------- + + +def test_sqrt_bins(): + """Test basic functionality of utils.sqrt_bins.""" + edge_min = 0 + edge_max = 3000 + n_bins = 128 + be = bq.utils.sqrt_bins(edge_min, edge_max, n_bins) + bc = (be[1:] + be[:-1]) / 2 + bw = np.diff(be) + # compute slope of line + m = np.diff(bw ** 2) / np.diff(bc) + # assert that the square of the bin + assert np.allclose(m[0], m) + # negative edge_min + with pytest.raises(AssertionError): + be = bq.utils.sqrt_bins(-10, edge_max, n_bins) + # edge_max < edge_min + with pytest.raises(AssertionError): + be = bq.utils.sqrt_bins(100, 50, n_bins)