In [1]:
import re

In [2]:
with open("data/day17.txt", encoding="utf-8") as f:
    data = f.read()

In [4]:
orig_registers, program = data.split("\n\n")
orig_registers = re.findall("Register ([A-C]): (\d+)", orig_registers)
orig_registers = {k: int(v) for k,v in orig_registers}
program = re.findall("Program: (.+)", program)[0]
program = program.split(",")
program = [int(p) for p in program]

def combo(operand):
    if operand in [0, 1, 2, 3]:
        return operand
    if operand == 4:
        return registers["A"]
    if operand == 5:
        return registers["B"]
    if operand == 6:
        return registers["C"]
    raise ValueError(f"Invalid operand value: {operand}")
    
def adv(operand):
    registers["A"] = registers["A"] >> combo(operand)
    return False, None
    
def bxl(operand):
    registers["B"] = registers["B"] ^ operand
    return False, None
    
def bst(operand):
    registers["B"] = combo(operand) % 8
    return False, None

def jnz(operand):
    if registers["A"] != 0:
        return False, operand
    return False, None
    
def bxc(operand):
    registers["B"] = registers["B"] ^ registers["C"]
    return False, None
    
def out(operand):
    return True, combo(operand) % 8

def bdv(operand):
    registers["B"] = registers["A"] >> combo(operand)
    return False, None
    
def cdv(operand):
    registers["C"] = registers["A"] >> combo(operand)
    return False, None

instructions = {
    0: adv,
    1: bxl,
    2: bst,
    3: jnz,
    4: bxc,
    5: out,
    6: bdv,
    7: cdv
}

# Problem 1

In [5]:
registers = orig_registers.copy()

def output(program):
    pointer = 0
    res = []
    while pointer < len(program):
        opcode = program[pointer]
        operand = program[pointer + 1]
        output, value = instructions[opcode](operand)
        if output:
            res.append(value)
            pointer += 2
        elif value is not None:
            pointer = value
        else:
            pointer += 2
            
    return res
    
",".join([str(o) for o in output(program)])

'2,0,1,3,4,0,2,1,7'

# Problem 2

In [6]:
def similar_output(a, program):
    global registers
    registers = {"A": a, "B": 0, "C": 0}
    return output(program) == program

def is_correct_output(a, result):
    global registers
    registers = {"A": a, "B": 0, "C": 0}
    return output(program)[0] == result

solutions = [0, 1, 2, 3, 4, 5, 6, 7]
for res in program[::-1]:  
    candidates = []
    for i in range(8):
        candidates.extend([new_a for a in solutions if is_correct_output(new_a := (a << 3) | i, res)])
    solutions = candidates
        
solution = min(solutions)
similar_output(solution, program), solution

(True, 236580836040301)