In [158]:
from queue import PriorityQueue as PQ
from copy import deepcopy


import networkx as nx

In [66]:
INPUT_TEST_ = """###############
#.......#....E#
#.#.###.#.###.#
#.....#.#...#.#
#.###.#####.#.#
#.#.#.......#.#
#.#.#####.###.#
#...........#.#
###.#.#####.#.#
#...#.....#.#.#
#.#.#.###.#.#.#
#.....#...#.#.#
#.###.#.#.#.#.#
#S..#.....#...#
###############"""
MAP_TEST = [list(s) for s in INPUT_TEST_.split('\n')]

INPUT_TEST_1_ = """#################
#...#...#...#..E#
#.#.#.#.#.#.#.#.#
#.#.#.#...#...#.#
#.#.#.#.###.#.#.#
#...#.#.#.....#.#
#.#.#.#.#.#####.#
#.#...#.#.#.....#
#.#.#####.#.###.#
#.#.#.......#...#
#.#.###.#####.###
#.#.#...#.....#.#
#.#.#.#####.###.#
#.#.#.........#.#
#.#.#.#########.#
#S#.............#
#################"""
MAP_TEST_1 = [list(s) for s in INPUT_TEST_1_.split('\n')]

with open('d16_in.txt', 'r') as f:
    INPUT_ = f.read()
MAP = [list(s) for s in INPUT_.split('\n')]

In [162]:
BLACK = '\033[30m'
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m' # orange on some systems
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
LIGHT_GRAY = '\033[37m'
DARK_GRAY = '\033[90m'
BRIGHT_RED = '\033[91m'
BRIGHT_GREEN = '\033[92m'
BRIGHT_YELLOW = '\033[93m'
BRIGHT_BLUE = '\033[94m'
BRIGHT_MAGENTA = '\033[95m'
BRIGHT_CYAN = '\033[96m'
WHITE = '\033[97m'
RESET = '\033[0m' # return to standard terminal text color

def color_map(map_):
    new_map = deepcopy(map_)
    for i in range(len(map_)):
        for j in range(len(map_[0])):
            c = map_[i][j]
            if c == '#':
                new_map[i][j] = DARK_GRAY+c
            elif c == '.':
                new_map[i][j] = BRIGHT_BLUE+c
            elif c in '<>^vO':
                new_map[i][j] = BRIGHT_YELLOW+c
    return new_map

In [133]:
def get_start(map_):
    for i in range(len(map_)):
        for j in range(len(map_[0])):
            if map_[i][j] == 'S':
                return (i, j)
    return (-1, -1)

def get_end(map_):
    for i in range(len(map_)):
        for j in range(len(map_[0])):
            if map_[i][j] == 'E':
                return (i, j)
    return (-1, -1)

def get_nbh_cost(nb_t, map_):
    v, d = nb_t
    i, j = v
    match d:
        case '^':
            nb_c = [((i-1, j), '^', 1), ((i, j), '>', 1000), ((i, j), '<', 1000)]
        case 'v':
            nb_c = [((i+1, j), 'v', 1), ((i, j), '>', 1000), ((i, j), '<', 1000)]
        case '>':
            nb_c = [((i, j+1), '>', 1), ((i, j), '^', 1000), ((i, j), 'v', 1000)]
        case '<':
            nb_c = [((i, j-1), '<', 1), ((i, j), '^', 1000), ((i, j), 'v', 1000)]
    return [
        nb for nb in nb_c
        if map_[nb[0][0]][nb[0][1]] != '#'
    ]

Part 1

In [98]:
BIG_NUM = (1 << 31) -1

In [94]:
def dij(map_):

    dist = {((i, j), d): BIG_NUM
        for i in range(1, len(map_)-1)
        for j in range(1, len(map_[0])-1)
        for d in "^v><"}
    pred = {((i, j), d): None
        for i in range(1, len(map_)-1)
        for j in range(1, len(map_[0])-1)
        for d in "^v><"}

    Q = PQ()
    s_ = get_start(map_)
    s = (s_, '>')
    dist[s] = 0

    Q.put((0, s))
    for v in dist:
        if v != s:
            Q.put((BIG_NUM, v))

    e_ = get_end(map_)
    ends = [(e_, d) for d in '^v><']

    mins = []
    while not Q.empty():
        dist_u, u = Q.get()
        if u in ends:
            mins.append((dist_u, u))
        if len(mins) == 4:
            break
        for nb_t in get_nbh_cost(u, map_):
            v, cost = (nb_t[0], nb_t[1]), nb_t[2]
            alt = dist_u + cost
            if alt < dist[v]:
                pred[v] = u
                dist[v] = alt
                Q.put((alt, v))
    return mins, pred, dist

In [185]:
map_ = MAP

mins, pred, dist = dij(map_)
min(mins)[0]

82460

Part 2

In [None]:
G = nx.DiGraph()

st_ = get_start(map_)
start = (st_, '>')

end = min(mins)[1]

for i in range(1, len(map_)-1):
    for j in range(1, len(map_[0])-1):
        if map_[i][j] == '#':
            continue
        else:
            for d in "><^v":
                u = ((i, j), d)
                if d == '>':
                    G.add_edge(u, ((i, j), '^'), weight=1000)
                    G.add_edge(u, ((i, j), 'v'), weight=1000)
                    if map_[i][j+1] != '#':
                        G.add_edge(u, ((i, j+1), '>'), weight=1)
                elif d == '<':
                    G.add_edge(u, ((i, j), '^'), weight=1000)
                    G.add_edge(u, ((i, j), 'v'), weight=1000)
                    if map_[i][j-1] != '#':
                        G.add_edge(u, ((i, j-1), '<'), weight=1)
                elif d == '^':
                    G.add_edge(u, ((i, j), '<'), weight=1000)
                    G.add_edge(u, ((i, j), '>'), weight=1000)
                    if map_[i-1][j] != '#':
                        G.add_edge(u, ((i-1, j), '^'), weight=1)
                elif d == 'v':
                    G.add_edge(u, ((i, j), '<'), weight=1000)
                    G.add_edge(u, ((i, j), '>'), weight=1000)
                    if map_[i+1][j] != '#':
                        G.add_edge(u, ((i+1, j), 'v'), weight=1)

min_set = set()
for idx, p in enumerate(nx.all_shortest_paths(G, source=start, target=end, weight='weight')):
    for pt_ in p:
        min_set.add(pt_[0])

print(len(min_set))

590
