In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
import re

In [None]:
data = load_data(2024, 17)

In [None]:
# data, part_1, part_2
tests = [
    (
        """Register A: 729
Register B: 0
Register C: 0

Program: 0,1,5,4,3,0
""",
        "4,6,3,5,6,3,5,2,1,0",
        None,
    ),
    (
        """Register A: 10
Register B: 0
Register C: 0

Program: 5,0,5,1,5,4
""",
        "0,1,2",
        None,
    ),
    (
        """Register A: 2024
Register B: 0
Register C: 0

Program: 0,1,5,4,3,0
""",
        "4,2,5,6,7,7,7,7,3,1,0",
        None,
    ),
    (
        """Register A: 0
Register B: 29
Register C: 0

Program: 1,7
""",
        "",
        None,
    ),
    (
        """Register A: 0
Register B: 2024
Register C: 43690

Program: 4,0
""",
        "",
        None,
    ),
    (
        """Register A: 2024
Register B: 0
Register C: 0

Program: 0,3,5,4,3,0
""",
        None,
        117440,
    ),
]

# Part 1

In [None]:
def parse_program(data):
    registers, program = data.split("\n\n")
    registers = dict(
        zip(
            range(4, 7),
            (int(v) for v in re.findall(r"(-?\d+)", registers)),
        ),
    )
    operations = [int(v) for v in re.findall(r"(-?\d+)", program)]
    return registers, operations

In [None]:
def execute(registers, program):
    pt = 0
    div_targets = {0: 4, 6: 5, 7: 6}
    while pt < len(program):
        match program[pt:pt + 2]:
            case [op, reg] if op in div_targets:
                denominator = registers.get(reg, reg)
                registers[div_targets[op]] = registers[4] >> denominator
            case [1, lit]:
                registers[5] ^= lit
            case [2, reg]:
                registers[5] = registers.get(reg, reg) % 8
            case [3, lit]:
                if registers[4]:
                    pt = lit - 2
            case [4, _]:
                registers[5] ^= registers[6]
            case [5, reg]:
                yield registers.get(reg, reg) % 8
            case _:
                raise AssertionError
        pt += 2

In [None]:
def parse_and_execture(data):
    registers, program = parse_program(data)
    return ",".join(str(v) for v in execute(registers, program))

In [None]:
check(parse_and_execture, tests)
parse_and_execture(data)

# Part 2

This solution assumes that the program has the following properties:

- register A is right shifted by 3 bits during each iteration
- one value is output during each iteration
- counting from last to first, the ith output depends on at most the 3i most significant bits of A

In [None]:
def matches(registers, program, a_values):
    a = sum(v * 8**i for i, v in enumerate(a_values[::-1]))
    reg = registers.copy()
    reg[4] = a
    cnt = 0
    for v1, v2 in reversed(list(zip(execute(reg, program), program))):
        if v1 != v2:
            break
        cnt += 1
    return cnt

In [None]:
def fix_input(data):
    registers, program = parse_program(data)
    a = [0] * len(program)
    while (m := matches(registers, program, a)) < len(program):
        a[m] += 1
        while a[m] > 7:
            a[m] = 0
            a[m - 1] += 1
            m -= 1
    return sum(v * 8**i for i, v in enumerate(a[::-1]))

In [None]:
check(fix_input, tests, 2)
fix_input(data)