diff --git a/docs/new_sets_doc.md b/docs/new_sets_doc.md index 4f7fde22..18b6507d 100755 --- a/docs/new_sets_doc.md +++ b/docs/new_sets_doc.md @@ -246,6 +246,29 @@ Returns the union of several sets. The set union of all sets in `*args`. + + +## sets.mutable_union + +
+sets.mutable_union(a, b)
+
+ +Modify set `a` adding elements from `b` to it. + +**PARAMETERS** + + +| Name | Description | Default Value | +| :------------- | :------------- | :------------- | +| a | A set, as returned by sets.make(). | none | +| b | A set, as returned by sets.make(). | none | + +**RETURNS** + +The set `a` with all elements appearing in `b` added to it. + + ## sets.difference @@ -269,6 +292,29 @@ Returns the elements in `a` that are not in `b`. A set containing the elements that are in `a` but not in `b`. + + +## sets.mutable_difference + +
+sets.mutable_difference(a, b)
+
+ +Modify set `a` removing elements from `b` from it. + +**PARAMETERS** + + +| Name | Description | Default Value | +| :------------- | :------------- | :------------- | +| a | A set, as returned by sets.make(). | none | +| b | A set, as returned by sets.make(). | none | + +**RETURNS** + +The set `a` with all elements appearing in `b` removed from it. + + ## sets.length diff --git a/lib/new_sets.bzl b/lib/new_sets.bzl index cd90a30e..02cf2cca 100644 --- a/lib/new_sets.bzl +++ b/lib/new_sets.bzl @@ -189,6 +189,19 @@ def _union(*args): """ return struct(_values = dicts.add(*[s._values for s in args])) +def _mutable_union(a, b): + """Modify set `a` adding elements from `b` to it. + + Args: + a: A set, as returned by `sets.make()`. + b: A set, as returned by `sets.make()`. + + Returns: + The set `a` with all elements appearing in `b` added to it. + """ + a._values.update(b._values) + return a + def _difference(a, b): """Returns the elements in `a` that are not in `b`. @@ -201,6 +214,20 @@ def _difference(a, b): """ return struct(_values = {e: None for e in a._values.keys() if e not in b._values}) +def _mutable_difference(a, b): + """Modify set `a` removing elements from `b` from it. + + Args: + a: A set, as returned by `sets.make()`. + b: A set, as returned by `sets.make()`. + + Returns: + The set `a` with all elements appearing in `b` removed from it. + """ + for item in b._values.keys(): + a._values.pop(item) + return a + def _length(s): """Returns the number of elements in a set. @@ -234,7 +261,9 @@ sets = struct( disjoint = _disjoint, intersection = _intersection, union = _union, + mutable_union = _mutable_union, difference = _difference, + mutable_difference = _mutable_difference, length = _length, remove = _remove, repr = _repr, diff --git a/tests/new_sets_tests.bzl b/tests/new_sets_tests.bzl index e73b7d46..9e7faa38 100644 --- a/tests/new_sets_tests.bzl +++ b/tests/new_sets_tests.bzl @@ -98,17 +98,29 @@ def _union_test(ctx): """Unit tests for sets.union.""" env = unittest.begin(ctx) - asserts.new_set_equals(env, sets.make(), sets.union()) - asserts.new_set_equals(env, sets.make([1]), sets.union(sets.make([1]))) - asserts.new_set_equals(env, sets.make(), sets.union(sets.make(), sets.make())) - asserts.new_set_equals(env, sets.make([1]), sets.union(sets.make(), sets.make([1]))) - asserts.new_set_equals(env, sets.make([1]), sets.union(sets.make([1]), sets.make())) - asserts.new_set_equals(env, sets.make([1]), sets.union(sets.make([1]), sets.make([1]))) - asserts.new_set_equals(env, sets.make([1, 2]), sets.union(sets.make([1]), sets.make([1, 2]))) - asserts.new_set_equals(env, sets.make([1, 2]), sets.union(sets.make([1]), sets.make([2]))) + s = sets.make() + s = sets.mutable_union(s, sets.make()) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make() + s = sets.mutable_union(s, sets.make([1])) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make()) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make([1])) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make([1, 2])) + asserts.new_set_equals(env, sets.make([1, 2]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make([2])) + asserts.new_set_equals(env, sets.make([1, 2]), s) # If passing a list, verify that duplicate elements are ignored. - asserts.new_set_equals(env, sets.make([1, 2]), sets.union(sets.make([1, 1]), sets.make([1, 2]))) + s = sets.make([1, 1]) + s = sets.mutable_union(s, sets.make([1, 2])) + asserts.new_set_equals(env, sets.make([1, 2]), s) return unittest.end(env) @@ -132,6 +144,38 @@ def _difference_test(ctx): difference_test = unittest.make(_difference_test) +def _mutable_difference_test(ctx): + """Unit tests for sets.difference.""" + env = unittest.begin(ctx) + + s = sets.make() + s = sets.mutable_difference(s, sets.make()) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make() + s = sets.mutable_difference(s, sets.make([1])) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make()) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make([1])) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make([1, 2])) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make([2])) + asserts.new_set_equals(env, sets.make([1]), s) + + # If passing a list, verify that duplicate elements are ignored. + s = sets.make([1, 2]) + s = s.mutable_difference(s, sets.make([1, 1])) + asserts.new_set_equals(env, sets.make([2]), s) + + return unittest.end(env) + +mutable_difference_test = unittest.make(_mutable_difference_test) + def _to_list_test(ctx): """Unit tests for sets.to_list.""" env = unittest.begin(ctx)