In [1]:
from copy import deepcopy
from collections import defaultdict

In [116]:
True + True

2

In [367]:


def is_adjacent(coord1, coord2):
    i_diff = abs(coord1[0] - coord2[0])
    j_diff = abs(coord1[1] - coord2[1])
    return (i_diff == 1 and j_diff == 0) or (i_diff == 0 and j_diff == 1)

def look_around(coord):
    i = coord[0]
    j = coord[1]
    return set([
        (i + 1, j),
        (i - 1, j),
        (i, j + 1),
        (i, j - 1)
    ])

def extrapolate_vertices(coord:tuple):
    i = coord[0]
    j = coord[1]
    return [(i, j), (i + 1, j), (i, j + 1), (i + 1, j + 1)]

def find_vertices_in_group(group:list):
    all_vertices = defaultdict(int)
    for item in group:
        vertices = extrapolate_vertices(item)
        for vertex in vertices:
            all_vertices[vertex] += 1
    return all_vertices



def calc_perimeter(group:list):
    total_perimeter = 0
    for item in group:
        total_surrounding = 0
        for other_item in group:
            if is_adjacent(item, other_item):
                total_surrounding += 1
        total_perimeter += (4 - total_surrounding)

    return total_perimeter

class Day12:
    def __init__(self, fname):
        with open(fname) as f:
            self.input = [[i for i in row.strip()] for row in f.readlines()]
        self.dimensions = (len(self.input), len(self.input[0]))
        self.create_plant_dict()
        self.group_plants()
        self.create_plant_vertex_dict()
        self.create_corner_dict()
        self.exceptions = []
        self.calc_all_flats()

    def check_edge(self, group):
        for coord in group:
            if coord[0] == 0 or coord[1] == 0:
                return True
            if coord[0] >= self.dimensions[0] or coord[1] >= self.dimensions[1]:
                return True
        return False
        
    def create_plant_dict(self):
        plant_dict = defaultdict(list)
        for i, row in enumerate(self.input):
            for j, item in enumerate(row):
                plant_dict[item].append((i,j))
        self.plant_dict = plant_dict

    def create_plant_vertex_dict(self):
        plant_vertex_dict = dict()
        for letter, groups in self.grouped_plants.items():
            plant_vertex_dict[letter] = dict()
            for starter, group in groups.items():
                plant_vertex_dict[letter][starter] = find_vertices_in_group(group)
        self.plant_vertex_dict = plant_vertex_dict

    def create_corner_dict(self):
        corner_vertex_dict = dict()
        for letter, groups in self.plant_vertex_dict.items():
            corner_vertex_dict[letter] = dict()
            for starter, group in groups.items():
                corner_vertex_dict[letter][starter] = set(g for g in group if group[g] % 2 == 1)
        self.corner_vertex_dict = corner_vertex_dict

    def is_exception(self, vertex):
        # Find those exceptional cases where the edges of two inner gardens touch on the corner,
        # which causes a violation of the rule that 2 items sharing a vertex represents a straigth line
        i, j = vertex
        up_left = self.input[i-1][j-1] if j-1 >= 0 and i-1 >= 0 else None
        up_right = self.input[i][j-1] if j-1 >= 0 and i < self.dimensions[0] else None
        down_left = self.input[i-1][j] if i-1 >= 0 and j < self.dimensions[1] else None
        down_right = self.input[i][j] if i < self.dimensions[0] and j < self.dimensions[1] else None
        if (up_left == down_right) or (up_right == down_left):
            return True

    def calc_flats(self, all_vertices, letter):
        total_corners = 0
        for vertex, ct in all_vertices.items():
            if ct % 2 == 1:
                total_corners += 1
            elif ct == 2:
                if self.is_exception(vertex):
                    total_corners += 2
                    self.exceptions.append(vertex)
                    self.exceptions.append(letter)
        flats = total_corners
        return flats

    def calc_all_flats(self):
        flat_count_dict = dict()
        for plant, groups in self.plant_vertex_dict.items():
            flat_count_dict[plant] = dict()
            for starter, vertices in groups.items():
                flat_count_dict[plant][starter] = self.calc_flats(vertices, plant)
        self.flat_count_dict = flat_count_dict

    def group_plant(self, plant):
        remaining_coords = set(self.plant_dict[plant]) # TODO: DO I need this to be a deepcopy?
        gardens = defaultdict(list)
        while len(remaining_coords) > 0:
            starter = list(remaining_coords)[0]
            gardens[starter] = [starter]
            remaining_coords.remove(starter)
            compare_coords = set([starter])
            while len(compare_coords) > 0:
                coords_to_add = set()
                for compare_coord in compare_coords:
                    adjacent = look_around(compare_coord)
                    coords_to_remove = set()
                    for remaining_coord in remaining_coords:
                        if remaining_coord in adjacent:
                            coords_to_remove.add(remaining_coord)
                            coords_to_add.add(remaining_coord)
                            gardens[starter] += [remaining_coord]
                    remaining_coords -= coords_to_remove
                compare_coords = coords_to_add
        return gardens
    
    def group_plants(self):
        self.grouped_plants = dict()
        for plant in self.plant_dict:
            self.grouped_plants[plant] = self.group_plant(plant)

    def calculate_total(self, discount = False):
        total = 0
        for plant, groups in self.grouped_plants.items():
            # print(plant)
            for starter, group in groups.items():
                if not discount:
                    perimeter = calc_perimeter(group)
                    area = len(group)
                    total += perimeter * area
                else:
                    flats = self.flat_count_dict[plant][starter]
                    area = len(group)
                    total += flats * area
        return total

        

In [368]:
day12_test = Day12('data/test.txt')
print(day12_test.calculate_total())
print(day12_test.calculate_total(discount = True))

1930
1206


In [369]:
day12 = Day12('data/input.txt')
print(day12.calculate_total())
print(day12.calculate_total(discount = True))

1363682
787680
