In [1]:
from math import floor

In [2]:
def parse_opcode(opcode):
    if opcode >= 100: # two-digit opcode
        op    = opcode % 100
        mode1 = floor(opcode % 1000 / 100)
        mode2 = floor(opcode % 10000 / 1000)
        mode3 = floor(opcode % 100000 / 10000)
    else:
        op = opcode
        mode1, mode2, mode3 = 0, 0, 0 # position mode
    return op, mode1, mode2, mode3

In [136]:
def fetch_value(data, arg, mode):
    if mode == 0: # position
        return data[arg]
    elif mode == 1: # immediate
        return arg

In [147]:
from collections import deque

In [203]:
def interpret(data, pos=0, input_queue=None, output_queue=None):
    opcode, mode1, mode2, mode3 = parse_opcode(data[pos])
    if opcode == 1: # add
        arg1 = data[pos+1]
        arg2 = data[pos+2]
        arg3 = data[pos+3]
        
        val1 = fetch_value(data, arg1, mode1)
        val2 = fetch_value(data, arg2, mode2)
        assert(mode3 == 0)
        
        data[arg3] = val1 + val2
        return interpret(data, pos + 4, input_queue, output_queue)
    elif opcode == 2: # multiply
        arg1 = data[pos+1]
        arg2 = data[pos+2]
        arg3 = data[pos+3]
        
        val1 = fetch_value(data, arg1, mode1)
        val2 = fetch_value(data, arg2, mode2)
        assert(mode3 == 0)
        
        data[arg3] = val1 * val2
        return interpret(data, pos + 4, input_queue, output_queue)
    elif opcode == 3: # input
        arg1 = data[pos+1]
        assert(mode1 == 0)
        if input_queue:
            assert(len(input_queue) > 0)
            data[arg1] = input_queue.popleft()
        else:
            data[arg1] = int(input())
        return interpret(data, pos + 2, input_queue, output_queue)
    elif opcode == 4: # output
        arg1 = data[pos+1]
        val1 = fetch_value(data, arg1, mode1)
        if output_queue != None:
            output_queue.append(val1)
        else:
            print(val1)
        return interpret(data, pos + 2, input_queue, output_queue)
    elif opcode == 5: # jump-if-true
        arg1 = data[pos+1]
        arg2 = data[pos+2]

        val1 = fetch_value(data, arg1, mode1)
        val2 = fetch_value(data, arg2, mode2)
        
        if val1 != 0:
            pos = val2
            return interpret(data, pos, input_queue, output_queue)
    
        return interpret(data, pos+3, input_queue, output_queue)
    elif opcode == 6: # jump-if-false
        arg1 = data[pos+1]
        arg2 = data[pos+2]

        val1 = fetch_value(data, arg1, mode1)
        val2 = fetch_value(data, arg2, mode2)
        
        if val1 == 0:
            pos = val2
            return interpret(data, pos, input_queue, output_queue)
        
        return interpret(data, pos+3, input_queue, output_queue)
    elif opcode == 7: # less than
        arg1 = data[pos+1]
        arg2 = data[pos+2]
        arg3 = data[pos+3]

        val1 = fetch_value(data, arg1, mode1)
        val2 = fetch_value(data, arg2, mode2)
        assert(mode3 == 0)
        
        if val1 < val2:
            data[arg3] = 1
        else:
            data[arg3] = 0
        
        return interpret(data, pos + 4, input_queue, output_queue)
    elif opcode == 8: # equals
        arg1 = data[pos+1]
        arg2 = data[pos+2]
        arg3 = data[pos+3]

        val1 = fetch_value(data, arg1, mode1)
        val2 = fetch_value(data, arg2, mode2)
        assert(mode3 == 0)
        
        if val1 == val2:
            data[arg3] = 1
        else:
            data[arg3] = 0
        
        return interpret(data, pos + 4, input_queue, output_queue)        
    elif opcode == 99: # end
        return data

In [5]:
def load_input():
    with open("input") as infile:
        data = infile.read()
    data = data.rstrip()
    input_data = data.split(",")
    input_data = list(map(int, input_data))
    return input_data

In [14]:
ex1prog = [3,15,3,16,1002,16,10,16,1,16,15,15,4,15,99,0,0]

In [15]:
ex1phase = [4,3,2,1,0]

In [16]:
A, B, C, D, E = ex1prog.copy(), ex1prog.copy(), ex1prog.copy(), ex1prog.copy(), ex1prog.copy()

