In [None]:
import sys
from itertools import combinations

import numpy as np
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

sys.setrecursionlimit(10_000)

In [None]:
def is_valid_index(array, index):
    return (index >= 0).all() and (index < array.shape).all()

In [None]:
def add_to_region(grid, visited, idx, region):
    if is_valid_index(grid, idx) and not visited[*idx] and grid[*idx]:
        region.append((int(idx[0]), int(idx[1])))
        visited[*idx] = True
        for offset in [[0, 1], [0, -1], [1, 0], [-1, 0]]:
            add_to_region(grid, visited, idx + offset, region)


def get_path(grid, start):
    visited = np.zeros(grid.shape, dtype=bool)
    region = []
    add_to_region(grid, visited, start, region)

    return region

In [None]:
def manhatten_dist(a, b):
    return abs(b[0] - a[0]) + abs(b[1] - a[1])

In [None]:
char_to_int = {"#": 0, ".": 1, "E": 2, "S": 3, "*": 4, "?": 5}
char_to_img = np.vectorize(char_to_int.get)

In [None]:
with open("data/day20/input.txt", "r") as file:
    map_init_raw = file.read()

# map_init_raw = """###############
# #...#...#.....#
# #.#.#.#.#.###.#
# #S#...#.#.#...#
# #######.#.#.###
# #######.#.#...#
# #######.#.###.#
# ###..E#...#...#
# ###.#######.###
# #...###...#...#
# #.#####.#.###.#
# #.#...#.#.#...#
# #.#.#.#.#.#.###
# #...#...#...###
# ###############"""

In [None]:
map_init = np.array([list(x) for x in map_init_raw.split("\n")])

In [None]:
fig, ax = plt.subplots()
ax.imshow(char_to_img(map_init))

In [None]:
node_loc_start = np.argwhere(map_init == "S")[0]
ref_path = get_path(map_init != "#", node_loc_start)
ref_path_dist = {k: i for i, k in enumerate(ref_path)}

## Part 1

In [None]:
cheat_length = 2

In [None]:
node_cheat_pairs = [
    (x, y, manhatten_dist(x, y))
    for x, y in combinations(ref_path, r=2)
    if 1 < manhatten_dist(x, y) <= cheat_length
]

In [None]:
cheat_saving = {}
for idx_cheat_a, idx_cheat_b, dist in node_cheat_pairs:
    cheat_saving[(idx_cheat_a, idx_cheat_b)] = (
        abs(ref_path_dist[idx_cheat_a] - ref_path_dist[idx_cheat_b]) - dist
    )

In [None]:
sum([x >= 100 for x in cheat_saving.values()])

## Part 2

In [None]:
cheat_length = 20

In [None]:
node_cheat_pairs = [
    (x, y, manhatten_dist(x, y))
    for x, y in combinations(ref_path, r=2)
    if 1 < manhatten_dist(x, y) <= cheat_length
]

In [None]:
cheat_saving = {}
for idx_cheat_a, idx_cheat_b, dist in node_cheat_pairs:
    cheat_saving[(idx_cheat_a, idx_cheat_b)] = (
        abs(ref_path_dist[idx_cheat_a] - ref_path_dist[idx_cheat_b]) - dist
    )

In [None]:
sum([x >= 100 for x in cheat_saving.values()])