# Union Find
[Disjoint Set Union (DSU)/Union-Find - A Complete Guide](https://leetcode.com/discuss/general-discussion/1072418/Disjoint-Set-Union-(DSU)Union-Find-A-Complete-Guide)

Disjoint Sets: intersection of any two sets is NULL

two operations to solve the problem
1. Combine two given sets
2. Tell about the connectivity of two elements. Whether they are in the same set or not


In [None]:
# find
def find(u: int) -> int:
    if (u == parent[u]):
        return u
    else:
        return find(parent[u])

# combine
def combine(u:int, v:int):
    u = find(u)
    v = find(v)
    if (u==v):
        # already in the same set
        return
    else:
        parent[v] = u

In [1]:
# optimize union-and-find
# Path compression
def find(u: int) -> int:
    if (u != parent[u]):
        parent[u] = find(parent[u])
    return parent[u]

# union by size
def combine(u:int, v:int):
    u = find(u)
    v = find(v)
    if (u == v):
        return
    else:
        if size[u] > size[v]:
            parent[v] = u
            size[u] += size[v]
        else:
            parent[u] = v
            size[v] += size[u]


## [684\. Redundant Connect](https://leetcode.com/problems/redundant-connection/)

return an edge that can be removed so that the resulting graph is a tree of n nodes.

If there are multiple answers, return the answer that occurs last in input.

### Solution 1. DFS

In [2]:
import collections
class Solution:
    def findRedundantConnection(self, edges):
        graph = collections.defaultdict(set)
        
        # check if we can connect from u to v
        def dfs(source, target):
            if source not in seen:
                seen.add(source)
                # find the source: the edge is redundant
                if source == target:
                    return True

                for neighbor in graph[source]:
                    if dfs(neighbor, target):
                        return True
                return False
        
        for u, v in edges:
            # set of seen nodes
            seen = set()
            # node already exists
            if u in graph and v in graph and dfs(u,v):
                return u, v
            # add to graph
            graph[u].add(v)
            graph[v].add(u)
            

### Solution 2. Union Find

Union Find is a data structure that keeps track of elements which are split into one or more disjoint sets. Its has two primary operations: "find" and "union".

[Video](https://youtu.be/wU6udHRIkcc)

#### Find

Find out which set the element belong.

If they belong to different set, perform union.

If they belong in the same set, there's a cycle.

#### Union

Perform Union on two set.

#### Union Find used in

* Kruskal's minimum spanning tree algorithm
* Grid percolation
* Network connectivity, detect cycle
* Least common Ancestor in trees
* Image processing


#### Time Complexity

* Construction: $O(n)$
* Union: $\alpha(n)$
* Find: $\alpha(n)$
* Get component size: $\alpha(n)$
* Check if connected: $\alpha(n)$
* Count components: $O(1)$

$\alpha(n)$: [Amortized constant time](https://stackoverflow.com/questions/200384/constant-amortized-time)

In [None]:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
    
    parent = [i for i in range(len(edges) + 1)]
    print(parent)
    def find(x: int) -> int:
        # recursively trace back to ultimate parent
        if x != parent[x]:
            parent[u] = find(parent[u])
        return parent[x]
    
    def union(x: int, y: int) -> None:
        # merges two set by assinging the same ultimate parent
        parent[find(y)] = find(x)
    
    for a, b in edges:
        if find(a) == find(b):
            # two nodes in the same set, cycle!
            return [a, b]
        else:
            # two nodes not in the same set, union
            union(a, b)

In [None]:
# 128. Longest Consecutive Sequence

def longestConsecutive(nums: List[int]) -> int:
    def find(i):
        if i != parent[i]:
            parent[i] = find(parent[i])
        return parent[i]
    def union(i, j):
        pi, pj = find(i), find(j)
        if pi != pj:
            if rank[pi] >= pj:
                parent[pj] = pi
                rank[pi] += 1
            else:
                parent[pi] = pj
                rank[pj] += 1

    if not nums:
        return 0

    # initialize parent and rank
    parent, rank, nums = {}, {}, set(nums)
    for num in nums:
        parent[num] = num
        rank[num] = 0
    # union nums[i] and consecutive numbers    for num in nums:
        if num - 1 in nums:
            union(num-1, num)
        if num + 1 in nums:
            union(num+1, num)

    d = collections.defaultdict(list)
    for num in nums:
        d[find(num)].append(num)
    return max([len(l) for l in d.values()])

In [None]:
# 200. Number of Island

def numIslands(self, grid: List[List[str]]) -> int:
    if len(grid) == 0:
        return 0
    row = len(grid)
    col = len(grid[0])
    count = sum(grid[i][j] == '1' for i in range(row) for j in range(col))
    parent = [i for i in range(row*col)] # turn grid to 1-d union

    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
    
    def union(x, y, count):
        xRoot, yRoot = find(x), find(y)
        if xRoot == yRoot:
            return
        parent[xRoot] = yRoot
        count += 1
    
    for i in range(row):
        for j in range(col):
            if grid[i][j] == '0':
                continue
            index = i*col + j
            if j < col - 1 and grid[i][j] == '1':
                union(index, index+1, count) # union with right element
            if i < row - 1 and grid[i][j] == '1':
                union(index, index+col, count) # union with down element
    return count