In [18]:
interpret(A)

4
0
4


[3, 15, 3, 16, 1002, 16, 10, 16, 1, 16, 15, 15, 4, 15, 99, 4, 0]

In [19]:
interpret(B)

3
4
43


[3, 15, 3, 16, 1002, 16, 10, 16, 1, 16, 15, 15, 4, 15, 99, 43, 40]

In [20]:
interpret(C)

2
43
432


[3, 15, 3, 16, 1002, 16, 10, 16, 1, 16, 15, 15, 4, 15, 99, 432, 430]

In [21]:
interpret(D)

1
432
4321


[3, 15, 3, 16, 1002, 16, 10, 16, 1, 16, 15, 15, 4, 15, 99, 4321, 4320]

In [22]:
interpret(E)

0
4321
43210


[3, 15, 3, 16, 1002, 16, 10, 16, 1, 16, 15, 15, 4, 15, 99, 43210, 43210]

In [23]:
ex2prog = [3,23,3,24,1002,24,10,24,1002,23,-1,23,101,5,23,23,1,24,23,23,4,23,99,0,0]

In [24]:
ex2phase = [0,1,2,3,4]

In [26]:
A, B, C, D, E = ex2prog.copy(), ex2prog.copy(), ex2prog.copy(), ex2prog.copy(), ex2prog.copy()

In [28]:
output = interpret(A)

0
0
5


In [29]:
output = interpret(B)

1
5
54


In [30]:
output = interpret(C)

2
54
543


In [31]:
output = interpret(D)

3
543
5432


In [32]:
output = interpret(E)

4
5432
54321


In [33]:
ex3prog = [3,31,3,32,1002,32,10,32,1001,31,-2,31,1007,31,0,33,1002,33,7,33,1,33,31,31,1,32,31,31,4,31,99,0,0,0]

In [34]:
ex3phase = [1,0,4,3,2]

In [35]:
A, B, C, D, E = ex3prog.copy(), ex3prog.copy(), ex3prog.copy(), ex3prog.copy(), ex3prog.copy()

In [37]:
output = interpret(A)

1
0
6


In [38]:
output = interpret(B)

0
6
65


In [39]:
output = interpret(C)

4
65
652


In [40]:
output = interpret(D)

3
652
6521


In [41]:
output = interpret(E)

2
6521
65210


In [255]:
def allPermutations(phases, fromPos=0):
    if fromPos + 1 >= len(phases):
        yield phases
    else:
        for p in allPermutations(phases, fromPos + 1):
            yield p.copy()
        for i in range(fromPos + 1, len(phases)):        
            phases[fromPos], phases[i] = phases[i], phases[fromPos]
            for p in allPermutations(phases, fromPos + 1):
                yield p.copy()     
            phases[fromPos], phases[i] = phases[i], phases[fromPos]

In [256]:
list(allPermutations(["A", "B", "C"]))

[['A', 'B', 'C'],
 ['A', 'C', 'B'],
 ['B', 'A', 'C'],
 ['B', 'C', 'A'],
 ['C', 'B', 'A'],
 ['C', 'A', 'B']]

In [257]:
prog = load_input()

In [261]:
amps = [0, 1, 2, 3, 4]

In [262]:
def findBest(prog, amps):
    bestPhases = None
    highestOutput = None

    for phases in allPermutations(amps):
        inputVal = 0
        for ampIndex, amp in enumerate(amps):
            ampProg = prog.copy()
            ampInput = deque([phases[ampIndex], inputVal])
            ampOutput = deque()
            interpret(ampProg, input_queue=ampInput, output_queue=ampOutput)
            assert(len(ampInput) == 0)
            outputSignal = ampOutput.pop()
            assert(len(ampOutput) == 0)
            if ampIndex == len(amps)-1:
                if highestOutput:
                    if outputSignal > highestOutput:
                        highestOutput = outputSignal
                        bestPhases = phases.copy()
                else:
                    highestOutput = outputSignal
                    bestPhases = phases.copy()
            else:
                inputVal = outputSignal

    return bestPhases, highestOutput

In [263]:
findBest(ex1prog, amps)

([4, 3, 2, 1, 0], 43210)

In [264]:
findBest(ex2prog, amps)

([0, 1, 2, 3, 4], 54321)

In [265]:
findBest(ex3prog, amps)

([1, 0, 4, 3, 2], 65210)

In [266]:
findBest(prog, amps)

([2, 0, 1, 4, 3], 880726)