In [None]:
import collections
import heapq

from string import ascii_uppercase

In [None]:
def find_neighbours(pos):
    """Find neighbours in the two first dimensions.
    Remaining dimensions are unchanged"""
    x, y, *rest = pos
    return [(x + dx, y + dy, *rest)
            for dx, dy in zip((0, 0, 1, -1), (1, -1, 0, 0))]

In [None]:
def parse_input(filename):

    grid = dict()
    with open(filename) as file:
        for row, line in enumerate(file):
            for col, char in enumerate(line.strip('\n')):
                grid[row, col] = char
        
    max_row, max_col = row, col
    portals = collections.defaultdict(list)
    for row in range(max_row):
        for col in range(max_col):
            symbol = grid[row, col]
            if symbol in ascii_uppercase:
                neighbours = [n for n in find_neighbours((row, col))
                              if grid.get(n, ' ') in ascii_uppercase + '.']
                if len(neighbours) == 2:
                    # We found the middle character in a portal
                    for n in neighbours:
                        if grid.get(n) == '.':
                            entry = n
                        else:
                            letter = grid.get(n)
                    portals[''.join(sorted(symbol + letter))].append(entry)

    start = portals['AA'][0]
    end = portals['ZZ'][0]

    teleport = dict()
    for name, points in portals.items():
        if name not in ['AA', 'ZZ']:
            teleport[points[0]] = points[1]
            teleport[points[1]] = points[0]

    return grid, portals, teleport, start, end

In [None]:
def draw_grid(grid):
    max_x = max((k[0] for k in grid))
    min_x = min((k[0] for k in grid))
    max_y = max((k[1] for k in grid))
    min_y = min((k[1] for k in grid))

    for x in range(min_x, max_x + 1):
        row = ''.join(grid.get((x, y)) for y in range(min_y, max_y + 1))
        print(row)

In [None]:
def find_dist():
    visited = set()
    dist = dict()
    dist[start] = 0
    queue = collections.deque([start])

    while queue:
        pos = queue.pop()
        visited.add(pos)

        # Find neighbours, including via teleports
        neighbours = [n for n in find_neighbours(pos) if grid.get(n) == '.']
        if pos in teleport:
            neighbours.append(teleport[pos])

        for n in neighbours:
            if dist[pos] + 1 < dist.get(n, float('inf')):
                dist[n] = dist[pos] + 1
            if n not in visited:
                queue.appendleft(n)

    return dist

# Part 1

In [None]:
# Tests
grid, portals, teleport, start, end = parse_input("day20-test1.input")
assert find_dist()[end] == 23

grid, portals, teleport, start, end = parse_input("day20-test2.input")
assert find_dist()[end] == 58

In [None]:
grid, portals, teleport, start, end = parse_input("day20.input")
find_dist()[end]

# Part 2

In [None]:
# All portals have an inner and outer entry
# For each portal, sort the entries in ["inner", "outer"]
# The outer entries will have coordinates either 2 or max - 2
def get_teleports():
    max_row = max(grid, key=lambda x: x[0])[0]
    max_col = max(grid, key=lambda x: x[1])[1]

    inner_entries = dict()
    outer_entries = dict()
    for name, entries in portals.items():
        if name not in ['AA', 'ZZ']:
            if entries[0][0] in [2, max_row - 2] or entries[0][1] in [2, max_col - 2]:
                # The first entry is an outer one, switch places
                portals[name].reverse()
            inner_entries[portals[name][0]] = portals[name][1]
            outer_entries[portals[name][1]] = portals[name][0]
    
    return inner_entries, outer_entries

In [None]:
def find_dist_part2(start, end):
    """Use Dijkstras algorithm. Points are represented by (row, col, level)"""
    
    visited = set()
    dist = dict()
    dist[start] = 0
    queue = [(dist[start], start)]
    MAX_RECURSION = 25
    
    while queue:
        dist_pos, pos = heapq.heappop(queue)
        visited.add(pos)
        
        # We have found the end, stop exploring
        if pos == end:
            break
            
        # Find normal neighbours
        rc, level = pos[:2], pos[2]
        neighbours = [n for n in find_neighbours(pos) if grid.get(n[:2]) == '.']

        # If this is a teleport entry, add the other end as a neighbour
        if rc in inner_entries and level <= MAX_RECURSION:
            neighbours.append((*inner_entries[rc], level + 1))
        elif rc in outer_entries and level > 0:
            neighbours.append((*outer_entries[rc], level - 1))
        
        # Update distances and add unvisited neighbours to the queue
        for n in neighbours:
            if dist[pos] + 1 < dist.get(n, float('inf')):
                dist[n] = dist[pos] + 1
            if n not in visited:
                heapq.heappush(queue, (dist[n], n))
    
    if not dist.get(end):
        raise RecursionError("Not possible to find a way")
        
    return dist[end]

In [None]:
# Tests

grid, portals, teleport, start, end = parse_input("day20-test1.input")
inner_entries, outer_entries = get_teleports()
assert find_dist_part2((*start, 0), (*end, 0)) == 26

grid, portals, teleport, start, end = parse_input("day20-test1.input")
inner_entries, outer_entries = get_teleports()
try:
    find_dist_part2((*start, 0), (*end, 0)) == 26
except RecursionError:
    pass
    
grid, portals, teleport, start, end = parse_input("day20-test3.input")
inner_entries, outer_entries = get_teleports()
assert find_dist_part2((*start, 0), (*end, 0)) == 396

In [None]:
grid, portals, teleport, start, end = parse_input("day20.input")
inner_entries, outer_entries = get_teleports()
find_dist_part2((*start, 0), (*end, 0))