Skip to content

Commit

Permalink
Refactor bag operations for efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
mlenzen committed Nov 28, 2018
1 parent f8f42bd commit d11097c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 24 deletions.
64 changes: 42 additions & 22 deletions collections_extended/bags.py
Expand Up @@ -28,9 +28,8 @@ def __init__(self, iterable=None):
self._size = 0
if iterable:
if isinstance(iterable, _basebag):
for elem, count in iterable._dict.items():
self._dict[elem] = count
self._size += count
self._dict = iterable._dict.copy()
self._size = iterable._size
else:
for value in iterable:
self._dict[value] = self._dict.get(value, 0) + 1
Expand Down Expand Up @@ -253,8 +252,6 @@ def isdisjoint(self, other):
"""Return if this bag is disjoint with the passed collection.
This runs in O(len(other))
TODO move isdisjoint somewhere more appropriate
"""
for value in other:
if value in self:
Expand Down Expand Up @@ -290,9 +287,14 @@ def __add__(self, other):
other (Iterable): elements to add to self
"""
out = self.copy()
for value in other:
out._dict[value] = out._dict.get(value, 0) + 1
out._size += 1
if isinstance(other, _basebag):
for elem, count in other._dict.items():
out._dict[elem] = out._dict.get(elem, 0) + count
out._size += count
else:
for elem in other:
out._dict[elem] = out._dict.get(elem, 0) + 1
out._size += 1
return out

def __sub__(self, other):
Expand All @@ -307,16 +309,17 @@ def __sub__(self, other):
Args:
other (Iterable): elements to remove
"""
out = self.copy()
for value in other:
old_count = out._dict.get(value, 0)
if old_count == 1:
del out._dict[value]
out._size -= 1
elif old_count > 1:
out._dict[value] = old_count - 1
out._size -= 1
return out
if isinstance(other, _basebag):
values = dict()
for elem, self_count in self._dict.items():
values[elem] = max(self_count - other.count(elem), 0)
else:
values = self._dict.copy()
for elem in other:
old_count = values.get(elem, 0)
if old_count:
values[elem] -= 1
return self.from_mapping(values)

def __mul__(self, other):
"""Cartesian product of the two sets.
Expand Down Expand Up @@ -355,7 +358,19 @@ def __xor__(self, other):
m = len(self)
n = len(other)
"""
return (self - other) | (other - self)
if isinstance(other, _basebag):
values = dict()
for elem in self._dict.keys() | other._dict.keys():
values[elem] = abs(self.count(elem) - other.count(elem))
else:
values = self._dict.copy()
for elem in other:
old_count = values.get(elem, 0)
if old_count:
values[elem] -= 1
else:
values[elem] += 1
return self.from_mapping(values)


class bag(_basebag, MutableSet):
Expand Down Expand Up @@ -480,9 +495,14 @@ def __ixor__(self, other):
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
other_minus_self = other - self
self.discard_all(other)
self |= other_minus_self
for elem, other_count in other._dict.items():
self_count = self.count(elem)
new_count = abs(self_count - other_count)
if new_count:
self._dict[elem] = new_count
else:
del self._dict[elem]
self._size -= (self_count - new_count)
return self

def __isub__(self, other):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_bags.py
Expand Up @@ -315,9 +315,12 @@ def test_iand():

def test_ixor():
"""Test __ixor__."""
b = bag('abbc')
b ^= bag('bg')
b = bag('abbbccd')
b ^= bag('bbcdg')
assert b == bag('abcg')
b = bag('bbcdg')
b ^= bag('abbbccd')
assert b == bag('acbg')
b = bag('abbc')
b ^= set('bg')
assert b == bag('abcg')
Expand Down

0 comments on commit d11097c

Please sign in to comment.