In [1]:
from pathlib import Path

data = Path("input").read_text().strip().splitlines()

In [2]:
connections = {}
for line in data:
    node, nodes = line.split(": ")
    nodes = set(nodes.split())
    connections[node] = nodes

In [3]:
import networkx as nx

G = nx.DiGraph()
G.add_nodes_from(connections.keys())
G.add_nodes_from(set().union(*connections.values()))
for node, nodes in connections.items():
    for n in nodes:
        G.add_edge(node, n)

In [4]:
print("Part 1:")
print(len(list(nx.all_simple_paths(G, source="you", target="out"))))

Part 1:
470


In [5]:
from functools import lru_cache

def count_paths_with_nodes_dp(G, source, target, required_nodes):
    required_nodes = frozenset(required_nodes)

    @lru_cache(maxsize=None)
    def dfs(node, visited_required):
        if node == target:
            return int(visited_required == required_nodes)
        total = 0
        for neighbor in G.successors(node):
            new_visited = visited_required | (frozenset([neighbor]) if neighbor in required_nodes else frozenset())
            total += dfs(neighbor, new_visited)
        return total

    start_visited = frozenset([source]) if source in required_nodes else frozenset()
    return dfs(source, start_visited)

In [6]:
print("Part 2:")
print(count_paths_with_nodes_dp(G, source="svr", target="out", required_nodes={"dac", "fft"}))

Part 2:
384151614084875
