From d11097cea6bca6d0f4f5ceaadac968014bef40c1 Mon Sep 17 00:00:00 2001 From: Michael Lenzen Date: Tue, 27 Nov 2018 22:17:40 -0600 Subject: [PATCH] Refactor bag operations for efficiency --- collections_extended/bags.py | 64 +++++++++++++++++++++++------------- tests/test_bags.py | 7 ++-- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/collections_extended/bags.py b/collections_extended/bags.py index 98a976a..2240b92 100644 --- a/collections_extended/bags.py +++ b/collections_extended/bags.py @@ -28,9 +28,8 @@ def __init__(self, iterable=None): self._size = 0 if iterable: if isinstance(iterable, _basebag): - for elem, count in iterable._dict.items(): - self._dict[elem] = count - self._size += count + self._dict = iterable._dict.copy() + self._size = iterable._size else: for value in iterable: self._dict[value] = self._dict.get(value, 0) + 1 @@ -253,8 +252,6 @@ def isdisjoint(self, other): """Return if this bag is disjoint with the passed collection. This runs in O(len(other)) - - TODO move isdisjoint somewhere more appropriate """ for value in other: if value in self: @@ -290,9 +287,14 @@ def __add__(self, other): other (Iterable): elements to add to self """ out = self.copy() - for value in other: - out._dict[value] = out._dict.get(value, 0) + 1 - out._size += 1 + if isinstance(other, _basebag): + for elem, count in other._dict.items(): + out._dict[elem] = out._dict.get(elem, 0) + count + out._size += count + else: + for elem in other: + out._dict[elem] = out._dict.get(elem, 0) + 1 + out._size += 1 return out def __sub__(self, other): @@ -307,16 +309,17 @@ def __sub__(self, other): Args: other (Iterable): elements to remove """ - out = self.copy() - for value in other: - old_count = out._dict.get(value, 0) - if old_count == 1: - del out._dict[value] - out._size -= 1 - elif old_count > 1: - out._dict[value] = old_count - 1 - out._size -= 1 - return out + if isinstance(other, _basebag): + values = dict() + for elem, self_count in self._dict.items(): + values[elem] = max(self_count - other.count(elem), 0) + else: + values = self._dict.copy() + for elem in other: + old_count = values.get(elem, 0) + if old_count: + values[elem] -= 1 + return self.from_mapping(values) def __mul__(self, other): """Cartesian product of the two sets. @@ -355,7 +358,19 @@ def __xor__(self, other): m = len(self) n = len(other) """ - return (self - other) | (other - self) + if isinstance(other, _basebag): + values = dict() + for elem in self._dict.keys() | other._dict.keys(): + values[elem] = abs(self.count(elem) - other.count(elem)) + else: + values = self._dict.copy() + for elem in other: + old_count = values.get(elem, 0) + if old_count: + values[elem] -= 1 + else: + values[elem] += 1 + return self.from_mapping(values) class bag(_basebag, MutableSet): @@ -480,9 +495,14 @@ def __ixor__(self, other): """ if not isinstance(other, _basebag): other = self._from_iterable(other) - other_minus_self = other - self - self.discard_all(other) - self |= other_minus_self + for elem, other_count in other._dict.items(): + self_count = self.count(elem) + new_count = abs(self_count - other_count) + if new_count: + self._dict[elem] = new_count + else: + del self._dict[elem] + self._size -= (self_count - new_count) return self def __isub__(self, other): diff --git a/tests/test_bags.py b/tests/test_bags.py index 9a66d36..da954af 100644 --- a/tests/test_bags.py +++ b/tests/test_bags.py @@ -315,9 +315,12 @@ def test_iand(): def test_ixor(): """Test __ixor__.""" - b = bag('abbc') - b ^= bag('bg') + b = bag('abbbccd') + b ^= bag('bbcdg') assert b == bag('abcg') + b = bag('bbcdg') + b ^= bag('abbbccd') + assert b == bag('acbg') b = bag('abbc') b ^= set('bg') assert b == bag('abcg')