In [None]:
from functools import reduce
from itertools import permutations

In [1]:
program = [3,8,1001,8,10,8,105,1,0,0,21,46,55,76,89,106,187,268,349,430,99999,3,9,101,4,9,9,1002,9,2,9,101,5,9,9,1002,9,2,9,101,2,9,9,4,9,99,3,9,1002,9,5,9,4,9,99,3,9,1001,9,2,9,1002,9,4,9,101,2,9,9,1002,9,3,9,4,9,99,3,9,1001,9,3,9,1002,9,2,9,4,9,99,3,9,1002,9,4,9,1001,9,4,9,102,5,9,9,4,9,99,3,9,101,1,9,9,4,9,3,9,102,2,9,9,4,9,3,9,1001,9,2,9,4,9,3,9,101,2,9,9,4,9,3,9,1001,9,1,9,4,9,3,9,101,1,9,9,4,9,3,9,102,2,9,9,4,9,3,9,102,2,9,9,4,9,3,9,1002,9,2,9,4,9,3,9,101,1,9,9,4,9,99,3,9,102,2,9,9,4,9,3,9,1002,9,2,9,4,9,3,9,101,1,9,9,4,9,3,9,101,2,9,9,4,9,3,9,102,2,9,9,4,9,3,9,102,2,9,9,4,9,3,9,102,2,9,9,4,9,3,9,1001,9,1,9,4,9,3,9,101,2,9,9,4,9,3,9,1002,9,2,9,4,9,99,3,9,101,1,9,9,4,9,3,9,101,1,9,9,4,9,3,9,101,2,9,9,4,9,3,9,1002,9,2,9,4,9,3,9,1001,9,2,9,4,9,3,9,1001,9,1,9,4,9,3,9,1001,9,2,9,4,9,3,9,102,2,9,9,4,9,3,9,102,2,9,9,4,9,3,9,102,2,9,9,4,9,99,3,9,101,1,9,9,4,9,3,9,102,2,9,9,4,9,3,9,101,2,9,9,4,9,3,9,101,1,9,9,4,9,3,9,102,2,9,9,4,9,3,9,1002,9,2,9,4,9,3,9,102,2,9,9,4,9,3,9,1001,9,2,9,4,9,3,9,102,2,9,9,4,9,3,9,101,1,9,9,4,9,99,3,9,1001,9,1,9,4,9,3,9,1001,9,1,9,4,9,3,9,102,2,9,9,4,9,3,9,102,2,9,9,4,9,3,9,1001,9,1,9,4,9,3,9,1001,9,1,9,4,9,3,9,1001,9,1,9,4,9,3,9,1002,9,2,9,4,9,3,9,101,2,9,9,4,9,3,9,101,1,9,9,4,9,99]

In [2]:
ADD = 1
MUL = 2
SAV = 3
OUT = 4
JT = 5
JF = 6
LT = 7
EQ = 8
HLT = 99

