In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
from collections import defaultdict
from itertools import permutations, combinations

In [None]:
data = load_data(2024, 8)

In [None]:
# data, part_1, part_2
tests = [
    (
        """............
........0...
.....0......
.......0....
....0.......
......A.....
............
............
........A...
.........A..
............
............
""",
        14,
        34,
    ),
    (
        """............
............
............
............
............
............
............
X...........
............
X...........
............
............
""",
        None,
        12,
    ),
]

# Part 1

In [None]:
def parse_map(data):
    antenas = defaultdict(set)
    for j, line in enumerate(data.splitlines()):
        for i, c in enumerate(line):
            if c != ".":
                antenas[c].add((i, j))
    return antenas, i + 1, j + 1

In [None]:
def get_antinodes(antenas, width, height):
    antinodes = set()
    for (x1, y1), (x2, y2) in permutations(antenas, 2):
        x = 2 * x2 - x1
        y = 2 * y2 - y1
        if 0 <= x < width and 0 <= y < height:
            antinodes.add((x, y))
    return antinodes

In [None]:
def count_antinodes(data, antinode_locations=get_antinodes):
    antenas, width, height = parse_map(data)
    antinodes = set()
    for locations in antenas.values():
        antinodes |= antinode_locations(locations, width, height)
    return len(antinodes)

In [None]:
check(count_antinodes, tests)
count_antinodes(data)

# Part 2

In [None]:
import math

In [None]:
def get_antinode_lines(antenas, width, height):
    antinodes = set()
    for (x1, y1), (x2, y2) in combinations(antenas, 2):
        dx = x2 - x1
        dy = y2 - y1
        gcd = math.gcd(dx, dy)
        dx //= gcd
        dy //= gcd
        # 0 <= x1 + k * dx <= width - 1
        # 0 <= y1 + k * dy <= height - 1
        kmin = -math.inf
        kmax = math.inf
        for z, dz, length in [(x1, dx, width - 1), (y1, dy, height - 1)]:
            if dz > 0:
                kmin = max(kmin, -z / dz)
                kmax = min(kmax, (length - z) / dz)
            elif dz < 0:
                kmax = min(kmax, -z / dz)
                kmin = max(kmin, (length - z) / dz)
        for k in range(math.ceil(kmin), math.floor(kmax) + 1):
            x, y = x1 + k * dx, y1 + k * dy
            assert 0 <= x < width and 0 <= y < height
            antinodes.add((x, y))
    return antinodes

In [None]:
check(count_antinodes, tests, 2, antinode_locations=get_antinode_lines)
count_antinodes(data, antinode_locations=get_antinode_lines)