In [93]:
from collections import defaultdict
from itertools import combinations


def read_file():
    with open('day08_input.txt', 'r') as file:
        grid = [list(line.strip()) for line in file]
    return grid

def parse_grid_for_nodes(grid):
    nodes = defaultdict(set)
    rows, cols = len(grid), len(grid[0])
    for row in range(rows):
        for col in range(cols):
            if grid[row][col] == '.':
                continue

            nodes[grid[row][col]].add((row, col))

    return nodes

def calculate_deltas(coords, bounds):
    results = set()

    for (x1, y1), (x2, y2) in combinations(coords, 2):
        delta_x, delta_y = x1 - x2, y1 - y2
        
        new_points = {
            (x1 + delta_x, y1 + delta_y),
            (x1 - delta_x, y1 - delta_y),
            (x2 + delta_x, y2 + delta_y),
            (x2 - delta_x, y2 - delta_y)
        } - {(x1, y1), (x2, y2)}

        in_bounds = lambda p: 0 <= p[0] < bounds[0] and 0 <= p[1] < bounds[1]

        {p: results.add(p) for p in new_points if in_bounds(p)}

    return results

def calculate_nodes(coords, bounds):
    results = set()

    def in_bounds(point):
        return 0 <= point[0] < bounds[0] and 0 <= point[1] < bounds[1]

    for (x1, y1), (x2, y2) in combinations(coords, 2):
        delta_x, delta_y = x1 - x2, y1 - y2
        
        results.add((x1, y1))
        results.add((x2, y2))

        for x_start, y_start in [(x1, y1), (x2, y2)]:
            x, y = x_start, y_start
            while in_bounds((x + delta_x, y + delta_y)):
                x += delta_x
                y += delta_y
                results.add((x, y))

            x, y = x_start, y_start
            while in_bounds((x - delta_x, y - delta_y)):
                x -= delta_x
                y -= delta_y
                results.add((x, y))

    return results

def part1():
    grid = read_file()
    node_dict = parse_grid_for_nodes(grid)
    node_locations = set()

    for coords in node_dict.values():
        node_locations.update(calculate_deltas(coords, [len(grid), len(grid[0])]))

    print(len(node_locations))

def part2():
    grid = read_file()
    node_dict = parse_grid_for_nodes(grid)
    node_locations = set()

    for coords in node_dict.values():
        node_locations.update(calculate_nodes(coords, [len(grid), len(grid[0])]))

    print(len(node_locations))


In [94]:
part1()
part2()

376
1352
