In [1]:
input_small = """Register A: 729
Register B: 0
Register C: 0

Program: 0,1,5,4,3,0"""

input_large = """Register A: 64196994
Register B: 0
Register C: 0

Program: 2,4,1,1,7,5,1,5,4,0,0,3,5,5,3,0"""


def parse_input(input_str: str):
    registers, program = input_str.split("\n\n")
    registers = {
        line.split(": ")[0][-1]: int(line.split(": ")[1])
        for line in registers.split("\n")
    }
    program = [int(x) for x in program.split(":")[-1].split(",")]
    return registers, program


parse_input(input_small)

({'A': 729, 'B': 0, 'C': 0}, [0, 1, 5, 4, 3, 0])

In [72]:
class Interpreter:
    def __init__(self, registers, program):
        self.program = program
        self.registers = registers.copy()
        self.pointer = 0
        self.outputs = []

    def _get_combo(self, combo):
        if combo < 4:
            return combo
        elif combo == 4:
            return self.registers["A"]
        elif combo == 5:
            return self.registers["B"]
        elif combo == 6:
            return self.registers["C"]
        else:
            raise ValueError(f"Invalid combo operand {combo}")

    def _get_literal(self, literal):
        return literal

    @staticmethod
    def _bitwise_xor(a: int, b: int):
        return a ^ b

    def _adv_op(self, operand):
        return self.registers["A"] // (2 ** self._get_combo(operand))

    def _adv_op_reverse(self, operand):
        return self.registers["A"] * (2 ** self._get_combo(operand))

    def run_instruction(self, opcode, operand):
        # adv: division of A by 2^combo. stores in A
        if opcode == 0:
            self.registers["A"] = self._adv_op(operand)
        # bxl: bitwise XOR of register B and the instruction's literal operand
        # stores in register B
        elif opcode == 1:
            self.registers["B"] = self._bitwise_xor(
                self.registers["B"], self._get_literal(operand)
            )
        # bst: combo operand modulo 8 (thereby keeping only its lowest 3 bits), B register.
        elif opcode == 2:
            self.registers["B"] = self._get_combo(operand) % 8
        # jnz: does nothing if the A register is 0.
        # else: it jumps by setting the instruction pointer to the value of its literal operand;
        # if this instruction jumps, the instruction pointer is not increased by 2 after this instruction.
        elif opcode == 3:
            if self.registers["A"] != 0:
                self.pointer = self._get_literal(operand)
                return None
        # 4, bxc: bitwise XOR of register B and register C, in register B
        elif opcode == 4:
            self.registers["B"] = self._bitwise_xor(
                self.registers["B"], self.registers["C"]
            )
        # 5, out: combo operand modulo 8, then outputs that value
        elif opcode == 5:
            self.outputs.append(self._get_combo(operand) % 8)
        # 6, bdv: adv instruction except that the result is stored in the B register
        elif opcode == 6:
            self.registers["B"] = self._adv_op(operand)
        # 7, cdv:  adv instruction except that the result is stored in the C register
        elif opcode == 7:
            self.registers["C"] = self._adv_op(operand)
        else:
            raise ValueError(f"Invalid opcode {opcode}")
        self.pointer += 2
        return None

    # def run_reverse_instruction(self, opcode, operand):
    #     if opcode == 0:
    #         self.registers["A"] = self._adv_op_reverse(operand)
    #     elif opcode == 1:
    #         self.registers["B"] = self._bitwise_xor(
    #             self.registers["B"], self._get_literal(operand)
    #         )
    #     elif opcode == 2:
    #         pass
    #     elif opcode == 3:

    def _run_step(self):
        opcode, operand = self.program[self.pointer : self.pointer + 2]
        self.run_instruction(opcode, operand)

    def check_self_program(self):
        while self.pointer < len(self.program):
            self._run_step()
            if self.outputs == self.program[: len(self.outputs)]:
                continue
                # if len(self.outputs) > 1:
                #     print(
                #         f"Found self-program of len ({len(self.outputs)}): {self.outputs}"
                #     )
                # continue
            else:
                return False
        return self.outputs == self.program

    def run_program(self, verbose=False):
        if verbose:
            print(f"Initial Registers: {self.registers}")
        while self.pointer < len(self.program):
            opcode, operand = self.program[self.pointer : self.pointer + 2]
            og_pointer = self.pointer
            self.run_instruction(opcode, operand)
            if verbose:
                print(f"Pointer: {og_pointer}, Opcode: {opcode}, Operand: {operand}")
                print(f"Registers: {self.registers}")
                print(f"Outputs: {self.outputs}")
                print(f"New Pointer: {self.pointer}")

        output = ",".join(str(x) for x in self.outputs)
        return output

