## Problem 1

In [2]:
def read_input(filename):
    f = open(f'../inputs/{filename}.txt', 'r')
    mtx = []
    while True:
        line = f.readline()
        if line == '':
            break
        mtx_row = []
        for letter in line.strip():
            mtx_row.append(letter)
        mtx.append(mtx_row)
    f.close()
    return mtx

In [121]:
def list_connected_comp_rec(graph, node, seen):
    seen[node] = True
    if node not in graph:
        return [node]

    node_list = [node]
    for neighbor in graph[node]:
        if neighbor not in seen:
            node_list += list_connected_comp_rec(graph, neighbor, seen)
    return node_list
        
def list_connected_comp(graph, node):
    seen = {}
    return list_connected_comp_rec(graph, node, seen)

In [122]:
class Garden:
    def __init__(self, mtx):
        self.plot_mtx = mtx
        self.n = len(mtx)
        self.m = len(mtx[0])
        self.__first_plots = []
        self.__find_first_plots_of_regions()

    def get_plot(self, i, j):
        if i < 0 or i >= self.n or j < 0 or j >= self.m:
            return None
        return self.plot_mtx[i][j]

    def __add_to_first_plot_graph(self, first_plot_graph, pos1, pos2):
        if pos1 == pos2:
            return
        if pos1 not in first_plot_graph:
            first_plot_graph[pos1] = []
        first_plot_graph[pos1].append(pos2)
        if pos2 not in first_plot_graph:
            first_plot_graph[pos2] = []
        first_plot_graph[pos2].append(pos1)

    def __find_component(self, first_plot_graph, pos, components_already_found):
        if pos in components_already_found:
            return components_already_found[pos]

        component = list_connected_comp(first_plot_graph, pos)
        component.sort()
        first = component[0]
        for node in component:
            if node == first:
                continue
            components_already_found[node] = first

        return first
    
    def __find_first_plots_of_regions(self):
        temp_first_plots = []
        first_plot_graph = {}
        for i in range(self.n):
            temp_first_plots.append([])
            for j in range(self.m):
                current_plot = self.get_plot(i,j)
                if self.get_plot(i,j-1) == current_plot:
                    temp_first_plots[i].append(temp_first_plots[i][j-1])
                    if self.get_plot(i-1,j) == current_plot:
                        self.__add_to_first_plot_graph(first_plot_graph, temp_first_plots[i][j], temp_first_plots[i-1][j])
                elif self.get_plot(i-1,j) == current_plot:
                    temp_first_plots[i].append(temp_first_plots[i-1][j])
                else:
                    temp_first_plots[i].append((i,j))

        components_already_found = {}
        for i in range(self.n):
            for j in range(self.m):
                temp_first_plots[i][j] = self.__find_component(first_plot_graph, temp_first_plots[i][j], components_already_found)
        self.__first_plots = temp_first_plots

    def first_plot_of_region(self, i, j):
        return self.__first_plots[i][j]

In [88]:
def solve1(input_filename):
    mtx = read_input(input_filename)
    garden = Garden(mtx)

    area = {}
    perim = {}
    prev_line = []
    row = 0
    for i in range(garden.n):
        for j in range(garden.m):
            plot = garden.get_plot(i,j)
            first_plot_pos = garden.first_plot_of_region(i,j)
            
            if first_plot_pos not in area:
                area[first_plot_pos] = 0
            area[first_plot_pos] += 1

            if first_plot_pos not in perim:
                perim[first_plot_pos] = 0
            if garden.get_plot(i,j-1) != plot:
                perim[first_plot_pos] += 1
            if garden.get_plot(i,j+1) != plot:
                perim[first_plot_pos] += 1
            if garden.get_plot(i-1,j) != plot:
                perim[first_plot_pos] += 1
            if garden.get_plot(i+1,j) != plot:
                perim[first_plot_pos] += 1

    count = 0
    for region in area:
        count += area[region] * perim[region]
    return count

## Problem 2

In [134]:
def solve2(input_filename):
    mtx = read_input(input_filename)
    garden = Garden(mtx)

    area = {}
    perim = {}
    prev_line = []
    row = 0
    for i in range(garden.n):
        for j in range(garden.m):
            plot = garden.get_plot(i,j)
            first_plot_pos = garden.first_plot_of_region(i,j)
            
            if first_plot_pos not in area:
                area[first_plot_pos] = 0
            area[first_plot_pos] += 1

            if first_plot_pos not in perim:
                perim[first_plot_pos] = 0
            if garden.get_plot(i,j-1) != plot and (garden.get_plot(i-1,j) != plot or garden.get_plot(i-1,j-1) == plot):
                perim[first_plot_pos] += 1
            if garden.get_plot(i,j+1) != plot and (garden.get_plot(i-1,j) != plot or garden.get_plot(i-1,j+1) == plot):
                perim[first_plot_pos] += 1
            if garden.get_plot(i-1,j) != plot and (garden.get_plot(i,j-1) != plot or garden.get_plot(i-1,j-1) == plot):
                perim[first_plot_pos] += 1
            if garden.get_plot(i+1,j) != plot and (garden.get_plot(i,j-1) != plot or garden.get_plot(i+1,j-1) == plot):
                perim[first_plot_pos] += 1

    count = 0
    for region in area:
        count += area[region] * perim[region]
    return count