Skip to content

Commit

Permalink
Fixed issue by specifying public interface
Browse files Browse the repository at this point in the history
- Improved also robustness of data structure
  • Loading branch information
nbro committed Mar 8, 2017
1 parent 8ec1b9c commit 2bb1dc8
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 249 deletions.
171 changes: 117 additions & 54 deletions ands/ds/DSForests.py → ands/ds/DisjointSetsForest.py
Expand Up @@ -8,7 +8,7 @@
Created: 21/02/2016
Updated: 03/01/2016
Updated: 08/03/2017
# Description
Expand All @@ -23,19 +23,32 @@
3. union(x, y): unions the sets where x and y are (if they do not belong already to the same set).
`DSForests` uses two heuristics that improve the performance with respect to a naive implementation.
`DisjointSetsForest` uses two heuristics that improve the performance with respect to a naive implementation.
1. Union by rank: attach the smaller tree to the root of the larger tree
2. Path compression: is a way of flattening the structure of the tree whenever find is used on it.
These two techniques complement each other: applied together, the amortized time per operation is only O( α (n)).
## Public interface
- make_set(x): add x to the DisjointSetsForest
- find(x): returns the root or representative of x
- union(x, y): unites the sets where x and y reside
- print_set(x): prints the set where x resides to the standard output
- contains(x): check if x is in the DisjointSetsForest
- size: returns the number of elements add to the data structure (using make_set)
- sets: returns the number of disjoint sets currently in the data structure
All other methods or fields are considered private and were NOT intended to be used by clients!
# TODO
- Deletion operation (OPTIONAL, since it's usually not part of the interface of a disjoint-set data structure)
- Pretty-print(x), for some element x in the disjoint-set data structure.
- Implement the version explained [here](http://algs4.cs.princeton.edu/15uf/)
- Add complexity analysis for print_set
# References
Expand All @@ -45,25 +58,32 @@
- [http://stackoverflow.com/a/22945492/3924118](http://stackoverflow.com/a/22945492/3924118)
- [http://stackoverflow.com/q/23055236/3924118](http://stackoverflow.com/q/23055236/3924118)
- [https://www.cs.usfca.edu/~galles/JavascriptVisual/DisjointSets.html](https://www.cs.usfca.edu/~galles/JavascriptVisual/DisjointSets.html)
to visualize how disjoint-sets work.
to visualize how disjoint-_sets work.
"""

__all__ = ["DisjointSetsForest"]


class DSFNode:
"""DSFNode is the node used internally by `DisjointSetsForest`
to represent nodes in the disjoint trees (or sets).
Clients should NOT need to use this class."""

class DSNode:
def __init__(self, x, rank=0):
# This attribute can contain any hashable value.
self.value = x

# The rank of node x only changes in one specific union(x, y) case:
# when x is the representative of its set
# and the representative of the set where y resides has the same rank as x.
# In the DSForests implementation below, if a situation as just described occurs,
# In the DisjointSetsForest implementation below, if a situation as just described occurs,
# then the x.rank is increased by 1.
self.rank = rank

# Reference to the representative of the set where this node resides
# Since DSForests actually implements a tree,
# Since DisjointSetsForest actually implements a tree,
# self.parent is also the root of that tree.
self.parent = self

Expand All @@ -73,9 +93,9 @@ def __init__(self, x, rank=0):
self.next = self

def is_root(self):
"""A DSNode x is a root or representative of a set
"""A DSFNode x is a root or representative of a set
whenever its parent pointer points to himself.
Of course this is only true if x is already in a DSForests object."""
Of course this is only true if x is already in a DisjointSetsForest object."""
return self.parent == self

def __str__(self):
Expand All @@ -88,18 +108,48 @@ def __repr__(self):
return "(value: {0}, rank: {1}, parent: {2})".format(self.value, self.rank, self.parent)


class DSForests:
def __init__(self):
# keeps tracks of the DSNodes in this disjoint-set forests.
self.sets = {}
class DisjointSetsForest:
"""Disjoint-set forests is a collection of disjoint sets.
Two sets A and B are disjoint if they have no element in common,
or, in other words, their intersection is the empty set.
It's called forest because the way the disjoint set data structure is implemented,
that is it's implemented by representing a forest of trees.
A disjoint-set data structure can be implemented differently.
def make_set(self, x) -> DSNode:
"""Creates a set object for `x`."""
assert x not in self.sets
self.sets[x] = DSNode(x)
return self.sets[x]
This data structure does not allow duplicates."""

def find(self, x: DSNode) -> DSNode:
def __init__(self):
# keeps tracks of the DSNodes in this disjoint-set forests.
self._sets = {}
self._n = 0

def make_set(self, x: object) -> None:
"""Creates a set object for `x`.
If `x` is already in self, then `ValueError` is raised."""
if x in self._sets:
raise LookupError("x is already in self")
self._sets[x] = DSFNode(x)
self._n += 1
assert 0 <= self.sets <= self.size

@property
def size(self) -> int:
"""Returns the number of elements in this DisjointSetsForest."""
return len(self._sets)

@property
def sets(self) -> int:
"""Returns the number of disjoint sets in `self`."""
return self._n

def contains(self, x: object) -> bool:
"""Returns True if x is in self, False otherwise."""
return x in self._sets

def _find(self, x: DSFNode) -> DSFNode:
"""Finds and returns the representative (or root) of `x`.
It follows parent nodes until it reaches
the root of the tree (set) to which `x` belongs.
Expand Down Expand Up @@ -128,24 +178,22 @@ def find(self, x: DSNode) -> DSNode:
&alpha; (n) is less than 5 for all remotely practical values of n.
Thus, the amortized running time per operation
is effectively a small constant."""
assert x
assert x is not None
if x.parent != x:
x.parent = self.find(x.parent)
x.parent = self._find(x.parent)
return x.parent

def find_iteratively(self, x: DSNode) -> DSNode:
@staticmethod
def _find_iteratively(x: DSFNode) -> DSFNode:
"""This version is just an iterative alternative to the find method."""
assert x
assert x is not None

y = x

# find the representative of the set where x resides
while y != y.parent:
y = y.parent

# post-condition
assert y == self.find(x)

# now y is the representative of x,
# but we also want to do a path compression,
# i.e. connect all nodes in the path from x to y directly to y.
Expand All @@ -156,24 +204,35 @@ def find_iteratively(self, x: DSNode) -> DSNode:

return y

def union(self, x, y) -> DSNode:
""""Union by rank" 2 trees (sets) into one by attaching
def find(self, x: object) -> object:
"""Finds and returns the representative (or root) of `x`.
Raises a `LookupError` if `x` does not belong to this `DisjointSetsForest`.
**Time Complexity:** O*(&alpha; (n))."""
if x not in self._sets:
raise LookupError("x is not in self")
x_root = self._find(self._sets[x]).value
assert x_root == DisjointSetsForest._find_iteratively(self._sets[x]).value
return x_root

def union(self, x: object, y: object) -> object:
""""Union by rank" 2 sets into one by attaching
the root of one to the root of the other.
Returns the `DSNode` object representing the representative of
Returns the root object representing the representative of
the set resulted from the union of the sets containing `x` and `y`.
It returns None if `x` and `y` are already in the same set.
"Union by rank" consists of attaching the smaller tree
to the root of the larger tree.
Since it is the depth of the tree that affects the running time,
the tree with smaller depth gets added
under the root of the deeper tree,
the tree with smaller depth gets added under the root of the deeper tree,
which only increases the depth if the depths were equal.
In the context of this algorithm,
the term _rank_ is used instead of depth,
since it stops being equal to the depth
if path compression is also used.
In the context of this algorithm, the term _rank_ is used instead of depth,
since it stops being equal to the depth if path compression is also used.
The rank is an upper bound on the height of the node.
Expand All @@ -188,48 +247,52 @@ def union(self, x, y) -> DSNode:
&alpha; (n) is less than 5 for all remotely practical values of n.
Thus, the amortized running time per operation
is effectively a small constant."""
assert x in self.sets and y in self.sets
if x not in self._sets:
raise LookupError("x is not in self")
if y not in self._sets:
raise LookupError("y is not in self")

# Since the original values x and y are not used afterwards,
# and what we actually need in two places of this algorithm are the corresponding DSNodes
# we set x and y to be respectively their DSNode counter-part.
x = self.sets[x]
y = self.sets[y]
x_node = self._sets[x]
y_node = self._sets[y]

x_root = self.find(x)
y_root = self.find(y)
x_root = self._find(x_node)
y_root = self._find(y_node)

# x and y are already joined.
if x_root == y_root:
return

# Exchanging the next pointers of x and y.
# Exchanging the next pointers of x_node and y_node.
# This is needed in order to print the elements of a set in O(m) time,
# where m is the size of the same set.
# where m is the size of the same set, in self.print_set.
# Check here: http://stackoverflow.com/a/22945492/3924118.
x.next, y.next = y.next, x.next
x_node.next, y_node.next = y_node.next, x_node.next

self._n -= 1
assert 0 <= self.sets <= self.size

# x and y are not in the same set, therefore we merge them.
if x_root.rank < y_root.rank:
x_root.parent = y_root
return y_root
return y_root.value
else:
y_root.parent = x_root
if x_root.rank == y_root.rank:
x_root.rank += 1
return x_root
return x_root.value

def print_set(self, x) -> None:
assert x in self.sets
def print_set(self, x: object) -> None:
if x not in self._sets:
raise LookupError("x is not in self")

x = self.sets[x]
y = x
x_node = self._sets[x]
y = x_node

print("{0} -> {{{1}".format(x, x), end="")
while y.next != x:
print("{0} -> {{{1}".format(x_node, x_node), end="")
while y.next != x_node:
print(",", y.next, end="")
y = y.next
print("}")

def __str__(self):
return str(self.sets)
return str(self._sets)

0 comments on commit 2bb1dc8

Please sign in to comment.