In [19]:
import heapq
import collections

In [2]:
testlines = '''###############
#.......#....E#
#.#.###.#.###.#
#.....#.#...#.#
#.###.#####.#.#
#.#.#.......#.#
#.#.#####.###.#
#...........#.#
###.#.#####.#.#
#...#.....#.#.#
#.#.#.###.#.#.#
#.....#...#.#.#
#.###.#.#.#.#.#
#S..#.....#...#
###############'''.splitlines()

In [6]:
biggertestlines = '''#################
#...#...#...#..E#
#.#.#.#.#.#.#.#.#
#.#.#.#...#...#.#
#.#.#.#.###.#.#.#
#...#.#.#.....#.#
#.#.#.#.#.#####.#
#.#...#.#.#.....#
#.#.#####.#.###.#
#.#.#.......#...#
#.#.###.#####.###
#.#.#...#.....#.#
#.#.#.#####.###.#
#.#.#.........#.#
#.#.#.#########.#
#S#.............#
#################'''.splitlines()

In [3]:
with open('day16input.txt') as fp:
    data = fp.read().splitlines()

## Part 1 ##

In [7]:
def get_walls_start_end(lines):
    walls = set()
    for row, line in enumerate(lines):
        for col, c in enumerate(line):
            if '.' == c:
                continue
            elif '#' == c:
                walls.add((row, col))
            elif 'S' == c:
                start = (row, col)
            elif 'E' == c:
                end = (row, col)
            else:
                raise ValueError(f'Bad character {c} at ({row},{col})')
    return walls, start, end

In [9]:
walls, start, end = get_walls_start_end(testlines)

In [10]:
start, end

((13, 1), (1, 13))

Stealing Dijkstra code from https://www.reddit.com/r/adventofcode/comments/1hfboft/comment/m2b0pw2/

In [13]:
DIRS = [(0,+1), (+1, 0), (0, -1), (-1, 0)] # E, S, W, N
COSTS = {'forward': 1, 'turn': 1000}
def traverse(source, target, walls):
    q = [(0, source, 0)]
    seen = set()
    while q:
        cost, pos, d = heapq.heappop(q)
        if (pos, d) in seen:
            continue
        if target == pos:
            return cost
        seen.add((pos, d))
        # add possible next moves to the queue
        forward = (pos[0] + DIRS[d][0], pos[1]+DIRS[d][1])
        if forward not in walls:
            heapq.heappush(q, (cost+COSTS['forward'], forward, d))
        heapq.heappush(q, (cost+COSTS['turn'], pos, (d+1)%4)) # turn 90 deg rt / CW
        heapq.heappush(q, (cost+COSTS['turn'], pos, (d-1)%4)) # turn 90 deg lf / CCW
    raise ValueError('Heap queue exhausted without finding the target')

In [14]:
def part1(lines):
    walls, start, end = get_walls_start_end(lines)
    cost = traverse(start, end, walls)
    return cost

In [15]:
assert(7036 == part1(testlines))

In [17]:
assert(11048 == part1(biggertestlines))

In [18]:
part1(data)

73404

## Part 2 ##

Again, just blatently stealing the solution posted above, which modifies the Dijkstra code to keep the paths, given we already know the optimal cost from Part 1.

In [67]:
def find_valid_links(source, target, target_cost, walls):
    best_costs = {}
    links = collections.defaultdict(set)
    q = [(0, source, 0, None)]
    while q:
        cost, pos, d, prev = heapq.heappop(q)
        if cost > target_cost:
            # min cost in the heap is too big, so all valid routes have been found
            return links
        if (pos, d) in best_costs:
            # if we've found a new good route, record it!
            if cost == best_costs[(pos, d)]:
                links[(pos, d)].add(prev)
            continue
        # now we know we have a new best cost
        best_costs[(pos, d)] = cost
        links[(pos, d)].add(prev)

        # back to "normal" Dijkstra
        prev = (pos, d)
        forward = (pos[0]+DIRS[d][0], pos[1]+DIRS[d][1])
        if forward not in walls:
            heapq.heappush(q, (cost+COSTS['forward'], forward, d, prev))
        heapq.heappush(q, (cost+COSTS['turn'], pos, (d+1)%4, prev))
        heapq.heappush(q, (cost+COSTS['turn'], pos, (d-1)%4, prev))
    raise ValueError('Heap exhausted')

In [70]:
def count_valid_tiles(links, target):
    routes, tiles = set(), set()
    def walk(cur, routes, tiles):
        if cur and cur not in routes:
            routes.add(cur)
            tiles.add(cur[0])
            for npos in links[cur]:
                walk(npos, routes, tiles)
    for d in range(4):
        walk((target, d), routes, tiles)
    return len(tiles)

In [72]:
def part2(lines):
    walls, start, end = get_walls_start_end(lines)
    cost = traverse(start, end, walls)
    links = find_valid_links(start, end, cost, walls)
    return count_valid_tiles(links, end)

In [74]:
assert(45 == part2(testlines))

In [75]:
assert(64 == part2(biggertestlines))

In [76]:
part2(data)

449