Skip to content

Commit

Permalink
[RF][PyROOT] Fix RooDataSet.from_numpy() for contiguous input arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
guitargeek committed Sep 5, 2023
1 parent d016feb commit e8d262a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 25 deletions.
Expand Up @@ -99,7 +99,6 @@ def log_warning(s):
range_mask = np.ones_like(list(data.values())[0], dtype=bool)

def in_range(arr, variable):

# For categories, we need to check whether the elements of the
# array are in the set of category state indices
if variable.isCategory():
Expand All @@ -117,30 +116,27 @@ def select_range_and_change_type(arr, dtype):
if range_mask is not None:
arr = arr[range_mask]
arr = arr if arr.dtype == dtype else np.array(arr, dtype=dtype)
return arr

for real in dataset.store().realStoreList():
vec = real.data()
arg = real.bufArg()
arr = select_range_and_change_type(data[arg.GetName()], np.float64)

vec.resize(len(arr))

beg = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
void_p = ctypes.cast(beg, ctypes.c_voidp).value + 8 * len(arr)
end = ctypes.cast(void_p, ctypes.POINTER(ctypes.c_double))
ROOT.std.copy(beg, end, vec.begin())

for cat in dataset.store().catStoreList():
vec = cat.data()
arg = cat.bufArg()
arr = select_range_and_change_type(data[arg.GetName()], np.int32)
vec.resize(len(arr))

beg = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int))
void_p = ctypes.cast(beg, ctypes.c_voidp).value + 4 * len(arr)
end = ctypes.cast(void_p, ctypes.POINTER(ctypes.c_int))
ROOT.std.copy(beg, end, vec.begin())
# Make sure that the array is contiguous to we can std::copy() it.
# In the implementation of ascontiguousarray(), no copy is done if
# the array is already contiguous, which is exactly what we want.
return np.ascontiguousarray(arr)

def copy_to_dataset(store_list, np_type, c_type, type_size_in_bytes):
for real in store_list:
vec = real.data()
arg = real.bufArg()
arr = select_range_and_change_type(data[arg.GetName()], np_type)

vec.resize(len(arr))

beg = arr.ctypes.data_as(ctypes.POINTER(c_type))
n_bytes = type_size_in_bytes * len(arr)
void_p = ctypes.cast(beg, ctypes.c_voidp).value + n_bytes
end = ctypes.cast(void_p, ctypes.POINTER(c_type))
ROOT.std.copy(beg, end, vec.begin())

copy_to_dataset(dataset.store().realStoreList(), np.float64, ctypes.c_double, 8)
copy_to_dataset(dataset.store().catStoreList(), np.int32, ctypes.c_int, 4)

dataset.store().recomputeSumWeight()

Expand Down
35 changes: 35 additions & 0 deletions bindings/pyroot/pythonizations/test/roofit/roodataset_numpy.py
Expand Up @@ -143,6 +143,41 @@ def test_ignoring_out_of_range(self):

self.assertEqual(dataset_numpy.numEntries(), n_in_range)

def test_non_contiguous_arrays(self):
"""Test whether the import also works with non-continguous arrays.
Covers GitHub issue #1360.
"""

import itertools

obs_1 = ROOT.RooRealVar("obs_1", "obs_1", 70, 70, 190)
obs_1.setBins(30)
obs_2 = ROOT.RooRealVar("obs_2", "obs_2", 100, 100, 180)
obs_2.setBins(80)

val_obs_1 = []
val_obs_2 = []
for i in range(obs_1.numBins()):
obs_1.setBin(i)
val_obs_1.append(obs_1.getVal())
for i in range(obs_2.numBins()):
obs_2.setBin(i)
val_obs_2.append(obs_2.getVal())

# so that all combination of values are in the dataset
val_cart_product = np.array(list(itertools.product(val_obs_1, val_obs_2)))
data = {"obs_1": val_cart_product[:, 0], "obs_2": val_cart_product[:, 1]}

# To make sure the array is really not C-contiguous
assert data["obs_1"].flags["C_CONTIGUOUS"] == False

dataset = ROOT.RooDataSet.from_numpy(data, ROOT.RooArgSet(obs_1, obs_2))

data_roundtripped = dataset.to_numpy()

np.testing.assert_equal(data["obs_1"], data_roundtripped["obs_1"])
np.testing.assert_equal(data["obs_2"], data_roundtripped["obs_2"])


if __name__ == "__main__":
unittest.main()

0 comments on commit e8d262a

Please sign in to comment.