In [1]:
import numpy as np
import networkx as nx
from itertools import product

In [2]:
def height_map(height_char) -> int:
    if height_char == "S":
        return height_map("a")
    elif height_char == "E":
        return height_map("z")
    else:
        return ord(height_char) - 96

def is_starting_square(char, use_second_task: bool):
    return char == "S" or (use_second_task and char == "a")

def parse_input(input_stream, use_second_task: bool = False):
    graph = nx.DiGraph()
    for row_idx, row in enumerate(input_stream):
        row = row.rstrip()
        graph.add_nodes_from([
            ((row_idx, col_idx), {
                "height": height_map(char),
                "start": is_starting_square(char, use_second_task),
                "end": char == "E",
            })
            for col_idx, char in enumerate(row)
        ])
    
    num_rows = max(n[0] for n in graph.nodes) + 1
    num_cols = max(n[1] for n in graph.nodes) + 1
    for row_idx, col_idx in product(range(num_rows), range(num_cols-1)):
        if graph.nodes[(row_idx, col_idx)]["height"] <= graph.nodes[(row_idx, col_idx + 1)]["height"] + 1:
            graph.add_edge((row_idx, col_idx + 1), (row_idx, col_idx))
        if graph.nodes[(row_idx, col_idx + 1)]["height"] <= graph.nodes[(row_idx, col_idx)]["height"] + 1:
            graph.add_edge((row_idx, col_idx), (row_idx, col_idx + 1))
    for row_idx, col_idx in product(range(num_rows-1), range(num_cols)):
        if graph.nodes[(row_idx, col_idx)]["height"] <= graph.nodes[(row_idx + 1, col_idx)]["height"] + 1:
            graph.add_edge((row_idx + 1, col_idx), (row_idx, col_idx))
        if graph.nodes[(row_idx + 1, col_idx)]["height"] <= graph.nodes[(row_idx, col_idx)]["height"] + 1:
            graph.add_edge((row_idx, col_idx), (row_idx + 1, col_idx))

    start_nodes = [node for node in graph.nodes if graph.nodes[node]["start"]]
    end_nodes = [node for node in graph.nodes if graph.nodes[node]["end"]]

    # for row_idx in range(num_rows):
    #     print(" ".join([f"{graph.nodes[(row_idx, col_idx)]['height']:02d}" for col_idx in range(num_cols)]))

    return graph, start_nodes, end_nodes

In [3]:
def calc_shortest_path_length(input_file_name, use_second_task: bool = False):
    graph, start_nodes, end_nodes = parse_input(open(input_file_name, "r"), use_second_task=use_second_task)
    path_lenghts = []
    for start, end in product(start_nodes, end_nodes):
        try:
            path_lenghts.append(nx.shortest_path_length(graph, start, end))
        except nx.NetworkXNoPath:
            pass
    return min(path_lenghts)

In [4]:
calc_shortest_path_length("test-input.txt")

31

In [5]:
calc_shortest_path_length("input.txt")

534

In [6]:
calc_shortest_path_length("test-input.txt", use_second_task=True)

29

In [7]:
calc_shortest_path_length("input.txt", use_second_task=True)

525