In [1]:
from string import ascii_lowercase, ascii_uppercase
import collections
import heapq

In [2]:
def get_grid(filename, part=1):
    grid, keys = dict(), dict()
    starts = 0
    with open(filename) as file:
        for row, line in enumerate(file):
            for col, char in enumerate(line.strip()):
                grid[row, col] = char
                if char == '@':
                    if part == 1:
                        keys[char] = (row, col)
                    if part == 2:
                        keys[str(starts)] = (row, col)
                        grid[row, col] = str(starts)
                        starts += 1
                if char in ascii_lowercase:
                    keys[char] = (row, col)

    return grid, keys

In [3]:
def find_neighbours(pos):
    """Find neighbours in the two first dimensions.
    Remaining dimensions are unchanged"""
    x, y, *rest = pos
    return [(x + dx, y + dy, *rest)
            for dx, dy in zip((0, 0, 1, -1), (1, -1, 0, 0))]

In [4]:
def find_distances_bfs(start_pos):
    """Starting from one key, find distance to all other keys, using BFS.
    Take note of the doors we have to pass, so we know what keys we need"""
    
    visited = set()
    dist = dict()
    dist[start_pos] = 0
    distances = dict()
    queue = collections.deque([(start_pos, frozenset(grid[start_pos]))])

    while queue:
        pos, keys_needed = queue.pop()
        visited.add(pos)
        char = grid[pos]
        
        # We found a key!
        if char in keys:
            distances[char] = (dist[pos], keys_needed)
            keys_needed = keys_needed.union(char)
            
        # Need a key to pass this door
        if char in ascii_uppercase:
            keys_needed = keys_needed.union(char.lower())

        # Enqueue the unvisited neighbours of this point
        neighbours = [n for n in find_neighbours(pos) if grid.get(pos) != '#']
        for n_pos in neighbours:
            dist[n_pos] = dist[pos] + 1
            if n_pos not in visited:
                queue.appendleft((n_pos, keys_needed))

    return distances

In [5]:
def all_distances():
    """Compute all key-to-key distances and the keys needed to reach each key"""
    distances = dict()
    for key, pos in keys.items():
        distances[key] = find_distances_bfs(pos)
    return distances

In [6]:
def find_reachable(from_key, have_keys):
    """Find all keys you can reach from 'key', with 'have_keys' keys"""
    reachable = []
    for dest, (dist, need_keys) in distances[from_key].items():
        if dest in have_keys:
            continue
        if need_keys.issubset(have_keys):
            reachable.append((dest, dist))
    return reachable

In [7]:
def dist_all_keys(start):
    """Use Dijkstra to find the shortest route to all the keys.
    Nodes are represented by the key (position) and the set of collected keys"""
    
    visited = set()
    dist = dict()
    have_keys = frozenset(start)
    start_node = (start, have_keys)
    dist[start_node] = 0
    queue = [(dist[start_node], start_node)]

    while queue:
        dist_node, node = heapq.heappop(queue)

        if node in visited:
            # This node has already been visited via a shorter route
            continue

        visited.add(node)      
        from_key, have_keys = node

        if len(have_keys) == len(keys):
            # We have all the keys, return the distance
            return dist_node

        reachable = find_reachable(from_key, have_keys)
        for to_key, dist_to_key in reachable:
            new_node = (to_key, have_keys.union(to_key))
            if new_node in visited:
                # That node has already been visited via a shorter route
                continue       
            dist_new_node = dist_node + dist_to_key
            if dist_new_node < dist.get(new_node, float('inf')):
                # This is either the first time we can reach new_node,
                # or we just found a shorter route to it
                dist[new_node] = dist_new_node
                heapq.heappush(queue, (dist_new_node, new_node))

## Tests 

In [8]:
grid, keys = get_grid("day18-test1.input")
distances = all_distances()
assert dist_all_keys('@') == 86

grid, keys = get_grid("day18-test2.input")
distances = all_distances()
assert dist_all_keys('@') == 132

grid, keys = get_grid("day18-test3.input")
distances = all_distances()
assert dist_all_keys('@') == 136

grid, keys = get_grid("day18-test4.input")
distances = all_distances()
assert dist_all_keys('@') == 81

# Part 1

In [9]:
grid, keys = get_grid("day18.input")
distances = all_distances()
dist_all_keys('@')

5808

# Part 2

In [10]:
def patch_grid_part2(grid, keys):
    replace = iter('0#1###2#3')
    start = keys['@']
    for row in range(start[0] - 1, start[0] + 2):
        for col in range(start[1] - 1, start[1] + 2):
            grid[row, col] = next(replace)
    del keys['@']
    for i, (dr, dc) in enumerate(zip([-1, -1, 1, 1], [-1, 1, -1, 1])):
        keys[str(i)] = (start[0] + dr, start[1] + dc)
        
    return grid, keys

In [11]:
def dist_all_keys_pt2(start_points):
    """Use Dijkstra to find the shortest route to all the keys.
    Nodes are represented by the position all four robots,
    plus the full set of collected keys from all robots"""
    
    visited = set()
    dist = dict()
    have_keys = frozenset(start_points)
    start_node = (have_keys, have_keys)
    dist[start_node] = 0
    queue = [(dist[start_node], start_node)]

    while queue:
        dist_node, node = heapq.heappop(queue)
        if node in visited:
            continue
            
        visited.add(node)      
        on_keys, have_keys = node
        
        if len(have_keys) == len(keys):
            return dist_node

        for on_key in on_keys:
            reachable = find_reachable(on_key, have_keys)
            for to_key, dist_to_key in reachable:
                new_on_keys = on_keys ^ set((on_key, to_key))
                new_node = (new_on_keys, have_keys.union(to_key))
                
                if new_node in visited:
                    continue
                
                tot_dist = dist_node + dist_to_key
                if tot_dist < dist.get(new_node, float('inf')):
                    dist[new_node] = tot_dist
                    heapq.heappush(queue, (tot_dist, new_node))

In [12]:
# Tests
grid, keys = get_grid("day18-part2-test1.input", part=2)
distances = all_distances()
assert dist_all_keys_pt2(('0', '1', '2', '3')) == 32

grid, keys = get_grid("day18-part2-test2.input", part=2)
distances = all_distances()
assert dist_all_keys_pt2(('0', '1', '2', '3')) == 72

In [13]:
grid, keys = patch_grid_part2(*get_grid("day18.input"))
distances = all_distances()
dist_all_keys_pt2(('0', '1', '2', '3'))

1992