In [1]:
import math
import time
import numpy as np

from itertools import permutations
from numba import njit

In [2]:
def read_input(filename):
    
    with open(filename) as f:
        lines = f.readlines()
        
    big_dict = {}

    for line in lines:
        line_parts = line.split()
        valve_name = line_parts[1]
        flow_rate = int(line_parts[4].split("=")[1].strip(';'))
        neighbors = [part.strip(',') for part in line_parts[9:]]
        d = {
            'flow_rate': flow_rate,
            'neighbors': neighbors
        }
        big_dict[valve_name] = d
            
    return big_dict


def find_shortest_path(start, end, big_dict):
    
    paths_list = [[start]]
    
    if start == end:
        return 0

    i = 0
    
    while True:
        
        i = i+1
        new_paths_list = []
        
        for path in paths_list:
            
            for neighbor in big_dict[path[-1]]["neighbors"]:
                if not neighbor in path:
                    
                    new_path = path.copy()
                    new_path.append(neighbor)
                    new_paths_list.append(new_path)
                    
                    if neighbor == end:
                        return i
                    
        paths_list = new_paths_list
        
        
def get_small_network(big_dict):

    valves_nonzero = [valve for valve in big_dict.keys() if big_dict[valve]['flow_rate']>0]
    relevant_valves = valves_nonzero.copy()
    relevant_valves.insert(0, 'AA')

    shortest_paths = np.array([[find_shortest_path(start, end, big_dict) for end in relevant_valves] for start in relevant_valves])

    shortest_times = shortest_paths.copy()
    for j in range(1, 6+1):
        shortest_times[:,j] += 1

    flow_rates = np.array([big_dict[valve]['flow_rate'] for valve in relevant_valves])
    
    return relevant_valves, shortest_times, flow_rates


@njit
def get_pressure(valve_order, flow_rates, shortest_times):
    time = 30
    pressure = 0
    
    valve_2 = valve_order[0]
    time += - shortest_times[0, valve_2]
    pressure += flow_rates[valve_2]*time
    
    for valve_1, valve_2 in zip(valve_order[:-1], valve_order[1:]):
        time += - shortest_times[valve_1, valve_2]
        if time > 0:
            pressure += flow_rates[valve_2]*time
        else:
            break
    return pressure

In [3]:
big_dict_test = read_input("16_test.txt")

rel_valves_test, shortest_times_test, flow_rates_test = get_small_network(big_dict_test)

n_nonzero_test = len(rel_valves_test[1:])
results_test = np.zeros(math.factorial(n_nonzero_test), dtype="int")
for i, valve_order in enumerate(permutations(range(1,n_nonzero_test+1))):
    valve_order = list(valve_order)
    valve_order.insert(0, 0)
    valve_order = np.array(valve_order)
    results_test[i] = get_pressure(valve_order, flow_rates_test, shortest_times_test)
    
print(np.max(results_test))

valve_order_max = list(permutations(range(1,n_nonzero_test+1)))[np.argmax(results_test)]
print([rel_valves_test[i] for i in valve_order_max])

1651
['DD', 'BB', 'JJ', 'HH', 'EE', 'CC']


In [4]:
big_dict = read_input("16.txt")

rel_valves, shortest_times, flow_rates = get_small_network(big_dict)

In [5]:
max_pressure = 0

t0 = time.time()

n_nonzero = len(flow_rates[1:])
for i, valve_order in enumerate(permutations(range(1,n_nonzero+1))):
    valve_order = np.array(valve_order)
    pressure = get_pressure(valve_order, flow_rates, shortest_times)
    if pressure > max_pressure:
        max_pressure = pressure
    if i > 10**6:
        break
        
print(time.time() - t0)

3.316317558288574


This is how many hours it would take to solve the optimization problem for the input data using the algorithm above:

In [6]:
(time.time() - t0)/60/60*math.factorial(15)/10**6

1207.8900687002563

In [7]:
flow_rates

array([ 0,  9, 20, 13, 16, 19,  8, 21, 11, 23, 14, 18, 25,  3, 22,  5])

In [8]:
shortest_times

array([[ 0,  5,  7,  8, 10,  4,  3,  6,  3,  9,  6,  6,  9,  2,  4,  3],
       [ 4,  1,  3,  4,  6,  6,  5,  2,  6, 12,  9,  9,  5,  2,  6,  4],
       [ 6,  3,  1,  3,  4,  8,  7,  4,  8, 14, 11, 11,  3,  4,  8,  6],
       [ 7,  4,  3,  1,  3,  9,  8,  2,  9, 15, 12, 12,  3,  5,  9,  7],
       [ 9,  6,  4,  3,  1, 11, 10,  3, 11, 17, 14, 14,  5,  7, 11,  9],
       [ 3,  6,  8,  9, 11,  1,  6,  7,  5, 11,  8,  8, 10,  3,  3,  2],
       [ 2,  5,  7,  8, 10,  6,  1,  6,  2,  8,  5,  5,  9,  2,  2,  3],
       [ 6,  3,  5,  3,  4,  8,  7,  0,  8, 14, 11, 11,  5,  4,  8,  6],
       [ 3,  7,  9, 10, 12,  6,  3,  8,  0,  6,  3,  3, 11,  4,  4,  3],
       [ 9, 13, 15, 16, 18, 12,  9, 14,  6,  0,  3,  9, 17, 10, 10,  9],
       [ 6, 10, 12, 13, 15,  9,  6, 11,  3,  3,  0,  6, 14,  7,  7,  6],
       [ 6, 10, 12, 13, 15,  9,  6, 11,  3,  9,  6,  0, 14,  7,  7,  6],
       [ 9,  6,  4,  4,  6, 11, 10,  5, 11, 17, 14, 14,  0,  7, 11,  9],
       [ 2,  3,  5,  6,  8,  4,  3,  4,  4, 10,  7,