In [99]:
expected_small_part2 = """Register A: 117440
Register B: 0
Register C: 0

Program: 0,3,5,4,3,0"""

registers, program = parse_input(expected_small_part2)

print(program)
interpreter = Interpreter(registers, program)
interpreter.run_program(verbose=False)
print(interpreter.outputs)

[0, 3, 5, 4, 3, 0]
[0, 3, 5, 4, 3, 0]


## Part B

Programs consist in a subprogram + (3,0).
(3,0) makes the loop.

The first steps set B,C in terms of uniquely A.
The program ends when it reaches (3,0) having A = 0.

Therefore, at each step of the loop, the input is A alone, and there are two outputs:
* new output
* Value of A

We can focus on just this structure, defyning a function with that signature.

We can reverse engineer it, since the final step must have a=0, output = 0.

Also, at each step of the loop the value of A gets a factor A -> A//8, so when going in the reverse direction the possible values are
[8*A, 8*(A+1))

In [110]:
from collections import deque
from dataclasses import dataclass


class Interpreter2(Interpreter):
    def loop_step(self, A: int):
        assert isinstance(self.program, tuple)
        original_program = self.program
        # strip the last instruction: 3,0
        assert original_program[-2] == 3
        self.program = original_program[:-2]
        # overwrite registers
        self.registers["A"] = A
        self.pointer = 0
        self.outputs = []
        self.run_program()
        new_output = self.outputs[-1]
        new_a = self.registers["A"]
        # recover original program
        self.program = original_program
        return new_a, new_output


def check_a(program, a, expected_new_a, expected_new_output):
    registers = {"A": 0, "B": 0, "C": 0}
    int2 = Interpreter2(registers, program)
    new_a, new_output = int2.loop_step(a)
    return new_a == expected_new_a and new_output == expected_new_output


@dataclass
class State:
    a: int
    step: int


def find_min_a(program):
    FACTOR = 2**3
    init_state = State(a=0, step=0)
    rev_outputs = list(reversed(program))

    stack = deque([init_state])
    good_states = []
    while stack:
        state = stack.popleft()
        possible_a = range(state.a * FACTOR, (state.a + 1) * FACTOR)
        for a in possible_a:
            if check_a(tuple(program), a, state.a, rev_outputs[state.step]):
                new_state = State(a=a, step=state.step + 1)
                if new_state.step == len(program):
                    print(f"Found new state: {new_state}")
                    good_states.append(new_state)
                else:
                    stack.append(new_state)

    good_aes = [state.a for state in good_states]
    # check that they indeed satisfy the requirement of self program
    # for a in good_aes:
    #     interpreter = Interpreter2({"A": a, "B": 0, "C": 0}, program)
    #     if not interpreter.check_self_program():
    #         print(f"Failed for {a}")

    min_a = min(good_aes)
    return min_a


registers, program = parse_input(input_large)
find_min_a(program)

Found new state: State(a=164541160582845, step=16)
Found new state: State(a=164541160583101, step=16)
Found new state: State(a=164545589767869, step=16)
Found new state: State(a=164545589768125, step=16)
Found new state: State(a=164546529291965, step=16)
Found new state: State(a=164546529292221, step=16)


164541160582845