In [56]:
with open("inputs/Day_18_part_1.txt") as f:
   puzzle_input_part_1 = f.read()

In [57]:
def part_1_solution(raw_input):
    grid, keys, start_position = parse_grid(raw_input)
    reachable_keys, end_key = explore_grid(grid, keys, start_position)
    shortest_path = get_path_to(end_key, (start_position, ()), reachable_keys)
    return len(shortest_path) - 1 # without transition to start position
    
    
def parse_grid(raw_input):
    grid = dict()
    keys = set()

    for y, row in enumerate(raw_input.split("\n")):
        for x, obj in enumerate(row):
            position = (x, y)
            
            if obj == "#":
                continue
            elif obj == ".":
                grid[position] = "OPEN_PASSAGE"
            elif obj == "@":
                start_position = position
                grid[position] = "OPEN_PASSAGE"
            else:
                grid[position] = obj
                
                if obj.islower():
                    keys.add(obj)
            
            
    return grid, keys, start_position


def explore_grid(grid, all_keys, start_position):
    # finds shortest path to collect all of the keys
    visited = dict() # seperate visited set for every set (converted to tuple) of collected keys
    path_to = dict() # (position, collected_keys) as key
    collected_keys = set()
    
    visited[tuple(collected_keys)] = set()
    
    queue = list()
    queue.append((start_position, collected_keys))
    
    while queue:
        current_position, current_colleted_keys = queue.pop(0)
        
        for neighbor, neighbor_keys in get_neighbors(current_position, current_colleted_keys, grid):
            neighbor_keys_tuple = tuple(sorted(neighbor_keys))
            if neighbor_keys_tuple not in visited:
                visited[neighbor_keys_tuple] = set()
                
            if neighbor not in visited[neighbor_keys_tuple]:
                visited[neighbor_keys_tuple].add(neighbor)
                path_to[(neighbor, neighbor_keys_tuple)] = (current_position, tuple(current_colleted_keys))
                
                if len(neighbor_keys) == len(all_keys):
                    # found shortest path to collect all of the keys on grid
                    return path_to, (neighbor, neighbor_keys_tuple)
                
                queue.append((neighbor, neighbor_keys_tuple))
                
    # could not find a path which collects all of the keys
    raise Exception("Path to all keys not found")    
                

def get_neighbors(position, collected_keys, grid):
    for x_offset, y_offset in ((0, 1), (1, 0), (0, -1), (-1, 0)):
        canditdate = (position[0] + x_offset, position[1] + y_offset)

        if canditdate not in grid:
            continue
        
        canditdate_content = grid[canditdate]
        
        if canditdate_content == "OPEN_PASSAGE":
            yield canditdate, collected_keys
        else:
            if canditdate_content.islower():
                # key
                new_collected_keys = set(collected_keys)
                new_collected_keys.add(canditdate_content)
                yield canditdate, new_collected_keys
            else:
                # doors
                required_key = canditdate_content.lower()
                
                if required_key in collected_keys:
                    yield canditdate, collected_keys
                    
def get_path_to(end_key, start_key, reachable_keys):
    path = list()
    
    current_key = end_key
    
    while current_key != start_key:
        position, collected_keys = current_key
        path.append(position)
        current_key = reachable_keys[current_key]
        
    start_position, _ = start_key
    path.append(start_position)
    path.reverse()
    
    return path 

In [53]:
test_input = """#########
#b.A.@.a#
#########"""
assert(part_1_solution(test_input) == 8)
test_input = """########################
#f.D.E.e.C.b.A.@.a.B.c.#
######################.#
#d.....................#
########################"""
assert(part_1_solution(test_input) == 86)
test_input = """########################
#...............b.C.D.f#
#.######################
#.....@.a.B.c.d.A.e.F.g#
########################"""
assert(part_1_solution(test_input) == 132)
test_input = """#################
#i.G..c...e..H.p#
########.########
#j.A..b...f..D.o#
########@########
#k.E..a...g..B.n#
########.########
#l.F..d...h..C.m#
#################"""
assert(part_1_solution(test_input) == 136)
test_input = """########################
#@..............ac.GI.b#
###d#e#f################
###A#B#C################
###g#h#i################
########################"""
assert(part_1_solution(test_input) == 81)
print("Tests passed")

Tests passed


In [58]:
%%time
print(f"Part 1 solution: {part_1_solution(puzzle_input_part_1)}")

Part 1 solution: 4250
CPU times: user 19.9 s, sys: 1.75 s, total: 21.6 s
Wall time: 23.5 s


In [18]:
%%time

import string
import collections
import math
import re
import sys

import sortedcollections

def reachablekeys(grid, start, havekeys):
    bfs = collections.deque([start])
    distance = {start: 0}
    keys = {}
    while bfs:
        h = bfs.popleft()
        for pt in [
            (h[0] + 1, h[1]),
            (h[0] - 1, h[1]),
            (h[0], h[1] + 1),
            (h[0], h[1] - 1),
        ]:
            if not (0 <= pt[0] < len(grid) and 0 <= pt[1] < len(grid[0])):
                continue
            ch = grid[pt[0]][pt[1]]
            if ch == '#':
                continue
            if pt in distance:
                continue
            distance[pt] = distance[h] + 1
            if 'A' <= ch <= 'Z' and ch.lower() not in havekeys:
                continue
            if 'a' <= ch <= 'z' and ch not in havekeys:
                keys[ch] = distance[pt], pt
            else:
                bfs.append(pt)
    return keys


def reachable4(grid, starts, havekeys):
    keys = {}
    for i, start in enumerate(starts):
        for ch, (dist, pt) in reachablekeys(grid, start, havekeys).items():
            keys[ch] = dist, pt, i
    return keys


seen = {}
def minwalk(grid, starts, havekeys):
    hks = ''.join(sorted(havekeys))
    if (starts, hks) in seen:
        return seen[starts, hks]
    keys = reachable4(grid, starts, havekeys)
    if len(keys) == 0:
        # done!
        ans = 0
    else:
        poss = []
        for ch, (dist, pt, roi) in keys.items():
            nstarts = tuple(pt if i == roi else p for i, p in enumerate(starts))
            poss.append(dist + minwalk(grid, nstarts, havekeys + ch))
        ans = min(poss)
    seen[starts, hks] = ans
    return ans


with open("inputs/Day_18_part_2.txt") as f:
    grid = [l.rstrip('\n') for l in f]

    starts = []
    for i in range(len(grid)):
        for j in range(len(grid[0])):
            if grid[i][j] == '@':
                starts.append((i,j))

    print(f"Part 2 solution: {minwalk(grid, tuple(starts), '')}")

Part 2 solution: 1640
CPU times: user 1min 7s, sys: 109 ms, total: 1min 7s
Wall time: 1min 13s
