In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import logging

from intcode import Machine

In [3]:
def serial_connection(machines):
    # Connect outputs and inputs for Part 1
    n = len(machines)
    for i in range(1, n):
         machines[i].inputs = machines[i - 1].outputs
    
    # Inject an input for the first amp
    machines[0].add_input(0)
 
    return machines

In [4]:
def network_output(prog, phases, connect_func):
    # Create amps with current phase values
    machines = []
    for i in range(len(phases)):
        machines.append(Machine(prog, phases[i], machine_id=i, loglevel=logging.WARNING))

    # Connect inputs and outputs
    machines = connect_func(machines)

    # Run all amps until they're finished
    finished = [False]*5
    while not all(finished):
        for i in range(5):
            machines[i].run()
            finished[i] = machines[i].finished

    # When everything is finished, get the output of the last machine
    return machines[-1].outputs.popleft()

In [5]:
def find_max_val(prog, connect_func, phase_values):
    max_val = 0
    best_phases = None
    for phases in itertools.permutations(phase_values):
        output = network_output(prog, phases, connect_func)
        if output > max_val:
            max_val = output
            best_phases = phases

    return max_val, best_phases

# Part 1

In [6]:
# Tests
prog = "3,15,3,16,1002,16,10,16,1,16,15,15,4,15,99,0,0"
assert find_max_val(prog, serial_connection, range(5)) == (43210, (4,3,2,1,0))

prog = "3,23,3,24,1002,24,10,24,1002,23,-1,23,101,5,23,23,1,24,23,23,4,23,99,0,0"
assert find_max_val(prog, serial_connection, range(5)) == (54321, (0,1,2,3,4))
 
prog = "3,31,3,32,1002,32,10,32,1001,31,-2,31,1007,31,0,33,1002,33,7,33,1,33,31,31,1,32,31,31,4,31,99,0,0,0"
assert find_max_val(prog, serial_connection, range(5)) == (65210, (1,0,4,3,2))

In [7]:
with open("day07.input") as file:
    prog = file.readline().strip()

find_max_val(prog, serial_connection, range(5))

(24405, (2, 3, 0, 4, 1))

# Part 2

In [8]:
def feedback_connection(machines):
    # Connect outputs and inputs for Part 2
    n = len(machines)
    for i in range(n):
         machines[(i + 1) % n].inputs = machines[i].outputs

    # Inject an input for the first amp
    machines[0].add_input(0)

    return machines

In [9]:
# Tests
prog = "3,26,1001,26,-4,26,3,27,1002,27,2,27,1,27,26,27,4,27,1001,28,-1,28,1005,28,6,99,0,0,5"
assert find_max_val(prog, feedback_connection, range(5, 10)) == (139629729, (9,8,7,6,5))

prog = "3,52,1001,52,-5,52,3,53,1,52,56,54,1007,54,5,55,1005,55,26,1001,54,-5,54,1105,1,12,1,53,54,53,1008,54,0,55,1001,55,1,55,2,53,55,53,4,53,1001,56,-1,56,1005,56,6,99,0,0,0,0,10"
assert find_max_val(prog, feedback_connection, range(5, 10)) == (18216, (9,7,8,5,6))

In [10]:
with open("day07.input") as file:
    prog = file.readline().strip()

find_max_val(prog, feedback_connection, range(5, 10))

(8271623, (5, 7, 9, 8, 6))