In [1]:
import re
import matplotlib.pyplot as plt
from queue import PriorityQueue
from tqdm import tqdm

In [2]:
with open("./data_inputs/day16_input.txt") as f:
    input_raw = f.read()

In [3]:
class Node:
    def __init__(self, id, flow=0, adjacents=[]):
        self.id = id
        self.flow = flow
        self.adjacents = adjacents
        self.is_valve_open = False

    def __repr__(self):
        return f"{self.id}: flow = {self.flow}, adj = {self.adjacents}"

    def open_valve(self):
        flow = self.flow
        self.flow = 0
        self.is_valve_open = False
        return flow


class Graph:
    def __init__(self):
        self.nodes = {}

    def __repr__(self):
        return "".join([f"{self.nodes[node]}\n" for node in self.nodes])

    def add_node(self, id, flow=0, adjacents=[]):
        self.nodes[id] = Node(id, flow, adjacents)

  
g = Graph()

for line in input_raw.split("\n"):
    node = re.findall(r"^Valve...", line)[0][-2:]
    flow = int(re.findall(r"=\d+", line)[0][1:])
    adj_nodes = re.split("valves? ", line)[1].split(", ")

    g.add_node(node, flow, adj_nodes)

In [4]:
# Dijkstra to find the shortest path between closed valves

class D_Graph:
    def __init__(self, nodes):
        self.v = len(nodes)
        self.edges = {n: {} for n in nodes}
        self.visited = []

    def add_edge(self, u, v, weight):
        self.edges[u][v] = weight


def dijkstra(graph, start_node):
    D = {n:float('inf') for n in graph.edges.keys()}
    D[start_node] = 0

    pq = PriorityQueue()
    pq.put((0, start_node))

    while not pq.empty():
        (dist, current_node) = pq.get()
        graph.visited.append(current_node)

        for neighbor in graph.edges.keys():
            if neighbor in graph.edges[current_node]:
                distance = graph.edges[current_node][neighbor]
                if neighbor not in graph.visited:
                    old_cost = D[neighbor]
                    new_cost = D[current_node] + distance
                    if new_cost < old_cost:
                        pq.put((new_cost, neighbor))
                        D[neighbor] = new_cost
    return D


def build_d_graph(graph):
    d_g = D_Graph(graph.nodes)

    for n in graph.nodes:
        for adj_n in graph.nodes[n].adjacents:
            d_g.add_edge(n, adj_n, 1)

    return d_g


flow_nodes = {node: {} for node in g.nodes if g.nodes[node].flow > 0}
flow_nodes["AA"] = {}

for src_f_node in flow_nodes:
    d_g = build_d_graph(g)
    D_n = dijkstra(d_g, src_f_node)
    for dest_f_node in flow_nodes:
        if dest_f_node != src_f_node:
            flow_nodes[src_f_node][dest_f_node] = D_n[dest_f_node]

In [5]:
# ---- Part 1 ----

def best_branch_pressure(prev_node, curr_node, minutes, visited, total_flow, total_pressure, total_minutes, max_time):
    
    if (total_minutes + minutes) >= max_time:
        minutes = (max_time - total_minutes) 
        total_pressure += (minutes * total_flow)
        return total_pressure

    total_minutes += minutes
    total_pressure += (minutes * total_flow)

    # Opening valve
    flow = g.nodes[curr_node].flow
    next_total_flow = total_flow + flow
    next_visited = visited.copy()
    next_visited[curr_node] = (minutes+1, total_flow)
    if prev_node:
        total_minutes += 1
        total_pressure += total_flow

    if len(next_visited) < len(flow_nodes):
        curr_best_branch = total_pressure

        for next_node in flow_nodes[curr_node]:
            if next_node not in next_visited:
                next_minutes = flow_nodes[curr_node][next_node]
                
                next_total_pressure = best_branch_pressure(curr_node, 
                                                next_node, 
                                                next_minutes, 
                                                next_visited, 
                                                next_total_flow,
                                                total_pressure, 
                                                total_minutes,
                                                max_time)

                if next_total_pressure > curr_best_branch:
                    curr_best_branch = next_total_pressure
    
    else:
        curr_best_branch = (total_pressure + (max_time - total_minutes) * next_total_flow)

    return curr_best_branch


MAX_TIME = 30   

best_pressure = best_branch_pressure(prev_node=None, 
                            curr_node="AA", 
                            minutes=0,
                            visited={}, 
                            total_flow=0, 
                            total_pressure=0, 
                            total_minutes=0,
                            max_time=MAX_TIME)

print("Result 1:", best_pressure)

Result 1: 1940


In [6]:
# ---- Part 2 ----

MAX_TIME_2 = 26
n = (len(flow_nodes)-1)
n_combinations = int("1" * n, 2) + 1

bit_combinations = [f"{c:015b}" for c in range(n_combinations // 2)]

max_combination_pressure = 0

for comb in tqdm(bit_combinations):

    visited_by_me = {node: (0, 0) for i, node in enumerate(list(flow_nodes.keys())[:-1]) if int(comb[i])}
    visited_by_elephant = {node: (0, 0) for i, node in enumerate(list(flow_nodes.keys())[:-1]) if not int(comb[i])}

    best_pressure_me = best_branch_pressure(
                                    prev_node=None, 
                                    curr_node="AA", 
                                    minutes=0,
                                    visited=visited_by_elephant, 
                                    total_flow=0, 
                                    total_pressure=0, 
                                    total_minutes=0,
                                    max_time=MAX_TIME_2)

    best_pressure_elephant = best_branch_pressure(
                                    prev_node=None, 
                                    curr_node="AA", 
                                    minutes=0,
                                    visited=visited_by_me, 
                                    total_flow=0, 
                                    total_pressure=0, 
                                    total_minutes=0,
                                    max_time=MAX_TIME_2)

    total_best_pressure = best_pressure_me + best_pressure_elephant
    if total_best_pressure > max_combination_pressure:
        max_combination_pressure = total_best_pressure

print("Result 2:", max_combination_pressure)

100%|██████████| 16384/16384 [01:33<00:00, 175.64it/s]

Result 2: 2469



