In [1]:
from string import ascii_lowercase, ascii_uppercase
import collections
import functools
import heapq

In [2]:
def get_grid(filename):
    grid, keys, start = dict(), dict(), None
    with open(filename) as file:
        for row, line in enumerate(file):
            for col, char in enumerate(line.strip()):
                grid[row, col] = char
                if char == '@':
                    start = (row, col)
                if char in ascii_lowercase + '@':
                    keys[char] = (row, col)

    return grid, keys, start

In [3]:
def find_neighbours(pos):
    """Find neighbours in the two first dimensions.
    Remaining dimensions are unchanged"""
    x, y, *rest = pos
    return [(x + dx, y + dy, *rest)
            for dx, dy in zip((0, 0, 1, -1), (1, -1, 0, 0))]

In [4]:
def find_distances_bfs(start):
    visited = set()
    dist = dict()
    dist[start] = 0
    distances = dict()
    queue = collections.deque([(start, frozenset(grid[start]))])

    while queue:
        pos, keys_needed = queue.pop()
        visited.add(pos)
        char = grid[pos]
        
        # We found a key!
        if char in keys and pos != start:
            distances[char] = dist[pos], keys_needed
            keys_needed = keys_needed.union(char)
            
        # Need a key to pass this door
        if char in ascii_uppercase:
            keys_needed = keys_needed.union(char.lower())

        # Enqueue the unvisited neighbours of this point
        neighbours = [n for n in find_neighbours(pos) if grid.get(pos) != '#']
        for n in neighbours:
            dist[n] = dist[pos] + 1
            if n not in visited:
                queue.appendleft((n, keys_needed))

    return distances

In [5]:
def all_distances():
    # Compute all key-to-key distances and keys needed to reach each key
    distances = dict()
    for key, pos in keys.items():
        distances[key] = find_distances_bfs(pos)
    return distances

In [6]:
# This function is called many times with the same arguments
# Use caching to speed things up
def find_reachable(from_key, have_keys):
    reachable = []
    for dest, (dist, need_keys) in distances[from_key].items():
        if dest in have_keys:
            continue
        if need_keys.issubset(have_keys):
            reachable.append((dest, dist))
    return reachable

In [7]:
def dist_all_keys(start):
    """A Dijkstra-ish solution.
    Since all edge weights are 1, there is no need to check for a shorter way.
    The first time we reach a node, we reached there along the shortest route"""
    
    have_keys = frozenset(start)
    dist = dict()
    queue = [(0, (start, have_keys))]

    while queue:
        dist_node, node = heapq.heappop(queue)

        if node in dist:
            # We have been here with these keys (and then with a shorter dist,
            # since we visit nodes in increasing distance).
            continue
        dist[node] = dist_node

        pos, have_keys = node
        if len(have_keys) == len(keys):
            return dist[node]

        reachable = find_reachable(pos, have_keys)
        for key, dist_to_key in reachable:
            tot_dist = dist_node + dist_to_key

            # If we don't have this key, add to queue
            if key not in have_keys:
                heapq.heappush(queue, (tot_dist, (key, have_keys.union(key))))

## Tests 

In [8]:
grid, keys, start = get_grid("day18-test1.input")
distances = all_distances()
assert dist_all_keys('@') == 86

grid, keys, start = get_grid("day18-test2.input")
distances = all_distances()
assert dist_all_keys('@') == 132

grid, keys, start = get_grid("day18-test3.input")
distances = all_distances()
assert dist_all_keys('@') == 136

grid, keys, start = get_grid("day18-test4.input")
distances = all_distances()
assert dist_all_keys('@') == 81

# Part 1

In [9]:
grid, keys, start = get_grid("day18.input")
distances = all_distances()
dist_all_keys('@')

5808