Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions catch/utils/set_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,23 +299,34 @@ def approx_multiuniverse(sets,
# Create the universes from given sets
if use_intervalsets:
# Store the elements of each universe in an IntervalSet
universes = defaultdict(lambda: interval.IntervalSet([]))
# First, collect a list of intervals for each universe
universes_unmerged = defaultdict(list)
for sets_by_universe in sets.values():
for universe_id, s in sets_by_universe.items():
if isinstance(s, tuple):
# s is a single interval
universes_unmerged[universe_id].append(s)
else:
# s is an IntervalSet
for i in s.intervals:
universes_unmerged[universe_id].append(i)
# Now, for each universe, create one IntervalSet from its list
# of intervals; doing so will merge overlapping intervals
# (effectively taking the union of all the intervals)
universes = {}
for universe_id, intervals in universes_unmerged.items():
universes[universe_id] = interval.IntervalSet(intervals)
else:
# Store the elements of each universe in a set
universes = defaultdict(set)
for sets_by_universe in sets.values():
for universe_id, s in sets_by_universe.items():
if use_intervalsets:
if isinstance(s, tuple):
# s is a single interval
s = interval.IntervalSet([s])
universes[universe_id] = universes[universe_id].union(s)
elif use_arrays:
for v in s:
universes[universe_id].add(v)
else:
universes[universe_id].update(s)
universes = dict(universes)
for sets_by_universe in sets.values():
for universe_id, s in sets_by_universe.items():
if use_arrays:
for v in s:
universes[universe_id].add(v)
else:
universes[universe_id].update(s)
universes = dict(universes)

if universe_p is None:
# Give each universe a coverage fraction of 1.0 (i.e., cover
Expand Down
7 changes: 7 additions & 0 deletions catch/utils/tests/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,16 @@ def compare_union(self, input_a, input_b, desired_output):
a = IntervalSet(input_a)
b = IntervalSet(input_b)
o = IntervalSet(desired_output)

# Union is commutative, so check both orders
self.assertEqual(a.union(b), o)
self.assertEqual(b.union(a), o)

# Making a new IntervalSet should merge overlapping
# intervals, effectively taking the union
ab = IntervalSet(input_a + input_b)
self.assertEqual(ab, o)

def test_union(self):
self.compare_union([], [], [])
self.compare_union([], [(1, 3)], [(1, 3)])
Expand All @@ -77,6 +83,7 @@ def test_union(self):
self.compare_union([(1, 10)], [(3, 7)], [(1, 10)])
self.compare_union([(2, 100)], [(0, 50)], [(0, 100)])
self.compare_union([(0, 7)], [(4, 10)], [(0, 10)])
self.compare_union([(4, 10)], [(0, 7)], [(0, 10)])
self.compare_union([(1, 5), (10, 15)], [(1, 5), (15, 20)], [(1, 5),
(10, 20)])
self.compare_union([(1, 5), (10, 15)], [(3, 12)], [(1, 15)])
Expand Down