In [1]:
# Leetcode link: https://leetcode.com/explore/featured/card/graph/618/disjoint-set/3846/

![](./graph_valid_tree.png)

In [9]:
# Approach 3 Optimized
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        # Use a rank array to record the height of each vertex, i.e,, the "rank" of each vertex
        # The initial "rank" of each vertex is 1, because each of them is
        # a standalone vertex with no connection to other vertices.
        self.rank = [1] * size
        
    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        if self.root[x] == x:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]
    
    # The union function with union by rank (weight)
    # It returns True if a union happened, False otherwise.
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1
            
            return True
 

In [12]:
class Solution:
    def validTree(self, n: int, edges) -> bool:
        # condition 1: The graph must contain n - 1 edges.
        if len(edges) != n - 1: return False
        
        # Create a new UnionFind object with n nodes
        uf = UnionFind(n)
        
        # Add each edge. Check if  a merge happened, because if it didn't there must be a cycle.
        for A, B in edges:
            if not uf.union(A,B):
                return False
        
        return True
                

In [14]:
solution = Solution()

In [19]:
n = 5
edges = [[0,1],[0,2],[0,3],[1,4]]
solution.validTree(n, edges)

True