In [1]:
import z3

In [2]:
init_b = 0
init_c = 0

program = [2,4,1,1,7,5,1,4,0,3,4,5,5,5,3,0]

In [3]:
program_opcodes_operands = list(zip(program[::2], program[1::2]))
program_opcodes_operands

[(2, 4), (1, 1), (7, 5), (1, 4), (0, 3), (4, 5), (5, 5), (3, 0)]

In [4]:
if program_opcodes_operands[-1] != (3, 0):
    raise ValueError('last op should be (3, 0)')

In [5]:
if sum(opcode == 5 for opcode, _ in program_opcodes_operands) != 1:
    raise ValueError('we should have a single output')

In [6]:
unrolled_instructions = program_opcodes_operands * len(program)

In [7]:
state_before = [
    {
        'a': z3.BitVec(f'a_{i}', 64),
        'b': z3.BitVec(f'b_{i}', 64),
        'c': z3.BitVec(f'c_{i}', 64),
    }
    for i in range(len(unrolled_instructions) + 1)
]
    

In [8]:
def combo_value(operand, state):
    match operand:
        case 0 | 1 | 2 | 3:
            return operand
        case 4:
            return state['a']
        case 5:
            return state['b']
        case 6:
            return state['c']
    raise ValueError()        

In [9]:

constraints = [
    state_before[0]['b'] == 0,
    state_before[0]['c'] == 0,
]
for i, (opcode, operand) in enumerate(unrolled_instructions):
    prev_state = state_before[i]
    next_state = state_before[i + 1]

    def combo():
        return combo_value(operand, prev_state)
    def forward(registers):
        constraints.append(z3.And(*[next_state[r] == prev_state[r] for r in registers]))
    def dv():
        return prev_state['a'] / (1 << combo())
    match opcode:
        case 0:
            constraints.append(next_state['a'] == dv())
            forward('bc')
        case 1:
            constraints.append(next_state['b'] == prev_state['b'] ^ operand)
            forward('ac')
        case 2:
            constraints.append(next_state['b'] == combo() % 8)
            forward('ac')
        case 3:
            if i // len(program_opcodes_operands) < len(program) - 1:
                constraints.append(prev_state['a'] != 0)
            else:
                constraints.append(prev_state['a'] == 0)
            forward('abc')
        case 4:
            constraints.append(next_state['b'] == prev_state['b'] ^ prev_state['c'])
            forward('ac')
        case 5:
            constraints.append(combo() % 8 == (program[i // len(program_opcodes_operands)]))
            forward('abc')
        case 6:
            constraints.append(next_state['b'] == dv())
            forward('ac')
        case 7:
            constraints.append(next_state['c'] == dv())
            forward('ab')
        case _:
            raise ValueError()


In [10]:
current_a = None
for _ in range(10):
    solver = z3.Optimize()    
    solver.add(*constraints)
    if current_a:
        solver.add(state_before[0]['a'] < current_a)
    solver.check()
    candidate = solver.model()[state_before[0]['a']]
    if candidate is None:
        break
    current_a = candidate.as_long()
current_a

202322936867370