In [1]:
import collections
from string import ascii_uppercase

In [2]:
def find_neighbours(pos):
    return [(pos[0] + dx, pos[1] + dy)
            for dx, dy in zip((0, 0, 1, -1), (1, -1, 0, 0))]

In [3]:
def parse_input(filename):

    grid = dict()
    with open(filename) as file:
        for row, line in enumerate(file):
            for col, char in enumerate(line.strip('\n')):
                grid[row, col] = char
        
    max_row, max_col = row, col
    portals = collections.defaultdict(list)
    for row in range(max_row):
        for col in range(max_col):
            symbol = grid[row, col]
            if symbol in ascii_uppercase:
                neighbours = [n for n in find_neighbours((row, col))
                              if grid.get(n, ' ') in ascii_uppercase + '.']
                if len(neighbours) == 2:
                    # We found the middle character in a portal
                    for n in neighbours:
                        if grid.get(n) == '.':
                            entry = n
                        else:
                            letter = grid.get(n)
                    portals[''.join(sorted(symbol + letter))].append(entry)

    start = portals['AA'][0]
    end = portals['ZZ'][0]

    teleport = dict()
    for name, points in portals.items():
        if name not in ['AA', 'ZZ']:
            teleport[points[0]] = points[1]
            teleport[points[1]] = points[0]

    return grid, portals, teleport, start, end

In [4]:
def draw_grid(grid):
    max_x = max((k[0] for k in grid))
    min_x = min((k[0] for k in grid))
    max_y = max((k[1] for k in grid))
    min_y = min((k[1] for k in grid))

    for x in range(min_x, max_x + 1):
        row = ''.join(grid.get((x, y)) for y in range(min_y, max_y + 1))
        print(row)

In [5]:
def find_dist():
    visited = set()
    dist = dict()
    dist[start] = 0
    queue = collections.deque([start])

    while queue:
        pos = queue.pop()
        visited.add(pos)

        # Find neighbours, including via teleports
        neighbours = [n for n in find_neighbours(pos) if grid.get(n) == '.']
        if pos in teleport:
            neighbours.append(teleport[pos])

        for n in neighbours:
            if dist[pos] + 1 < dist.get(n, float('inf')):
                dist[n] = dist[pos] + 1
            if n not in visited:
                queue.appendleft(n)

    return dist

# Part 1

In [6]:
# Tests
grid, portals, teleport, start, end = parse_input("day20-test1.input")
assert find_dist()[end] == 23

grid, portals, teleport, start, end = parse_input("day20-test2.input")
assert find_dist()[end] == 58

In [7]:
grid, portals, teleport, start, end = parse_input("day20.input")
find_dist()[end]

600

# Part 2

In [8]:
def part2(grid):
    import networkx as nx
    
    max_row = max(grid, key=lambda x: x[0])[0]
    max_col = max(grid, key=lambda x: x[1])[1]
    max_levels = 30

    # Add all the .'s to the graph
    G = nx.Graph()
    for pos, symbol in grid.items():
        if symbol == '.':
            for level in range(max_levels):
                G.add_node((*pos, level))
            neighbours = [n for n in find_neighbours(pos) if grid.get(n) == '.']
            for n in neighbours:
                for level in range(max_levels):
                    G.add_edge((*pos, level), (*n, level))

    # Add all the portals to the graph
    for pads in portals.values():
        if len(pads) == 2:
            if pads[0][0] in [2, max_row - 2] or pads[0][1] in [2, max_col - 2]:
                outer, inner = pads
            else:
                inner, outer = pads
            for i in range(max_levels - 1):
                # inner portals lead to the outer portals on the next level and outer to inner on the previous level
                G.add_edge((*inner, i), (*outer, i + 1))
                G.add_edge((*outer, i + 1), (*inner, i))
                
    return nx.shortest_path_length(G, (*start, 0), (*end, 0))

In [9]:
# Tests
grid, portals, teleport, start, end = parse_input("day20-test1.input")
assert part2(grid) == 26

grid, portals, teleport, start, end = parse_input("day20-test3.input")
assert part2(grid) == 396

In [10]:
grid, portals, teleport, start, end = parse_input("day20.input")
part2(grid)

6666