In [1]:
from typing import List, TypeVar, Generic

T = TypeVar('T')

class QuickUnion(Generic[T]):
    def __init__(self, length):
        self.collection = list(range(length))
        self.size = [1] * length
    
    def get_root(self, index: int) -> int:
        while index != self.collection[index]:
            self.collection[index] = self.collection[self.collection[index]]
            index = self.collection[index]
        
        return index
    
    def is_connected(self, a: int, b: int) -> bool:
        return self.get_root(a) == self.get_root(b)
    
    def union(self, a: int, b: int):
        a_root, b_root = self.get_root(a), self.get_root(b)
        
        if a_root != b_root:
            if self.size[a_root] < self.size[b_root]:
                self.size[b_root] += self.size[a_root]
                self.collection[a_root] = b_root
            else:
                self.size[a_root] += self.size[b_root]
                self.collection[b_root] = a_root

    def __str__(self) -> str:
        return f"collection: {self.collection}\nsize: {self.size}"

In [2]:
DIRECTIONS = [[0, 1], [1, 0]]

In [3]:
def translate_to_index(grid: List[List[int]], x: int, y: int) -> int:
    num_of_cols = len(grid[0])
    
    return num_of_cols * x + y

In [4]:
def is_in_grid(grid: List[List[int]], x: int, y: int) -> int:
    return 0 <= x <= len(grid) - 1 and 0 <= y <= len(grid[0]) - 1

In [5]:
def find_max_connected_cell_in_grid(grid):
    quick_union: QuickUnion[int] = QuickUnion(len(grid) * len(grid[0]))
    
    for x in range(len(grid)):
        for y in range(len(grid[0])):
            for [dx, dy] in DIRECTIONS:
                new_x, new_y = x + dx, y + dy

                if is_in_grid(grid, new_x, new_y): 
                    if grid[x][y] == grid[new_x][new_y]:
                        quick_union.union(
                            translate_to_index(grid, x, y),
                            translate_to_index(grid, new_x, new_y)
                        )
    
    return quick_union

In [6]:
quick_union = find_max_connected_cell_in_grid([
    [0, 0, 2, 1, 1],
    [0, 2, 0, 1, 0],
    [0, 2, 0, 0, 0]
])

In [7]:
print(quick_union)

collection: [0, 0, 2, 3, 3, 0, 6, 7, 3, 7, 0, 6, 7, 7, 9]
size: [4, 1, 1, 3, 1, 1, 2, 5, 1, 2, 1, 1, 1, 1, 1]