class IntCode(object):
    def __init__(self, prog: list):
        self.__orig_prog = prog.copy()
        self.restart()
    
    def restart(self):
        self.prog = self.__orig_prog.copy()
        self.pc = 0
        self.halted = False
        self.output = None

    @staticmethod
    def __decode_opcode(op):
        opcode = op % 100
        mode1 = (op // 100) % 10
        mode2 = (op // 1000) % 10
        mode3 = (op // 10000) % 10

        return (opcode, mode1, mode2, mode3)

    @staticmethod
    def __validate_opcode(opcode):
        if not (opcode == ADD or opcode == MUL or
                opcode == SAV or opcode == OUT or
                opcode == JT or opcode == JF or
                opcode == LT or opcode == EQ or
                opcode == HLT):
            raise RuntimeError(f'Unknown opcode {opcode}')

    def __get_arg(self, p, mode):
        arg = self.prog[self.pc + p]
        return self.prog[arg] if mode == 0 else arg
    
    def __execute_opcode(self, input, opcode, mode1, mode2, mode3):
        self.__validate_opcode(opcode)

        pc = self.pc

        # SAV, OUT
        if opcode == SAV:
            self.prog[self.prog[pc + 1]] = input
            return (pc + 2, None)
        elif opcode == OUT:
            return (pc + 2, self.prog[self.prog[pc + 1]])

        arg1 = self.__get_arg(1, mode1)
        arg2 = self.__get_arg(2, mode2)

        # conditions
        if opcode == JT:
            return (arg2 if arg1 else pc + 3, None)
        elif opcode == JF:
            return (arg2 if not arg1 else pc + 3, None)
        elif opcode == LT:
            self.prog[self.prog[pc + 3]] = 1 if arg1 < arg2 else 0
            return (pc + 4, None)
        elif opcode == EQ:
            self.prog[self.prog[pc + 3]] = 1 if arg1 == arg2 else 0
            return (pc + 4, None)

        # ADD, MUL
        if opcode == ADD:
            self.prog[self.prog[pc + 3]] = arg1 + arg2
        elif opcode == MUL:
            self.prog[self.prog[pc + 3]] = arg1 * arg2

        return (pc + 4, None)

    def run(self, inputs, stop_on_out = True):
        if self.halted:
            return self.output

        curr_input = 0

        while self.prog[self.pc] != HLT:
            try:
                opcode = self.prog[self.pc]
                dec = self.__decode_opcode(opcode)

                inp = None
                if dec[0] == SAV:
                    inp = inputs[curr_input] #min(len(inputs) - 1, curr_input)]
                    curr_input += 1

                self.pc, out = self.__execute_opcode(inp, *dec)
#                 print(f'aaa: {self.pc}: {self.prog}')

                if out != None:
                    self.output = out
                    if stop_on_out:
                        return self.output
            except Exception as e:
                print(f'Error: pc = {self.pc}\nprog = {self.prog}\ninput: {inputs}\nexception: {e}')
                raise

        self.halted = True
        return self.output

In [26]:
%%time
# program = [3,15,3,16,1002,16,10,16,1,16,15,15,4,15,99,0,0]
program = [3,31,3,32,1002,32,10,32,1001,31,-2,31,1007,31,0,33,1002,33,7,33,1,33,31,31,1,32,31,31,4,31,99,0,0,0]
# program = [3,23,3,24,1002,24,10,24,1002,23,-1,23,101,5,23,23,1,24,23,23,4,23,99,0,0]

max_res = float('-inf')

for inp in list(permutations(range(0, 5))):
    ics = [IntCode(program) for _ in range(0, 5)]

    # run until all of them are halted
    res = 0
    for s in range(0, 5):
        res = ics[s].run([inp[s], res], False)

    if res > max_res:
        max_res = res
        print(f'{input1}: {res}')

(4, 3, 2, 1, 0): 56012
(4, 3, 2, 1, 0): 56021
(4, 3, 2, 1, 0): 56102
(4, 3, 2, 1, 0): 56120
(4, 3, 2, 1, 0): 56201
(4, 3, 2, 1, 0): 56210
(4, 3, 2, 1, 0): 65012
(4, 3, 2, 1, 0): 65021
(4, 3, 2, 1, 0): 65102
(4, 3, 2, 1, 0): 65120
(4, 3, 2, 1, 0): 65201
(4, 3, 2, 1, 0): 65210
CPU times: user 16.7 ms, sys: 1.99 ms, total: 18.7 ms
Wall time: 18.1 ms


In [27]:
%%time
program = [3,26,1001,26,-4,26,3,27,1002,27,2,27,1,27,26,27,4,27,1001,28,-1,28,1005,28,6,99,0,0,5]
# program = [3,52,1001,52,-5,52,3,53,1,52,56,54,1007,54,5,55,1005,55,26,1001,54,
#            -5,54,1105,1,12,1,53,54,53,1008,54,0,55,1001,55,1,55,2,53,55,53,4,
#            53,1001,56,-1,56,1005,56,6,99,0,0,0,0,10]

max_res = float('-inf')

for inp in list(permutations(range(5, 10))):
    ics = [IntCode(program) for _ in range(0, 5)]

    # run until all of them are halted
    res = 0
    while not reduce(lambda x, y: x and y, [s.halted for s in ics]):
        for s in range(0, 5):
            res = ics[s].run([inp[s], res])

    if res > max_res:
        max_res = res
        print(f'{inp}: {res}')

(5, 6, 7, 8, 9): 23
CPU times: user 54.6 ms, sys: 3.01 ms, total: 57.6 ms
Wall time: 55.5 ms


In [12]:
ics[4].output

6