In [1]:
import numpy as np
import networkx as nx

In [2]:
with open("data/day20.txt", encoding="utf-8") as f:
    data = f.read()

In [3]:
racetrack = np.array([list(d) for d in data.split("\n")])
rows, cols = racetrack.shape
start_position = [(i, j) for i in range(rows) for j in range(cols) if racetrack[i, j] == "S"][0]
end_position = [(i, j) for i in range(rows) for j in range(cols) if racetrack[i, j] == "E"][0]
directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]

def is_valid(i, j):
    return 0 <= i < rows and 0 <= j < cols

# Problem 1

In [4]:
def shortest_path_length(position):
    G = nx.Graph()
    for i in range(rows):
        for j in range(cols):
            if racetrack[i,j] == "#":
                continue
            neighbours = [(i + d[0], j + d[1]) for d in directions]

            for ni, nj in neighbours:
                if racetrack[ni, nj] != "#":
                    G.add_edge((i, j), (ni, nj))

    return nx.shortest_path_length(G, source=position)

paths_length_from_start = shortest_path_length(start_position)
paths_length_from_end = shortest_path_length(end_position)
normal_length = paths_length_from_start[end_position]

total = 0
for i in range(rows):
    for j in range(cols):
        if racetrack[i,j] != "#":
            continue
            
        obstacle_neighbours = [((i + d[0], j + d[1]), paths_length_from_start[(i + d[0], j + d[1])]) \
                               for d in directions \
                               if is_valid(i + d[0], j + d[1]) \
                               and racetrack[i + d[0], j + d[1]] != "#"]
        
        obstacle_neighbours = sorted(obstacle_neighbours, key=lambda x: x[1])
        
        if len(obstacle_neighbours) > 0:
            _, length_from_start = obstacle_neighbours[0]
            
            for n, _ in obstacle_neighbours[1:]:
                new_length = length_from_start + paths_length_from_end[n] + 2
                if normal_length - new_length >= 100:
                    total += 1
                    break
total

1296

# Problem 2

In [5]:
total = 0
for i in range(rows):
    for j in range(cols):
        if racetrack[i,j] == "#":
            continue
            
        for k in range(rows):
            for l in range(cols):
                if (i,j) == (k,l) or racetrack[k, l] == "#":
                    continue
        
                length_between = abs(i - k) + abs(j - l)
                if length_between <= 20 \
                and paths_length_from_start[(i,j)] < paths_length_from_start[(k,l)]:
                    new_length = paths_length_from_start[(i,j)] + length_between + paths_length_from_end[(k,l)]
                    if normal_length - new_length >= 100:
                        total += 1

total

977665