# Advent of Code 2023 Day 8 - Bonus Round

December 8 is a holiday in Catholic Italy -- the Immaculate Conception if you're wondering -- and I'm in bed with a nasty fever. What better occasion to do some Advent of Code?

The text of the problem is this: [Day 8 - Haunted Wasteland](https://adventofcode.com/2023/day/8). The interesting part is the second one. A summary of the problem statement is as follows. You are given as input a directed graph whose edges are labeled with at most two labels, $G = (V, E, L)$ with $L: E \rightarrow \{0, 1\}$. From each node departs exactly one edge for each type of label, and no two out-edges from the same node have the same label. A subset of the nodes $S \subseteq V$ are starting nodes, and a subset of the nodes $A \subseteq V$ are accepting nodes. You are furthermore given a finite string $l \in \{0, 1\}^*$. Let $l^*$ be the concatenation of infinitely many copies of $l$.
For each node $s$, let $f^n(s)$ be the node reached taking the edge from $s$ labeled with the first character of $l^*$, then the edge outgoing from that node labeled with the second character of $l^*$ and so on for a total of $n$ steps.
More formally, let $f^0(s) = s$ and $f^{n+1}(s)$ be the unique $s'$ such that $(f^n(s), s') \in E$ and $F((f^n(s), s')) = l^*[n+1]$.
Determine the least $n$ such that for all $s \in S$ it holds that $f^n(s) \in A$.

In more intuitive terms, imagine placing a token on each starting node, and for each character of $l^*$ in order, moving simultaneously each token to the node pointed by the edge labeled with that character. The problem asks after how many moves all tokens will be on accepting nodes.

The problem is interesting but _I dislike problem statements that don't tell you the hypotheses, all of them and all at once_. As we will see, the intended solutions uses the fact that the input is more constrained that what the problem text describes. So, here comes the _bonus round_. Let's try to solve the problem in full generality!

The crucial observation is this: imagine following the transitions from a certain state. They must eventually cycle, with an optional anticycle. To efficiently find the length of the anticycle and cycle, construct a "long pointer" from each state to the state it would reach following all edges of the input string, in order. This may be made efficient by computing them on demand and storing the results to avoid computing anything twice.

The anticycle and cycle from each node on the "long pointer" graph induce an anticycle and cycle in the original state machine.

At this point the problem becomes number theory. Let $l$ be the length of the given string, $a$ and $c$ be the lengths of the anticycle and cycle on the long pointer graph respectively, $f(s)$ be node connected to $s$ in the long pointer graph.

Furthermore, let $R(s)$ be the set containing all indices $i$ for which following $i$ transitions from $s$ in the original graph (using the prefix of the given string to choose labels) reaches an accepting state.

With this notation, it is easy to see that starting from a node $s$, we will be in an accepting state if and only if the number of steps made is in the following set:

$$
X_s =
\left(
\bigcup_{i=0}^{a} \{li + j: j \in R(f^{i}(s)) \}
\right)
\cup
\left(
\bigcup_{i=0}^{\infty} \{la + li + j: j \in R(f^{a}(s)) \}
\right)
$$

The final answer is thus the intersection of $X_s$ for all $s$ that are starting states.

First, some imports:

In [1]:
from functools import cache, reduce
from collections import defaultdict, Counter
from dataclasses import dataclass
from itertools import product, count
from math import lcm

Let's get to work!

In [2]:
input_8 = open('input/8.txt').read().split('\n')

moves, _, *trans = input_8


class StateAutomaton:
    def __init__(self, trans):
        self.trans = defaultdict(list)
        self.all_states = set()
        for pre, post, label in trans:
            self.trans[pre].append((post, label))
            self.all_states |= {pre, post}
            
    def transition(self, pre, label):
        for post, tlabel in self.trans[pre]:
            if label == tlabel:
                return post
            
    @classmethod
    def from_text(cls, ls):
        def gen_trans(ls):
            for l in ls:
                a, b, c = l.translate(
                    str.maketrans('(),=', '    ')
                ).split()
                yield a, b, 'L'
                yield a, c, 'R'
        return cls(gen_trans(ls))


@cache
def fast_ptr(automaton, state, labels):
    return reduce(automaton.transition, labels, state)


@cache
def cycle_length(automaton, state, labels):
    visited = {}
    for n in count():
        if state in visited:
            return visited[state], n - visited[state]
        visited[state] = n
        state = fast_ptr(automaton, state, labels) 


@cache
def accepted_points(automaton, state, labels, cond):
    def _trans_automaton(state):
        for n, l in enumerate(labels, 1):
            state = automaton.transition(state, l)
            if cond(state):
                yield n, state
    return list(_trans_automaton(state))


@dataclass
class AcceptedPoints:
    anticycle_accepted: list[int]
    anticycle_length: int
    cycle_accepted: list[int]
    cycle_length: int
        
    def gen(self):
        yield from iter(self.anticycle_accepted)
        for k in count():
            yield from (
                k * self.cycle_length + self.anticycle_length + x
                for x in self.cycle_accepted
            )


def accepted_points_repeating(automaton, state, labels, cond):
    anticycle, cycle = cycle_length(automaton, state, labels)
    
    anticycle_accepted = []
    for i in range(anticycle):
        anticycle_accepted += [
            i * len(labels) + n for n, _ in
            accepted_points(automaton, state, labels, cond)
        ]
        state = fast_ptr(automaton, state, labels)
    
    cycle_accepted = []
    for j in range(cycle):
        cycle_accepted += [
            j * len(labels) + n for n, _ in
            accepted_points(automaton, state, labels, cond)
        ]
        state = fast_ptr(automaton, state, labels)
        
    return AcceptedPoints(
        anticycle_accepted,
        anticycle * len(labels),
        cycle_accepted,
        cycle * len(labels)
    )


S = StateAutomaton.from_text(trans)

cond = lambda u: u == 'ZZZ'
accepted = accepted_points_repeating(S, 'AAA', moves, cond)
A = next(accepted.gen())
assert A == 16579


cond = lambda u: u.endswith('Z')
starting_states = {
    a.split()[0] for a in trans
    if a.split()[0].endswith('A')
}

accepted = {
    state: accepted_points_repeating(S, state, moves, cond)
    for state in starting_states
}
    
    
A = lcm(16579, 17141, 14893, 22199, 12083, 19951)
assert A == 12927600769609

This does not work in general! The input is in fact a special case where the following happens:

In [3]:
accepted

{'LJA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[21918], cycle_length=22199),
 'KTA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[14612], cycle_length=14893),
 'NFA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[11802], cycle_length=12083),
 'JXA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[16860], cycle_length=17141),
 'AAA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[16298], cycle_length=16579),
 'PLA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[19670], cycle_length=19951)}

By what we have shown before, we would have:

$$
X_{JXA} = 281 + (16860 + 17141j) = 17141(j+1) \\
X_{KTA} = 281 + (14612 + 14893j) = 14893(j+1) \\
\dots
$$

and similarly for all other cases, which explains why *in this particular case* taking the least common multiple works.

### Day 8 - Bonus Round

But what if we were to _really_ solve the problem? What would it take? It would take a lemma.

### Lemma

Given a system of congruences in the form

$$
f(x) = \left\{
    \begin{array}{ll}
        x \equiv a_1 & \mod m_1 \\
        x \equiv a_2 & \mod m_2 \\
        \dots \\
        x \equiv a_n & \mod m_n
    \end{array}
\right.
$$

for _arbitrary_ $m_1, m_2, \dots m_n$ we can compute the set of solutions modulo some integer $M$, and we can do so efficiently. The solution may not be unique, and there might be no solutions.

Consider the following example:

In [4]:
system = {(4, 12), (4, 14), (1, 55), (34, 121)}
system

{(1, 55), (4, 12), (4, 14), (34, 121)}

representing:

$$
f(x) = \left\{
    \begin{array}{ll}
        x \equiv 1 & \mod 55 \\
        x \equiv 4 & \mod 12 \\
        x \equiv 34 & \mod 121 \\
        x \equiv 4 & \mod 14
    \end{array}
\right.
$$


#### Step 1: Divide
Consider the equations one at a time. For each equation, if the modulo $m_i$ is a prime or a power of a prime, do not do anything. If $m_i$ contains two or more distinct primes in its factorization, use the chinese remainder theorem _in reverse_ to replace it with two or more equations whose moduli are powers of primes and whose set of solutions is identical.

For example, consider:

$$
x \equiv a \mod p_1^{\alpha_1}p_2^{\alpha_2} \dots p_n^{\alpha_n}
$$

If $p$ is a prime, then by definition for all integers $a, b$ it holds that $p|ab \rightarrow p|a \vee p|b$. This means that if the above equation holds, then the following also all hold:

$$
f(x) = \left\{
    \begin{array}{ll}
        x \equiv a \mod p_1^{\alpha_1} \\
        x \equiv a \mod p_2^{\alpha_2} \\
        \dots \\
        x \equiv a \mod p_n^{\alpha_n} \\
    \end{array}
\right.
$$

Additionally, by the chinese remainder theorem, the intersection of all those equations has the same solutions as the original equation (powers of distinct primes are trivially coprime). This means that we can effectively _replace_ the former equation with the latter system, yielding a system of equations with the same set of solutions.

Applying this reasoning to all equations, we are left with a system of equations where all modules are powers of primes.

In [5]:
def primes():
    ps = defaultdict(list)
    for i in count(2):
        if i in ps:
            for n in ps[i]:
                ps[i + (n if n == 2 else 2*n)].append(n)
            del ps[i]
        else:
            yield i
            ps[i**2].append(i)

            
def factorize(n):
    fs = Counter()
    gen = primes()
    while n > 1:
        p = next(gen)
        while n % p == 0:
            n //= p
            fs[p] += 1

    return fs


def split_equation(a, m):
    for prime, power in factorize(m).items():
        yield a % (prime ** power), prime, power
        

system_split = {
    x
    for a, m in system
    for x in split_equation(a, m)
}

system_split

{(0, 2, 1),
 (0, 2, 2),
 (1, 3, 1),
 (1, 5, 1),
 (1, 11, 1),
 (4, 7, 1),
 (34, 11, 2)}

representing:
    
$$
f(x) = \left\{
    \begin{array}{ll}
        x \equiv 0 & \mod 2^1 \\
        x \equiv 0 & \mod 2^2 \\
        x \equiv 1 & \mod 3^1 \\
        x \equiv 1 & \mod 5^1 \\
        x \equiv 1 & \mod 11^1 \\
        x \equiv 4 & \mod 7^1 \\
        x \equiv 34 & \mod 11^2
    \end{array}
\right.
$$  
    
#### Step 2: Merge

Now, consider the equations two at a time, without loss of generality

$$
f(x) = \left\{
    \begin{array}{ll}
        x \equiv a_1 \mod p_1^{\alpha_1} \\
        x \equiv a_2 \mod p_2^{\alpha_2} \\
    \end{array}
\right.
$$

There are two cases. If $p_1 \neq p_2$, then they are coprime, and the two equations may be replaced with their solution modulo $p_1^{\alpha_1}p_2^{\alpha_2}$.

In [6]:
def solve_crt(a1, m1, a2, m2):
    k = (a2 - a1) * pow(m1, -1, m2)
    return m1 * k + a1

The function `merge_eq` solves the following system:

$$
f(x) = \left\{
    \begin{array}{ll}
        x \equiv a_1 & \mod m_1 \\
        x \equiv a_2 & \mod m_2 \\
    \end{array}
\right.
$$

where $m_1$, $m_2$ are coprimes.

If we rewrite the first equation as $x = m_1 k + a_1$, the second becomes $m_1 k + a_1 \equiv a_2 \mod m_2$. Subtracting $a_1$ from both sides and multiplying both sides by $m_1^{-1}$ yields the answer, which exists and is unique.

If instead $p_1 = p_2$, there is either a contradiction or one of the two equations is superfluous and may be removed.

In [7]:
class Contradiction(BaseException):
    ...
    

def handle_same_prime(a1, a2, p, alpha1, alpha2):
    if alpha1 == alpha2:
        if a1 == a2:
            return (a1, p, alpha1)
        else:
            raise Contradiction()
    elif alpha1 > alpha2:
        if (a1 % p ** alpha2) == (a2):
            return (a1, p, alpha1)
        else:
            raise Contradiction()
    else:
        handle_same_prime(a2, a1, p, alpha2, alpha1)

Putting it all together:

In [8]:
def solve_system(system):
    system_split = {
        x
        for a, m in system
        for x in split_equation(a, m)
    }
    
    eqn = defaultdict(list)
    for a, p, alpha in system_split:
        eqn[p].append((a, p, alpha))
        
    reduced = []
    try:
        for l in eqn.values():
            acc = l[0]
            for a, p, alpha in l[1:]:
                acc = handle_same_prime(a, acc[0], p, alpha, acc[2])
            reduced.append((acc[0], acc[1] ** acc[2]))
    except:
        return []
    
    acc = reduced[0]
    for a, m in reduced[1:]:
        acc = solve_crt(acc[0], acc[1], a, m) % (m * acc[1]), m * acc[1]
    
    return acc


solve_system(system)

(41416, 50820)

Representing the unique solution of the original system $x \equiv 41416 \mod 50820$.

### Back to the original problem

We now have a way to solve systems of linear congruencies:

In [9]:
solve_system

<function __main__.solve_system(system)>

...and a representation of accepted points for each initial state:

In [10]:
accepted

{'LJA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[21918], cycle_length=22199),
 'KTA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[14612], cycle_length=14893),
 'NFA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[11802], cycle_length=12083),
 'JXA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[16860], cycle_length=17141),
 'AAA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[16298], cycle_length=16579),
 'PLA': AcceptedPoints(anticycle_accepted=[], anticycle_length=281, cycle_accepted=[19670], cycle_length=19951)}

How do we find all intersections?

They may happen at two places: before the anticycle of some starting state is over, or when all initial states are cycling. The first may be found by merge-joining the lists as iterators, the second by solving a system of linear congruences with the above method.

Let's define a more interesting example:

In [11]:
accepted_bonus = {
    'AAA': AcceptedPoints([1, 4, 11], 32, [30, 31], 40),
    'BBB': AcceptedPoints([1, 2, 4, 11, 20], 32, [30, 31], 40),
    'CCC': AcceptedPoints([3, 4, 7, 11, 15], 58, [4], 40)
}

accepted_bonus

{'AAA': AcceptedPoints(anticycle_accepted=[1, 4, 11], anticycle_length=32, cycle_accepted=[30, 31], cycle_length=40),
 'BBB': AcceptedPoints(anticycle_accepted=[1, 2, 4, 11, 20], anticycle_length=32, cycle_accepted=[30, 31], cycle_length=40),
 'CCC': AcceptedPoints(anticycle_accepted=[3, 4, 7, 11, 15], anticycle_length=58, cycle_accepted=[4], cycle_length=40)}

In [12]:
@dataclass
class PointSolution:
    # Represents a single solution
    x: int
    
    __hash__ = lambda s: s.x
    
@dataclass
class PositiveLinearSpanSolution:
    # Represents a set of solutions
    # in the form a + kb for k >= 0
    a: int
    b: int
        
    __hash__ = lambda s: hash((s.a, s.b))


def all_intersections(accepted):
    intersections = set()
    limit = max(c.anticycle_length for c in accepted.values())
    
    gen = {
        s: [0, c.gen()]
        for s, c in accepted.items()
    }
    
    def _stop():
        s = list(c[0] for c in gen.values())
        return any(s) and all(x == s[0] for x in s)
    
    while True:
        s = list(c[0] for c in gen.values())
        if any(s) and all(x == s[0] for x in s):
            intersections.add(PointSolution(s[0]))

        lo = min(gen, key=lambda u: gen[u][0])
        val = next(gen[lo][1])
        if val <= limit:
            gen[lo][0] = val
        else:
            break

    
    cycle_intersections = set()
    anti_length = []
    cycle_accepted = []
    cycle_length = []
    for k, v in accepted.items():
        anti_length.append(v.anticycle_length)
        cycle_accepted.append(v.cycle_accepted)
        cycle_length.append(v.cycle_length)
        
    for its in product(*cycle_accepted):
        eqn = [
            (-(its[i] + anti_length[i]) % cycle_length[i], cycle_length[i])
            for i in range(len(its))
        ]
                            
        if t := solve_system(eqn):
            res_mod = -t[0] % t[1]
            if res_mod <= limit:
                res_mod += t[1] * (1 + (limit - res_mod) // t[1])
            intersections.add(PositiveLinearSpanSolution(
                res_mod, t[1]
            ))
            
    return intersections

    
all_intersections(accepted_bonus)

{PointSolution(x=11),
 PointSolution(x=4),
 PositiveLinearSpanSolution(a=62, b=40)}

With, naturally, the same solution on the original input:

In [13]:
all_intersections(accepted)

{PositiveLinearSpanSolution(a=12927600769609, b=12927600769609)}

These observations and pieces of code (should) be a complete solution of the fully-general case of the problem.