In [1]:
# will be shared between solutions
class Util:
    def validTree(self, n, edges):
        if len(edges) != n - 1:
            return False
        
    def arr_to_alist(self, edges):
        alist = {}
        for frm, to in edges:
            alist[frm] = alist.get(frm, set()) | {to}
            alist[to] = alist.get(to, set()) | {frm}
        return alist

# Detecting cyclicity in a directed graph 
### Method 1 (deleting connections)

On an undirected graph, like the one we're working with here, trivial "cycles" will be detected. For example, if there's an undirected edge between node A and node B, a detected cycle will include A → B → A. This is because an undirected edge is actually 2 edges in the adjacency list, and so forms a trivial cycle.

There are several strategies of detecting whether or not an undirected graph contains cycles, while excluding the trivial cycles. First one would be to delete the opposite direction edges from the adjacency list. In other words, when we follow an edge A → B, we should lookup Bs adjacency list and delete A from it, effectively removing the opposite edge of B → A

In [2]:
class Solution(Util):
    def validTree(self, n, edges):
        if len(edges) != n - 1:
            return False

        def dfs(alist):
            stack = [next(iter(alist))]
            visited = {stack[-1]}
            while stack:
                node = stack.pop()
                # prev location of visited (on line 15, visited was instantiated empty)
                # this implementation caused the problem of not detecting cycles properly.
                for neighbour in alist[node]:
                    if neighbour in visited:
                        return False
                    visited.add(neighbour)
                    alist[neighbour].remove(node)
                    stack.append(neighbour)
            return True

        alist = self.arr_to_alist(edges)
        if not alist: return True
        return dfs(alist)
        
print(Solution().validTree(5, [[0,1],[0,4],[2,3]])) # no cycle, yes disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,4],[1,4],[2,3]])) # has a cycle and disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,2],[0,3],[1,4]])) # no cycle -> True

False
False
True


### Method 2 (storing parent data)

The second strategy is, instead of using a seen set, to use a seen map that also keeps track of the "parent" node that we got to a node from. We'll call this map parent. Then, when we iterate through the neighbours of a node, we ignore the "parent" node as otherwise it'll be detected as a trivial cycle (and we know that the parent node has already been visited by this point anyway). The starting node (0 in this implementation) has no "parent", so put it as -1.

In [3]:
class Solution(Util):
    def validTree(self, n, edges):
        if len(edges) != n - 1:
            return False

        def dfs(alist):
            starting_node = next(iter(alist))
            parent = {starting_node: -1}
            visited = {starting_node}
            stack = [starting_node]
            while stack:
                node = stack.pop()
                for neighbour in alist[node]:
                    if neighbour == parent[node]:
                        continue
                    if neighbour in visited:
                        return False
                    visited.add(neighbour)
                    parent[neighbour] = node
                    stack.append(neighbour)
            return True
        alist = self.arr_to_alist(edges)
        if not alist: return True
        return dfs(alist)
        
print(Solution().validTree(5, [[0,1],[0,4],[2,3]])) # no cycle, yes disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,4],[1,4],[2,3]])) # has a cycle and disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,2],[0,3],[1,4]])) # no cycle -> True

False
False
True


### Method 3 (Union Find - Not optimised)

In [4]:
class Solution(Util):
    def validTree(self, n, edges):
        if len(edges) != n - 1:
            return False

        def make_set(n):
            return [node for node in range(n)]

        def find(parent, a):
            while parent[a] != a:
                a = parent[a]
            return a
        
        def union(parent, a, b):
            root_a = find(parent, a)
            root_b = find(parent, b)
            if root_a == root_b:
                return False
            parent[root_a] = root_b
            return True
        
        parent = make_set(n)
        for a, b in edges:
            if not union(parent, a, b):
                return False
        return True

print(Solution().validTree(5, [[0,1],[0,4],[2,3]])) # no cycle, yes disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,4],[1,4],[2,3]])) # has a cycle and disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,2],[0,3],[1,4]])) # no cycle -> True

False
False
True


### Method 4 (Union Find - Optimised)

In [5]:
class Solution(Util):
    def validTree(self, n, edges):
        if len(edges) != n - 1:
            return False

        def make_set(n):
            parent = [node for node in range(n)]
            size = [1] * n
            return parent, size
        
        # optimisation 1 - path compression
        # Step 1: Find the root.
        # Step 2: Do a second traversal, this time setting each node to point
        # directly at A as we go.
        def find(parent, a):
            root = a
            while root != parent[root]:
                root = parent[root]
            while a != root:
                old_parent = parent[a]
                parent[a] = root
                a = old_parent
            return root
        
        # optimisation 2 - union by size
        # adding the smaller node set to bigger node set
        def union(a, b, parent, size):
            root_a = find(parent, a)
            root_b = find(parent, b)
            if root_a == root_b:
                return False
            if size[root_a] > size[root_b]:
                parent[root_b] = root_a
                size[root_a] += size[root_b]
            else:
                parent[root_a] = root_b
                size[root_b] += root_a
            return True

        parent, size = make_set(n)
        for a, b in edges:
            if not union(a, b, parent, size):
                return False
        return True

print(Solution().validTree(5, [[0,1],[0,4],[2,3]])) # no cycle, yes disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,4],[1,4],[2,3]])) # has a cycle and disconnected components -> False
print(Solution().validTree(5, [[0,1],[0,2],[0,3],[1,4]])) # no cycle -> True

False
False
True
