# Applications of Disjoint Sets

If you have a set of N elements partitioned into further subsets, and you have to keep track of the connectivity of each element in a particular subset or connectivity of subsets with eachouther. You can manage connectivity easily with disjoint sets and the union find operation.

[William Fiset's Union Find Video](https://www.youtube.com/watch?v=0jNmHPfA_yE)

## Examples

### Redundant Connection

#### Problem Statement

In this problem, a tree is an undirected graph that is connected and has no cycles.

You are given a graph that started as a tree with `n` nodes labeled from 1 to `n`, with one additional edge added. The added edge has two different vertices chosen from 1 to `n`, and was not an edge that already existed. The graph is represented as an array edges of length `n` where `edges[i] = [ai, bi]` indicates that there is an edge between nodes ai and bi in the graph.

Return an edge that can be removed so that the resulting graph is a tree of n nodes. If there are multiple answers, return the answer that occurs last in the input.

#### Solution

O(n) runtime, O(n) space. Apply union find over all edges and if the current edge is already accounted for in the graph, return the redudnant edge.

In [7]:
def find_redundant_connection(edges):
    parent = [0] * len(edges)
    def find(x):
        if parent[x] == 0:
            return x
        parent[x] = find(parent[x])
        return parent[x]
    def union(x, y):
        root_x = find(x)
        root_y = find(y)
        if root_x == root_y:
            return False
        parent[root_x] = root_y
        return True
    for x, y in edges:
        if not union(x - 1, y - 1):
            return [x, y]
    return [-1, -1]

In [9]:
edges, ans = [[1,2],[1,3],[2,3]], [2,3]
result = find_redundant_connection(edges)
print(result, result == ans)

[2, 3] True


### Find the Duplicate Number

Given an array of integers nums containing `n + 1` integers where each integer is in the range `[1, n]` inclusive.

There is only one repeated number in nums, return this repeated number.

You must solve the problem without modifying the array nums and uses only constant extra space.

#### Problem Statement

[LeetCode #287 (medium)](https://leetcode.com/problems/find-the-duplicate-number/)

#### Solution

O(n) runtime, O(1) space. This solution works by storing the values in their indexes position by making the value negative. This can be used to detect if this value has already been "visited" and is a flattened way of performing union find.

However it modifies the input, if the input can't be modified, use the Floyd-Warshall algorithm for cycle detection.

In [12]:
def find_duplicate(nums):
    for num in nums:
        if nums[abs(num) - 1] < 0:
            return abs(num)
        nums[abs(num) - 1] *= -1
    return -1

In [13]:
nums, ans = [3,1,3,4,2], 3
result = find_duplicate(nums)
print(result, result == ans)

3 True


### Number of Islands

#### Problem Statement

You are given an empty 2D binary grid grid of size `m x n`. The grid represents a map where 0's represent water and 1's represent land. Initially, all the cells of grid are water cells (i.e., all the cells are 0's).

We may perform an add land operation which turns the water at position into a land. You are given an array positions where `positions[i] = [ri, ci]` is the position `(ri, ci)` at which we should operate the ith operation.

Return an array of integers answer where `answer[i]` is the number of islands after turning the cell `(ri, ci)` into a land.

An island is surrounded by water and is formed by connecting adjacent lands horizontally or vertically. You may assume all four edges of the grid are all surrounded by water.

[LeetCode #305 (hard)](https://leetcode.com/problems/number-of-islands-ii/)

#### Solution

O(n) runtime, O(n^2) space. This approach starts by implementing a `UnionFind` class which performs all the basic operations of disjoint sets. This uses an array to flatten the matrix into a disjoint set. For large matrices, it is possible to use a dictionary with a get method and default value to improve space to O(n). However, this solution has the O(n^2) space solution.

In [15]:
class UnionFind:
    def __init__(self, n):
        self.count = 0
        self.parent = [-1]*n
        self.rank = [0]*n
    def is_valid(self, i):
        return self.parent[i] >= 0
    def set_parent(self, i):
        if self.parent[i] == -1:
            self.parent[i] = i
            self.count += 1
    def find(self, i):
        if self.parent[i] != i:
            self.parent[i] = self.find(self.parent[i])
        return self.parent[i]
    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x != root_y:
            if self.rank[root_x] > self.rank[root_y]:
                self.parent[root_y] = root_x
            elif self.rank[root_x] < self.rank[root_y]:
                self.parent[root_x] = root_y
            else:
                self.parent[root_y] = root_x
                self.rank[root_x] += 1
            self.count -= 1
    def get_count(self):
        return self.count

Once you have the `UnionFind` class down, you instantiate it at size `(m * n)` which flattens the matrix into a disjoint set. You can iterate over the provided positions and see which of the neighbors overlap a pre-existing disjoint set. Then you apply the `set_parent` operation the flatten positon in the matrix. For each of the neighbors that are overlapping, you can apply the union find operation with .

In [17]:
def num_islands(m, n, positions):
    ans = []
    uf = UnionFind(m * n)
    for pos in positions:
        r, c = pos[0], pos[1]
        overlap = []
        if r - 1 >= 0 and uf.is_valid((r - 1) * n + c):
            overlap.append((r - 1) * n + c)
        if r + 1 < m and uf.is_valid((r+1) * n + c):
            overlap.append((r + 1) * n + c)
        if c - 1 >= 0 and uf.is_valid(r * n + c - 1):
            overlap.append(r * n + c - 1)
        if c + 1 < n and uf.is_valid(r * n + c + 1):
            overlap.append(r * n + c + 1)
        idx = r * n + c
        uf.set_parent(idx)
        for i in overlap:
            uf.union(i, idx)
        ans.append(uf.get_count())
    return ans

In [18]:
m, n, positions, ans = 3, 3, [[0,0],[0,1],[1,2],[2,1]], [1,1,2,3]
result = num_islands(m, n, positions)
print(result, result == ans)

[1, 1, 2, 3] True


### Skyline Silhouette

A city's skyline is the outer contour of the silhouette formed by all the buildings in that city when viewed from a distance. Given the locations and heights of all the buildings, return the skyline formed by these buildings collectively.

The geometric information of each building is given in the array buildings where `buildings[i] = [lefti, righti, heighti]`:

`lefti` is the xth coordinate of the left edge of the ith building.
`righti` is the xth coordinate of the right edge of the ith building.
`heighti` is the height of the ith building.
You may assume all buildings are perfect rectangles grounded on an absolutely flat surface at height 0.

The skyline should be represented as a list of "key points" sorted by their x-coordinate in the form `[[x1,y1],[x2,y2],...]`. Each key point is the left endpoint of some horizontal segment in the skyline except the last point in the list, which always has a y-coordinate 0 and is used to mark the skyline's termination where the rightmost building ends. Any ground between the leftmost and rightmost buildings should be part of the skyline's contour.

Note: There must be no consecutive horizontal lines of equal height in the output skyline. For instance, `[...,[2 3],[4 5],[7 5],[11 5],[12 7],...]` is not acceptable; the three lines of height 5 should be merged into one in the final output as such: `[...,[2 3],[4 5],[12 7],...]`.

[LeetCode #218 (hard)](https://leetcode.com/problems/the-skyline-problem/)

#### Solution

O(nlogn) runtime (only because sorting, rest is O(n)), O(n) space.

[Visualization/Explanation](https://briangordon.github.io/2014/08/the-skyline-problem.html)

In [19]:
from heapq import heappop, heappush

def get_skyline(buildings):
    def add_height(pos, hei):
        if heights[-1][1] != hei:
            heights.append([pos, hei])
    positions = set([b[0] for b in buildings] + [b[1] for b in buildings])
    i, curr, heights = 0, [], [[-1, 0]]
    for t in sorted(positions):
        while i < len(buildings) and buildings[i][0] <= t:
            heappush(curr, (-buildings[i][2], buildings[i][1]))
            i += 1
        while curr and curr[0][1] <= t:
            heappop(curr)
        new_height = -curr[0][0] if curr else 0
        add_height(t, new_height)
    return heights[1:]

In [20]:
buildings, ans = [[2,9,10],[3,7,15],[5,12,12],[15,20,10],[19,24,8]], [[2,10],[3,15],[7,12],[12,0],[15,10],[20,8],[24,0]]
result = get_skyline(buildings)
print(result, result == ans)

[[2, 10], [3, 15], [7, 12], [12, 0], [15, 10], [20, 8], [24, 0]] True
