In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import check, load_data

In [None]:
from collections.abc import Hashable, Mapping
from functools import cache
from itertools import permutations

In [None]:
data = load_data(2024, 21)

In [None]:
# data, part_1, part_2
tests = [
    (
        """029A
980A
179A
456A
379A
""",
        126384,
        154115708116294,
    ),
]

# Part 1

In [None]:
class KeyPad(Hashable, Mapping):
    """An hashable wrapper around dictionaries to represent a keypad."""

    def __init__(self, keys):
        self._dct = keys
        self._hash = hash(tuple(keys.items()))
        self._len = len(keys)
        self._values = frozenset(keys.values())
        if len(self._values) != self._len:
            raise ValueError("Key positions must be unique")

    def __len__(self):
        return self._len

    def __hash__(self):
        return self._hash

    def __eq__(self, other):
        return other._dct == self._dct

    def __getitem__(self, item):
        return self._dct[item]

    def __iter__(self):
        return self._dct.__iter__()

    def values(self):
        return self._values

def parse_pad(lines):
    """Convert pad descriptions to a KeyPad object."""
    pad = {}
    for j, line in enumerate(lines):
        for i, c in enumerate(line):
            if c != " ":
                pad[c] = i, j
    return KeyPad(pad)

In [None]:
DIRECTIONS = {"<": (-1, 0), ">": (1, 0), "^": (0, -1), "v": (0, 1)}

def is_valid(start, sequence, valid):
    """Check if a sequence of instructions only considers valid positions."""
    i, j = start
    for c in sequence:
        di, dj = DIRECTIONS[c]
        i += di
        j += dj
        if (i, j) not in valid:
            return False
    return True

def gen_paths(a, b, pad, push="A"):
    """Generate all shortest valid instruction sequences.

    These sequences make a robot starting above the `a` position push the `b` key.
    """
    ai, aj = pad[a]
    bi, bj = pad[b]
    di, dj = bi - ai, bj - aj
    keys = "<" * (-di) + ">" * di + "^" * (-dj) + "v" * dj
    for sequence in {"".join(p) for p in permutations(keys)}:
        if is_valid(pad[a], sequence, pad.values()):
            yield sequence + push

In [None]:
@cache
def min_cost(a, b, pads, start):
    """Compute the minimum number of instructions for a chain of robots.

    Parameters
    ----------
    a: str
        The starting key.
    b: str
        The key to be pushed.
    pads: tuple(KeyPad)
        The sequence of keypad layouts for robots to use.
        This sequence is reversed w.r.t. the chain of robots, i.e., the first
        keypad is linked with the door to open.
    start: str
        The robots arm starting position.

    Returns
    -------
    int
        The minimum number of instructions.
    """
    if not pads:
        # The human can push any key in one instruction
        return 1
    return min(
        sum(
            min_cost(da, db, pads[1:], start)
            for da, db in zip(start + path, path)
        ) for path in gen_paths(a, b, pads[0])
    )

In [None]:
def find_shortest(code, pads, push="A"):
    return sum(
        min_cost(a, b, pads, push)
        for a, b in zip(push + code, code)
    )

In [None]:
def type_codes(data, depth=2):
    pad_lines = (["789", "456", "123", " 0A"], [" ^A", "<v>"])
    npad, dpad = map(parse_pad, pad_lines)
    pads = (npad,) + (dpad,) * depth
    complexities = 0
    for code in data.splitlines():
        num = int(code.replace("A", ""))
        complexities += find_shortest(code, pads) * num
    return complexities

In [None]:
check(type_codes, tests)
type_codes(data)

# Part 2

In [None]:
check(type_codes, tests, 2, depth=25)
type_codes(data, depth=25)