In [1]:
# Get raw data
with open('input/07.txt', 'r') as f:
    rawinput = f.read().strip()

In [2]:
# Part 1
class Program(object):
    def __init__(self, instr, input_val, prev=None):
        self.ops = {
            1: self.do_add,
            2: self.do_mult,
            3: self.do_input,
            4: self.do_output,
            5: self.do_jump_if_true,
            6: self.do_jump_if_false,
            7: self.do_less_than,
            8: self.do_equals,
            }
        self.instr_orig = instr
        self.input_val_orig = input_val if isinstance(input_val, list) else [input_val_orig]
        self.input_val = []
        self.output = []
        if prev:
            prev.output = self.input_val
        self.reset()
        
    def reset(self):
        self.instr = [*self.instr_orig]
        self.input_val.clear()
        self.input_val.extend(self.input_val_orig)
        self.output.clear()
        self.ptr = 0
        
    def actuals(self, params):
        return [self.instr[i] if j=='0' else i 
                for i,j in params]
        
    def do_add(self, params):
        self.instr[params[2][0]] = sum(self.actuals(params[:2]))
        self.ptr += 4

    def do_mult(self, params):
        self.instr[params[2][0]] = (z:=self.actuals(params[:2]))[0]*z[1]
        self.ptr += 4
        
    def do_input(self, params):
        if self.input_val:
            self.instr[params[0][0]] = self.input_val.pop(0)
            self.ptr += 2
        
    def do_output(self, params):
        self.output += self.actuals(params[:1])
        self.ptr += 2
        
    def do_jump_if_true(self, params):
        if (z:=self.actuals(params[:2]))[0]:
            self.ptr = z[1]
        else:
            self.ptr += 3

    def do_jump_if_false(self, params):
        if (z:=self.actuals(params[:2]))[0] == 0:
            self.ptr = z[1]
        else:
            self.ptr += 3

    def do_less_than(self, params):
        self.instr[params[2][0]] = int((z:=self.actuals(params[:2]))[0] < z[1])
        self.ptr += 4
        
    def do_equals(self, params):
        self.instr[params[2][0]] = int((z:=self.actuals(params[:2]))[0] == z[1])
        self.ptr += 4

    def do_step(self):
        opmode = self.instr[self.ptr]
        opcode = opmode % 100
        if opcode != 99:
            params = [[i,j]
                      for i,j in zip(self.instr[self.ptr+1:], 
                                     list((str(opmode)[-3::-1]+'000'))[:3])]
            self.ops[opcode](params)

instr = [int(i) for i in rawinput.split(',')]

def test_phase(phase):
    amps = [(pp:=Program(instr,
                         [j, *([] if i else [0])],
                         pp if i else None))
            for i,j in enumerate(phase)]
    while any([p.instr[p.ptr] != 99 for p in amps]):
        for p in amps:
            p.do_step()
    return amps[-1].output[0]

max([test_phase([a,b,c,d,e])
     for a in range(5)
     for b in sorted({*range(5)}-{a})
     for c in sorted({*range(5)}-{a,b})
     for d in sorted({*range(5)}-{a,b,c})
     for e in sorted({*range(5)}-{a,b,c,d})])

368584

In [3]:
# Part 2
def test_phase(phase):
    amps = [(pp:=Program(instr,
                         [j, *([] if i else [0])],
                         pp if i else None))
            for i,j in enumerate(phase)]
    amps[-1].output = amps[0].input_val
    while any([p.instr[p.ptr] != 99 for p in amps]):
        for p in amps:
            p.do_step()
    return amps[-1].output[0]

max([test_phase([a,b,c,d,e])
     for a in range(5,10)
     for b in sorted({*range(5,10)}-{a})
     for c in sorted({*range(5,10)}-{a,b})
     for d in sorted({*range(5,10)}-{a,b,c})
     for e in sorted({*range(5,10)}-{a,b,c,d})])

35993240