In [None]:
import math
import numpy as np
import pyparsing as pp

In [None]:
filename = "inputs/12-08.txt"
with open(filename, "r") as f:
    data = f.read()

In [None]:
# data, part_1, part_2
tests = [
    (
        """RL

AAA = (BBB, CCC)
BBB = (DDD, EEE)
CCC = (ZZZ, GGG)
DDD = (DDD, DDD)
EEE = (EEE, EEE)
GGG = (GGG, GGG)
ZZZ = (ZZZ, ZZZ)""",
        2,
        2,
    ),
    (
        """LLR

AAA = (BBB, BBB)
BBB = (AAA, ZZZ)
ZZZ = (ZZZ, ZZZ)""",
        6,
        6,
    ),
    (
        """LR

11A = (11B, XXX)
11B = (XXX, 11Z)
11Z = (11B, XXX)
22A = (22B, XXX)
22B = (22C, 22C)
22C = (22Z, 22Z)
22Z = (22B, 22B)
XXX = (XXX, XXX)""",
        None,
        6,
    ),
    (
        """LLLRR

AAA = (AAA, ABB)
ABB = (ACA, ABB)
ACA = (ACA, ZZZ)
ZZZ = (ACA, ZZZ)
BAA = (BBZ, BBZ)
BBZ = (BAA, BAA)""",
        9,
        9,
    ),
    (
        """LLLRR

AAA = (AAA, ABB)
ABB = (ACA, ABB)
ACA = (ACA, ZZZ)
ZZZ = (ACA, ZZZ)
BAA = (BBZ, BBZ)
BBZ = (BAA, BAA)
CAA = (CAA, CBZ)
CBZ = (CDD, CDD)
CDD = (CEE, CEE)
CEE = (CBZ, CBZ)
DAA = (DBZ, DCC)
DBZ = (DAA, DCC)
DBB = (DBB, DCC)
DCC = (DDD, DCC)
DDD = (DDD, DEZ)
DEZ = (DBB, DBB)
FAA = (FBZ, FBZ)
FBZ = (FCZ, FCZ)
FCZ = (FDD, FCZ)
FDD = (FCZ, FCZ)
GAA = (GAA, GBB)
GBB = (GCC, GBB)
GCC = (GCC, GDD)
GDD = (GEE, GDD)
GEE = (GEE, GFF)
GFF = (GGG, GFF)
GGG = (GGG, GHZ)
GHZ = (GII, GHZ)
GII = (GII, GII)""",
        9,
        19,
    ),
]

In [None]:
def check(function, part=1, tests=tests):
    for input, *outputs in tests:
        output = outputs[part - 1]
        if output is not None:
            result = function(input)
            assert result == output, f"{result} == {output}"

# Part 1

In [None]:
word = pp.Word(pp.alphanums)
node_expr = pp.Group(word + pp.Suppress("=") + pp.Suppress("(") + word + pp.Suppress(",") + word + pp.Suppress(")"))
data_expr = word + pp.Group(pp.OneOrMore(node_expr))

In [None]:
def parse_map(data):
    lr, nds = data_expr.parse_string(data)
    nodes = {}
    for nd, left, right in nds:
        nodes[nd] = (left, right)
    return [int(c == "R") for c in lr], nodes

In [None]:
def navigate(data, start="AAA", end="ZZZ"):
    lr, nodes = parse_map(data)
    cur = start
    steps = 0
    while cur != end:
        cur = nodes[cur][lr[steps % len(lr)]]
        steps += 1
    return steps

In [None]:
check(navigate)
navigate(data)

# Part 2

The intended solution was probably to assume that ghosts looped back on their first node after reaching their end node.  
This was the case for my input and the provided example, but might not always be true.

The following implementation is designed (but not thoroughly tested) to work without this assumption.

In [None]:
def navigate(lr, nodes, start):
    """Find end nodes in a path.

    Paths are supposed to loop at some point.

    Parameters
    ----------
    lr : list of (0, 1) int
        Left (0) - right (1) instructions loop
    nodes : dict[str, (str, str)]
        Mapping dictionary, node -> (left node, right node)
    start : str
        Starting node

    Returns
    -------
    loop_length : int
        The length of the loop
    last : int
        The number of steps needed to reach the start of the loop
    ends : list of int
        End nodes reached along the path
    """
    cur = start
    steps = 0
    path = {}
    ends = []
    while (cur, steps % len(lr)) not in path:
        path[(cur, steps % len(lr))] = steps
        cur = nodes[cur][lr[steps % len(lr)]]
        steps += 1
        if cur[-1] == "Z":
            ends += [steps]
    loop_length = steps - path[(cur, steps % len(lr))]
    return loop_length, steps - loop_length, ends

In [None]:
def offset_lcm(a, a_offset, b, b_offset):
    """Merge two (period, offset) pairs into a single (period, offset)"""
    gcd, s, t = egcd(a, b)
    offset_diff = a_offset - b_offset
    q, r = divmod(offset_diff, gcd)
    if r:
        return None
    lcm = a // gcd * b
    offset = (a_offset - s * q * a) % lcm
    offset = offset + max(0, math.ceil((a_offset - offset) / lcm) * lcm)
    offset = offset + max(0, math.ceil((b_offset - offset) / lcm) * lcm)
    return lcm, offset

def egcd(a, b):
    """Extended Euclidean algorithm"""
    s0, s1, t0, t1 = 1, 0, 0, 1
    while b:
        q, r = divmod(a, b)
        a, b = b, r
        s0, s1 = s1, s0 - q * s1
        t0, t1 = t1, t0 - q * t1
    return a, s0, t0

In [None]:
def navigate_ghosts(data):
    """Find the time of the first common exit."""
    lr, nodes = parse_map(data)
    lengths = []
    lcm_offsets = [(1, 0)]
    other_ends = set()
    for nd in nodes:
        if nd[-1] == "A":
            loop_length, loop_begin, ends = navigate(lr, nodes, nd)
            # some end nodes might be reached before entering the loop
            prefix_ends = set()
            loop_ends = set()
            for end in ends:
                if end > loop_begin:
                    loop_ends.add(end)
                else:
                    prefix_ends.add(end)
            merged_lcm_offsets = []
            for (l1, o1) in lcm_offsets:
                l2 = loop_length
                for o2 in loop_ends:
                    lo = offset_lcm(l1, o1, l2, o2)
                    if lo is not None:
                        merged_lcm_offsets.append(lo)
            # merge ends reached before entering the common loop
            min_offset = max(prefix_ends.union(other_ends).union({0}))
            for lcm, offset in lcm_offsets:
                while offset <= min_offset:
                    other_ends.add(offset)
                    offset += lcm
            for offset in loop_ends:
                while offset <= min_offset:
                    prefix_ends.add(offset)
                    offset += loop_length
            other_ends = other_ends.intersection(prefix_ends)
            lcm_offsets = merged_lcm_offsets
    return min(list(other_ends) + [o for _, o in lcm_offsets])

In [None]:
check(navigate_ghosts, 2)
navigate_ghosts(data)