In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
from collections import defaultdict, deque
from itertools import count
import math

In [None]:
data = load_data(2023, 20)

In [None]:
# data, part_1, part_2
tests = [
    (
        """broadcaster -> a, b, c
%a -> b
%b -> c
%c -> inv
&inv -> a""",
        32000000,
        None,
    ),
    (
        """broadcaster -> a
%a -> inv, con
&inv -> b
%b -> con
&con -> output""",
        11687500,
        None,
    ),
    (
        """broadcaster -> a, d
%a -> b
%b -> c
%c -> f
%d -> e
%e -> f
&f -> rx""",
        None,
        6,
    ),
    (
        """broadcaster -> a, d
%a -> b
%b -> c
%c -> f
%d -> e
%e -> f
&f -> g
%g -> rx""",
        None,
        14,
    ),
    (
        """broadcaster -> a, d
%a -> b
%b -> c
%c -> f
%d -> e
%e -> f
&f -> g
%g -> h
&h -> rx""",
        None,
        6,
    ),
    (
        """broadcaster -> a
%a -> b
%b -> c, d, f
&c -> a, f
%d -> e, f
%e -> f, h
&f -> g
%g -> g, h
&h -> rx""",
        None,
        11,
    ),
    (
        """broadcaster -> a
%a -> b, c
&b -> a, d, e
%c -> b, e
&d -> f
%e -> g
&f -> rx
%g -> b, h
%h -> b, i
%i -> j, b
%j -> k, b
%k -> l, b
%l -> b, m
%m -> n, b
%n -> o, b
%o -> b""",
        None,
        4091,
    ),
]

# Part 1

In [None]:
def parse(data):
    operations = {}
    memory = {}
    for line in data.splitlines():
        left, right = line.split(" -> ")
        targets = right.split(", ")
        if left == "broadcaster":
            op = "broadcast"
            label = "broadcaster"
        else:
            op = left[0]
            label = left[1:]
            if op == "&":
                memory[label] = {}
        operations[label] = (op, targets)
    for source in operations:
        _, targets = operations[source]
        for target in targets:
            if target in memory:
                memory[target][source] = "low"
    return operations, memory

In [None]:
def button_mashing(operations, memory, target_module=None, times=range(1000)):
    signals = {"low": 0, "high": 0}
    states = defaultdict(lambda: "low")
    for presses in times:
        pulses = deque([("broadcaster", "low", "button")])
        while pulses:
            label, signal, source = pulses.popleft()
            if target_module and signal == "low" and label == target_module:
                return presses + 1
            signals[signal] += 1
            if label not in operations:
                continue
            op, targets = operations[label]
            if op == "broadcast":
                for t in targets:
                    pulses.append((t, signal, label))
            elif op == "%":
                if signal == "low":
                    states[label] = "low" if states[label] == "high" else "high"
                    for target in targets:
                        pulses.append((target, states[label], label))
            elif op == "&":
                memory[label][source] = signal
                if all(v == "high" for v in memory[label].values()):
                    emit = "low"
                else:
                    emit = "high"
                for target in targets:
                    pulses.append((target, emit, label))
    return signals["high"] * signals["low"]

In [None]:
def count_pulses(data):
    operations, memory = parse(data)
    return button_mashing(operations, memory)

In [None]:
check(count_pulses, tests)
count_pulses(data)

# Part 2

In [None]:
def simplify(operations, memory, target_module="rx"):
    """Identify easier subproblems to solve.

    If:
    1. rx <- &a <- (&b, &c, ... &z), and
    2. subgraphs associated with (&b, &c, ..., &z) are independant, and
    3. (&b, &c, ... &z) output high pulses periodically (which means they
    receive a low pulse at that time).

    Then &a will output a low pulse at the LCM of all periods.
    This is very specific.
    """
    def _add_dependency(dependencies, operations, source, target):
        if source in dependencies[target]:
            return []
        dependencies[target] |= {source}
        dependencies[target] |= dependencies[source]
        if target not in operations:
            return []
        _, sub_targets = operations[target]
        return [(source, sub_target) for sub_target in sub_targets]

    dependencies = defaultdict(set)
    direct_dependencies = defaultdict(set)
    queue = []
    for source, (op, targets) in operations.items():
        for target in targets:
            queue.append((source, target))
            direct_dependencies[target] |= {source}
    while queue:
        source, target = queue.pop()
        queue += _add_dependency(dependencies, operations, source, target)
    for dependency in dependencies:
        dependencies[dependency] |= {dependency}

    # Check hypotheses
    sorry_but_no = [{"operations": operations, "memory": memory, "target_module": target_module}]
    # Hypothesis #1
    if not len(direct_dependencies[target_module]) == 1:
        return sorry_but_no
    top_cell = direct_dependencies[target_module].pop()
    if not operations[top_cell][0] == "&":
        return sorry_but_no
    # Hypothesis #2
    subgraphs = []
    subproblems = []
    for subgoal in direct_dependencies[top_cell]:
        if not operations[subgoal][0] == "&":
            return sorry_but_no
        subcells = dependencies[subgoal]
        for other in subgraphs:
            if subcells & other > {"broadcast"}:
                return sorry_but_no
        subproblems.append({
            "operations": {cell: operations[cell] for cell in operations if cell in subcells},
            "memory": {cell: memory[cell] for cell in memory if cell in subcells},
            "target_module": subgoal,
        })
    # Hypothesis #3
    # Unchecked...
    return subproblems

In [None]:
def maybe_solve(data):
    operations, memory = parse(data)
    return math.lcm(*[button_mashing(**subproblem, times=count()) for subproblem in simplify(operations, memory)])

In [None]:
check(maybe_solve, tests, 2)
maybe_solve(data)