In [1]:
sample_register = {
    "A": 729,
    "B": 4,
    "C": 0
}

sample_program = [0,1,5,4,3,0]

In [2]:
registers = {
    "A": 65804993,
    "B": 4,
    "C": 0
}

program = [2,4,1,1,7,5,1,4,0,3,4,5,5,5,3,0]

In [3]:
def combo(n):
    if n > 6:
        raise ValueError(f"Bad Combo:{n}")
    if n <= 3:
        return n
    return REGISTER['ABC'[n-4]]

In [4]:
REGISTER = None
INST = 0
OUTPUT = []

def adv(n):
    REGISTER['A'] = REGISTER['A'] // (2 ** combo(n)) 

def bxl(n):
    REGISTER['B'] = REGISTER['B'] ^ n

def bst(n):
    REGISTER['B'] = combo(n) % 8

def jnz(n):
    global INST
    # how to handle when running
    # to avoid moving the INST after the jump
    if REGISTER['A'] != 0:
        INST = n

def bxc(n):
    REGISTER['B'] = REGISTER['B'] ^ REGISTER['C']

def out(n):
    global OUTPUT
    OUTPUT.append(combo(n) % 8)

def bdv(n):
    REGISTER['B'] = REGISTER['A'] // (2 ** combo(n)) 

def cdv(n):
    REGISTER['C'] = REGISTER['A'] // (2 ** combo(n)) 

def init(registers):
    global REGISTER
    global INST
    global OUTPUT
    REGISTER = registers.copy()
    INST = 0
    OUTPUT = []
    
opcodes = [
    adv,
    bxl,
    bst,
    jnz,
    bxc,
    out,
    bdv,
    cdv
]

def run(program, registers):
    global INST
    init(registers)
    while INST < len(program):
        op = opcodes[program[INST]]
        INST += 1
        arg = program[INST]
        INST += 1

        op(arg)
    return OUTPUT
    #return ','.join(OUTPUT)


In [5]:
run(sample_program, sample_register)

[4, 6, 3, 5, 6, 3, 5, 2, 1, 0]

In [6]:
run(program, registers)

[5, 1, 4, 0, 5, 1, 0, 2, 6]

### Part Two
Not sure what the *right* way to do this is, but it's clear the program output grows with input. It also outputs base 8 numbers from base 8 input. 

And it's relatively predicatble:

```
175921860444160 (5 * 8 ** 15) -> first 0 ending = [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 0]
```

It's a little manual, but starting with the most significant digit we can work backward gradually to a more accurate number until we have the whole thin. The search space shrinks exponentially this way.

For example just by iterating 64 numbers you can identify:

`[5, 6]` as the two most siginifacnt digits in the solution

`5 * (8 ** 15) + 6 * (6 ** 14) = 202322936867370`

Plugging this in to the register outputs:


```
with the correct right-most digits      v  v  v  
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 3, 0]
``` 

Continuing you find `[5,6,0,0,1]` which as the first 5 digits of a 16-digit 8 bit number equalling `202322219106304`.

Pluggin this in to the register gives:

```
with more correct digits       v  v  v  v  v  v      
[5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 5, 5, 5, 3, 0]

```

Iterating over one 8-bit number are a time, did not really work — it doesn't always find the next digit, but iterating over a few and taking the most significant bit from the first match did. 

In [7]:
# Some helper functions when exploring the problem space

def to_base_8(n):
    s = ""
    while n:
        s = str(n % 8) + s
        n //= 8
    return s
    
def from_base_8_l(l):
    s = ''.join(map(str, b[::-1]))
    return int(s, 8)
    
def from_base_8(n):
    return int(n, 8)

to_base_8(65804993)  # original register a in base-8


'373015301'

In [8]:
# the original program input
to_match = [2, 4, 1, 1, 7, 5, 1, 4, 0, 3, 4, 5, 5, 5, 3, 0]

def solve(to_match):
    # get the first 15 digits - we learned it will start with 5 above
    # so give it a head start
    bases = [5]
    for i in range(1, 14):
        bases, match = get_digit(bases[:i], to_match)
        if match:
            return match
    return bases


def get_digit(bases, to_match):
    # Take three bits at a time
    # to avoid missing the digit
    for i in range(8):
        for j in range(8):
            for k in range(8):
                if len(bases)<14:
                    base_plus =  bases + [i, j, k]
                else:
                    base_plus =  bases + [i]

                r = sum(n * 8 ** (15 - i) for i, n in enumerate(base_plus))

                count_registers = {
                                'A': r,
                                'B': 0,
                                'C': 0,
                            }
                
                window = -len(base_plus) + 1
                try:
                    b = run(program, count_registers)
                except:
                    continue
                if b == to_match:
                    return None, r
                if b[window:] == to_match[window:]:
                    return bases + [i, j], None



In [9]:
first_bits = solve(to_match)
print("All but last bit:", first_bits)

# this fails on the last digit, so just find it out of the eight left:

for i in range(8):
    r = sum(n * 8 ** (15 - i) for i, n in enumerate(first_bits + [i]))
    
    count_registers = {
        'A': r,
        'B': 0,
        'C': 0,
    }
    
    b = run(program, count_registers) 
    
    if b == to_match:
        print("Solution bits:",  first_bits + [i])
        print("solution:", r)
        break

202322936867370

All but last bit: [5, 6, 0, 0, 1, 3, 7, 2, 6, 2, 0, 2, 5, 0, 5]
Solution bits: [5, 6, 0, 0, 1, 3, 7, 2, 6, 2, 0, 2, 5, 0, 5, 2]
solution: 202322936867370


202322936867370