# advent of code 2024 - [day 17](https://adventofcode.com/2024/day/17)

### Parsing

In [15]:
import re
from collections import namedtuple
from functools import cache

Instance = namedtuple('Instance', ['halt','output','A', 'B', 'C', 'pointer', 'program'])

In [16]:
def parse(file='input.txt'):
    """Generates tuples of integers"""
    file = open(file, 'r')

    dict_res = {'pointer':0}
    for ix, line in enumerate(file):
        if ix < 3:
            # Parse registers
            group =re.match(r'^Register ([A|B|C]): (\d+)$', line)
            dict_res[group[1]] = int(group[2])
            continue

        if ix == 3:
            continue

        if ix == 4:
            # parse program
            group =re.match(r'^Program: ([\d|,]+)$', line)
            dict_res['program'] = tuple(int(x) for x in group[1].strip().split(','))
            continue
                         
    return Instance(False,tuple([]), dict_res['A'], dict_res['B'], dict_res['C'], dict_res['pointer'], dict_res['program'])

### Compute

In [17]:
def bitwise_xor(a, b):
    # Assume a and b are non-negative integers
    if a < 0 or b < 0:
        raise ValueError("Only non-negative integers are allowed.")
    
    # Convert to binary (without '0b' prefix)
    a_bin = bin(a)[2:]
    b_bin = bin(b)[2:]
    
    # Pad the shorter binary string with zeros on the left
    max_len = max(len(a_bin), len(b_bin))
    a_bin = a_bin.zfill(max_len)
    b_bin = b_bin.zfill(max_len)
    
    # Perform XOR bit by bit
    result_bits = []
    for bit_a, bit_b in zip(a_bin, b_bin):
        xor_bit = str(int(bit_a) ^ int(bit_b))
        result_bits.append(xor_bit)
    
    # Join the bits and convert back to integer
    result = int(''.join(result_bits), 2)
    return result

# Example usage:
# 10 (1010) XOR 4 (0100) = 14 (1110)
print(bitwise_xor(10, 4))  # Should print 14

14


In [18]:
def halt(state, verbose):
    if verbose:
        print("halt!", state)
    return Instance(True, state.output, state.A, state.B, state.C, state.pointer, state.program).output

def adv (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("adv", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    return run_program(Instance(state.halt, state.output, state.A//(2**combo_operand), state.B, state.C, state.pointer+2, state.program), verbose)

def bxl (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("bxl", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    return run_program(Instance(state.halt, state.output, state.A, bitwise_xor(state.B,literal_operand), state.C, state.pointer+2, state.program), verbose)

def bst (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("bst", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    return run_program(Instance(state.halt, state.output, state.A, combo_operand%8, state.C, state.pointer+2, state.program), verbose)

def jnz (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("jnz", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    if state.A == 0:
        return run_program(Instance(state.halt, state.output, state.A, state.B, state.C, state.pointer+2, state.program), verbose)
    return run_program(Instance(state.halt, state.output, state.A, state.B, state.C, literal_operand, state.program), verbose)

def bxc (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("bxc", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    return run_program(Instance(state.halt, state.output, state.A, bitwise_xor(state.B,state.C), state.C, state.pointer+2, state.program), verbose)

def out (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("out", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    return run_program(Instance(state.halt, state.output + (combo_operand%8,), state.A, state.B, state.C, state.pointer+2, state.program), verbose)

def bdv (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("adv", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    return run_program(Instance(state.halt, state.output, state.A, state.A//(2**combo_operand), state.C, state.pointer+2, state.program), verbose)

def cdv (state, literal_operand, combo_operand, verbose):
    if verbose:
        print ("adv", "state:", state, "operand:", literal_operand, "combo operand:", combo_operand)
    return run_program(Instance(state.halt, state.output, state.A, state.B, state.A//(2**combo_operand), state.pointer+2, state.program), verbose)

def run_program(state, verbose=False):    
    if verbose:
        print("run_program:", state)

    if state.pointer >= len(state.program) - 1:
        return halt(state, verbose)
    
    else:
        opcode = state.program[state.pointer]
        literal_operand = state.program[state.pointer+1]
        if literal_operand in [0, 1, 2, 3]:
            combo_operand = literal_operand
        elif literal_operand == 4:
            combo_operand = state.A
        elif literal_operand == 5:
            combo_operand = state.B
        elif literal_operand == 6:
            combo_operand = state.C
        else:
            combo_operand = -1

        return [adv, bxl, bst, jnz, bxc, out, bdv, cdv][opcode](state, literal_operand, combo_operand, verbose)

In [28]:
filename = "input.txt"
#filename = "test.txt"
init = parse(filename)

print("part 1: ", ','.join([str(output_element) for output_element in run_program(init, False)]))


part 1:  4,1,5,3,1,5,3,5,7


In [None]:
a = 8 ** 15
power = 14
init = parse(filename)
matched = init.program[-1:]

while (True):
    state = Instance(init.halt, init.output, a, init.B, init.C, init.pointer, init.program)

    output = run_program(state, False) 
    if output == state.program:
        break
    
    if output[-len(matched):] == matched:
      power = max(power - 1, 0)
      matched = init.program[-(len(matched)+1):]

    a += 8**power
 
print("part 2:", a)

part 2: 164542125272765
