diff --git a/catch/utils/set_cover.py b/catch/utils/set_cover.py index 0e11b9424..7f937b24b 100644 --- a/catch/utils/set_cover.py +++ b/catch/utils/set_cover.py @@ -299,23 +299,34 @@ def approx_multiuniverse(sets, # Create the universes from given sets if use_intervalsets: # Store the elements of each universe in an IntervalSet - universes = defaultdict(lambda: interval.IntervalSet([])) + # First, collect a list of intervals for each universe + universes_unmerged = defaultdict(list) + for sets_by_universe in sets.values(): + for universe_id, s in sets_by_universe.items(): + if isinstance(s, tuple): + # s is a single interval + universes_unmerged[universe_id].append(s) + else: + # s is an IntervalSet + for i in s.intervals: + universes_unmerged[universe_id].append(i) + # Now, for each universe, create one IntervalSet from its list + # of intervals; doing so will merge overlapping intervals + # (effectively taking the union of all the intervals) + universes = {} + for universe_id, intervals in universes_unmerged.items(): + universes[universe_id] = interval.IntervalSet(intervals) else: # Store the elements of each universe in a set universes = defaultdict(set) - for sets_by_universe in sets.values(): - for universe_id, s in sets_by_universe.items(): - if use_intervalsets: - if isinstance(s, tuple): - # s is a single interval - s = interval.IntervalSet([s]) - universes[universe_id] = universes[universe_id].union(s) - elif use_arrays: - for v in s: - universes[universe_id].add(v) - else: - universes[universe_id].update(s) - universes = dict(universes) + for sets_by_universe in sets.values(): + for universe_id, s in sets_by_universe.items(): + if use_arrays: + for v in s: + universes[universe_id].add(v) + else: + universes[universe_id].update(s) + universes = dict(universes) if universe_p is None: # Give each universe a coverage fraction of 1.0 (i.e., cover diff --git a/catch/utils/tests/test_interval.py b/catch/utils/tests/test_interval.py index b90522bab..a3722f315 100644 --- a/catch/utils/tests/test_interval.py +++ b/catch/utils/tests/test_interval.py @@ -62,10 +62,16 @@ def compare_union(self, input_a, input_b, desired_output): a = IntervalSet(input_a) b = IntervalSet(input_b) o = IntervalSet(desired_output) + # Union is commutative, so check both orders self.assertEqual(a.union(b), o) self.assertEqual(b.union(a), o) + # Making a new IntervalSet should merge overlapping + # intervals, effectively taking the union + ab = IntervalSet(input_a + input_b) + self.assertEqual(ab, o) + def test_union(self): self.compare_union([], [], []) self.compare_union([], [(1, 3)], [(1, 3)]) @@ -77,6 +83,7 @@ def test_union(self): self.compare_union([(1, 10)], [(3, 7)], [(1, 10)]) self.compare_union([(2, 100)], [(0, 50)], [(0, 100)]) self.compare_union([(0, 7)], [(4, 10)], [(0, 10)]) + self.compare_union([(4, 10)], [(0, 7)], [(0, 10)]) self.compare_union([(1, 5), (10, 15)], [(1, 5), (15, 20)], [(1, 5), (10, 20)]) self.compare_union([(1, 5), (10, 15)], [(3, 12)], [(1, 15)])