In [1]:
from pathlib import Path
import numpy as np
import re
from math import prod
from collections import defaultdict, deque
from copy import copy

In [2]:
data = Path('../Data/Day17.txt').read_text().splitlines()

In [3]:
[a, b, c] = [int(data[i].split(':')[1].strip()) for i in range(3)]
program = [int(num) for num in data[-1].split(': ')[1].split(',')]

In [4]:
def execute_program(a, b, c, program):
    combo = {
        0: 0,
        1: 1,
        2: 2,
        3: 3,
        4: a,
        5: b,
        6: c,
    }

    outputs = []

    pointer = 0

    while pointer < len(program):
        opcode = program[pointer]
        operand = program[pointer+1]

        match opcode:
            case 0:
                a = int(a / 2**combo[operand])
                combo[4] = a
            case 1:
                b = b ^ operand
                combo[5] = b
            case 2:
                b = combo[operand] % 8
                combo[5] = b
            case 3:
                if a != 0:
                    pointer = operand
                    continue
            case 4:
                b = b ^ c
                combo[5] = b
            case 5:
                outputs.append(combo[operand] % 8)
            case 6:
                b = int(a / 2 ** combo[operand])
                combo[5] = b
            case 7:
                c = int(a / 2**combo[operand])
                combo[6] = c
        
        pointer += 2

    return outputs

In [5]:
','.join(str(num) for num in execute_program(a, b, c, program))

'7,1,5,2,4,0,7,6,1'

In [6]:
# (2, 4) Set B to A % 8
# (1, 2) Set B to B ^ 2 
# (7, 5) Set C to A // 2**B
# (1, 3) Set B to B ^ 3
# (4, 4) Set B to B ^ C
# (5, 5) Output B % 8
# (0, 3) Set A = A // 2**3
# (3, 0) If A != 0, jump to start

# Recognize that program is 16 numbers
# That requires 8**15 <= A < 8**16, aka A is a 48 bit number

# The state of B, C do not matter at the start of the cycle 
# as B is immediately determined by A
# and C is determined by A and B
# Thus the output is only determined by value of A at beginning of cycle
# And the next output is simply determined by A // 8

# B % 8 only cares about the last 3 bits of B
# So we only care about last 3 bits of B ^ C
# So we only care about last 3 bits of B and C
# B is based off of last 3 bits of A
# C is based off of last 3 bits after right shifting anywhere from 0 to 7 bits based on B (A % 8 ^ 2)
# Therefore B tells us last three bits of A at that cycle
# C tells us bits B to B+3
# Only 64 combinations of B and C to consider

# Some combinations of B and C cannot happen
# E.G. B,C cannot start as 2 and 5 as that would imply
# Last 3 bits of A are 010 and 101 simultaneously
# In general, if B^2 < 3, there is a chance of contradiction


In [7]:
# For an output, find valid B,C pairs that generate that output
bc_map = defaultdict(set)

for b_ in range(8):
    for c_ in range(8):
        shift = b_ ^ 2
        if shift < 3:
            b_bits = f"{b_:03b}"
            c_bits = f"{c_:03b}"

            if b_bits[0:3-shift] != c_bits[shift:]:
                continue
        
        # Program is all the relevant commands that affect output
        # Except for the commands that initialized B and C
        out, *_ = execute_program(0, b_, c_, [1, 2, 1, 3, 4, 4, 5, 5])

        bc_map[out].add((b_, c_))

In [8]:
bc_map

defaultdict(set,
            {1: {(0, 0), (1, 1), (2, 2), (4, 4), (5, 5), (6, 6), (7, 7)},
             3: {(0, 2), (1, 3), (3, 1), (4, 6), (5, 7), (6, 4), (7, 5)},
             5: {(0, 4), (1, 5), (4, 0), (5, 1), (6, 2), (7, 3)},
             7: {(0, 6), (1, 7), (3, 5), (4, 2), (5, 3), (6, 0), (7, 1)},
             0: {(1, 0), (4, 5), (5, 4), (6, 7), (7, 6)},
             2: {(1, 2), (4, 7), (5, 6), (6, 5), (7, 4)},
             4: {(1, 4), (4, 1), (5, 0), (6, 3), (7, 2)},
             6: {(1, 6), (4, 3), (5, 2), (6, 1), (7, 0)}})

In [9]:
# Every ith output gives us info on last 3 bits of A at that cycle
# Plus some more info depending on how much A was right shifted to get C
# Explore possibilities and generate final candidates that satisfy output condition

def determine_bits(program):
    total_bits = len(program)*3

    nodes = deque([(0, [-1]*total_bits)])

    candidates = []

    while nodes:
        pointer, guess = nodes.pop()

        if pointer == len(program):
            candidates.append(guess)
            continue

        target = program[pointer]

        

        for b,c in bc_map[target]:
            
            b_bits = f"{b:03b}"
            c_bits = f"{c:03b}"

            valid = True

            guess2 = copy(guess)

            b_start = total_bits - 3*pointer - 3
        
            c_start = total_bits - 3*pointer - 3 - (b ^ 2)

            for i in range(3):
                b_ind = b_start + i

                if guess2[b_ind] == -1:
                    guess2[b_ind] = int(b_bits[i])
                elif guess2[b_ind] != int(b_bits[i]):
                    valid = False
                    break

            if not valid:
                continue

            for j in range(3):
                c_ind = c_start + j
                
                if c_ind < 0:
                    if int(c_bits[j]) == 0:
                        continue
                    else:
                        valid = False
                        break
                elif guess2[c_ind] == -1:
                    guess2[c_ind] = int(c_bits[j])
                elif guess2[c_ind] != int(c_bits[j]):
                    valid = False
                    if pointer == 15:
                        print(c_ind, "error here")
                    break
            
            if not valid:
                continue

            nodes.append((pointer+1, guess2))
    
    return int(''.join(str(num) for num in min(candidates)), 2)
                




In [10]:
determine_bits(program)

37222273957364