In [None]:
import re
from pathlib import Path

In [None]:
# Read the input data
data = Path("day17_input.txt").read_text().strip()
registers = {
    key: int(value) for key, value in re.findall(r"Register (\w+): (\d+)", data)
}
if match := re.search(r"Program: ([\w,]+)", data):
    program = match.group(1)

In [None]:
class Computer:
    """The computer that runs the program."""

    def __init__(self, registers: dict[str, int], program: str):
        """Initialize the computer."""
        self.registers = registers.copy()
        self.program = [int(num) for num in program.split(",")]
        self.ptr = 0
        self._output = []

    @property
    def output(self):
        """Return the output of the program."""
        return ",".join(str(val) for val in self._output)

    def combo_operand(self, operand: int) -> int:
        """Calculate the combo operand."""
        operand2register = {
            4: "A",
            5: "B",
            6: "C",
        }
        if operand <= 3:
            return operand
        return self.registers[operand2register[operand]]

    def run(self):
        """Run the program."""
        while 0 <= self.ptr < len(self.program):
            opcode = self.program[self.ptr]
            literal_operand = self.program[self.ptr + 1]
            combo_operand = self.combo_operand(literal_operand)

            advance_pointer = True
            if opcode == 0:
                # adv
                self.registers["A"] = self.registers["A"] // 2**combo_operand
            elif opcode == 1:
                # bxl
                self.registers["B"] = self.registers["B"] ^ literal_operand
            elif opcode == 2:
                # bst
                self.registers["B"] = combo_operand % 8
            elif opcode == 3:
                # jnz
                if self.registers["A"] != 0:
                    self.ptr = literal_operand
                    advance_pointer = False
            elif opcode == 4:
                # bxc
                self.registers["B"] = self.registers["B"] ^ self.registers["C"]
            elif opcode == 5:
                # out
                self._output.append(combo_operand % 8)
            elif opcode == 6:
                # bdv
                self.registers["B"] = self.registers["A"] // 2**combo_operand
            elif opcode == 7:
                # cdv
                self.registers["C"] = self.registers["A"] // 2**combo_operand

            if advance_pointer:
                self.ptr += 2

# Part 1


In [None]:
computer = Computer(registers, program)
computer.run()
computer.output

# Part 2

Every 3 bits of A produces a number to be output. And only the last 3 bits determines
the first number to be output. The next 3 bits (from the right) determines the second
number to be output. And so on...

Thus: Starting from the right end of A, and the left end of the desired output, we can
search for one output-digit at the time.


In [None]:
# For the example, inspect how the output is generated
# For the example, the solution is oct(117440) = 0o345300
for a in range(0, 8**6):
    registers["A"] = a
    computer = Computer(registers, program)
    computer.run()
    if program.endswith(computer.output):
        print(f"{oct(a)}: {computer.output}")

In [None]:
print("Desired output:", program)

octals = []
while True:
    # Loop over all 2-digit octal numbers
    for a in range(8**2):
        # Prepend any previously found octals
        found_already = sum(
            octal * 8**i for i, octal in enumerate(reversed(octals), start=2)
        )
        registers["A"] = found_already + a
        computer = Computer(registers, program)
        computer.run()
        if (
            program.endswith(computer.output)
            and len(computer.output.split(",")) == len(octals) + 2
        ):
            # The two last digits in the output is matching
            break

    octal = a // 8  # Get the first octal of the two-digit number
    octals.append(octal)
    print(f"Found octals {octals}, output is now: {computer.output}")

    if computer.output == program:
        print("DONE!")
        solution = found_already + a
        print(f"A: {solution} (in oct: {oct(solution)})")
        break