In [None]:
import sys

import numpy as np
from IPython.display import display, clear_output
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

sys.setrecursionlimit(10_000)

In [None]:
def is_valid_index(array, index):
    return (index >= 0).all() and (index < array.shape).all()

In [None]:
def add_to_region(grid, visited, idx, region):
    if is_valid_index(grid, idx) and not visited[*idx] and grid[*idx]:
        region.add((int(idx[0]), int(idx[1])))
        visited[*idx] = True
        for offset in [[0, 1], [0, -1], [1, 0], [-1, 0]]:
            add_to_region(grid, visited, idx + offset, region)


def is_path(byte_coords, grid_size, start, end):
    map = np.full((grid_size, grid_size), True)
    for bc in byte_coords:
        map[bc[1], bc[0]] = False

    visited = np.zeros(map.shape, dtype=bool)
    region = set()
    add_to_region(map, visited, start, region)

    return tuple(end) in region

In [None]:
char_to_int = {"#": 0, ".": 1, "E": 2, "S": 3, "*": 4, "?": 5}
char_to_img = np.vectorize(char_to_int.get)

In [None]:
with open("data/day18/input.txt", "r") as file:
    byte_coords_raw = file.read()
grid_size = 71
n_bytes = 1024

In [None]:
byte_coords = [tuple(int(y) for y in x.split(",")) for x in byte_coords_raw.split("\n")]

## Part 1

In [None]:
node_loc_end = np.array((grid_size - 1, grid_size - 1))
node_loc_start = np.array((0, 0))

In [None]:
map_p1 = np.full((grid_size, grid_size), ".")
for bc in byte_coords[:n_bytes]:
    map_p1[bc[1], bc[0]] = "#"
map_p1[node_loc_start[1], node_loc_start[0]] = "S"
map_p1[node_loc_end[1], node_loc_end[0]] = "E"

In [None]:
node_locs = np.argwhere(map_p1 != "#")

In [None]:
d_max = np.abs(node_loc_end - node_loc_start).sum()
d_loc = {
    tuple(node_loc): np.abs(node_loc_end - node_loc).sum() for node_loc in node_locs
}

In [None]:
nodes = {
    (int(x), int(y)): {
        "c": np.inf,
        "d": int(d_loc[(x, y)]),
        "n": None,
        "v": False,
    }
    for x, y in node_locs
}

In [None]:
nodes[tuple(node_loc_start)]["c"] = 0
nodes[tuple(node_loc_start)]["n"] = set()
nodes[tuple(node_loc_start)]["n"].add((int(node_loc_start[0]), int(node_loc_start[1])))

In [None]:
fig, ax = plt.subplots()
iter = 0
d_best = 1

idx_visit = tuple(node_loc_start)

while not all([e == c for e, c in zip(node_loc_end, idx_visit[:2])]):

    node = nodes[idx_visit]

    for i_offset, offset in enumerate([(0, 1), (1, 0), (0, -1), (-1, 0)]):

        idx_cons = (
            idx_visit[0] + offset[0],
            idx_visit[1] + offset[1],
        )
        if idx_cons in nodes:
            if node["c"] + 1 < nodes[idx_cons]["c"]:
                nodes[idx_cons]["c"] = node["c"] + 1
                nodes[idx_cons]["n"] = node["n"].copy()
                nodes[idx_cons]["n"].add(idx_cons[:2])

            d_curr = d_loc[idx_cons[:2]] / d_max
            if d_curr < np.floor(d_best * 100) / 100 or d_curr == 0:
                d_best = d_curr
                node_paths = nodes[idx_cons]["n"]
                map = map_p1.copy()
                for i in [idx[:2] for idx, vals in nodes.items() if vals["v"]]:
                    map[*i] = "?"
                for i in list(nodes[idx_cons]["n"]):
                    map[*i] = "*"
                ax.imshow(char_to_img(map))
                ax.set_title(f"Iteration: {iter} ({1-d_best:.2f})")
                display(fig)
                clear_output(wait=True)

    node["v"] = True
    iter += 1
    idx_visit = sorted(
        [(idx, vals["c"] + vals["d"]) for idx, vals in nodes.items() if not vals["v"]],
        key=lambda x: x[1],
    )[0][0]

map = map_p1.copy()
for i in [idx[:2] for idx, vals in nodes.items() if vals["v"]]:
    map[*i] = "?"
for i in list(nodes[tuple(node_loc_end)]["n"]):
    map[*i] = "*"
ax.imshow(char_to_img(map))
ax.set_title(f"Iteration: {iter} ({1-d_best:.2f})")
display(fig)
clear_output(wait=True)

In [None]:
nodes[tuple(node_loc_end)]["c"]

## Part 2

In [None]:
n_lower = 0
n_upper = len(byte_coords)

while n_upper - n_lower > 1:
    n_mid = (n_upper + n_lower) // 2
    if is_path(byte_coords[:n_mid], grid_size, node_loc_start, node_loc_end):
        n_lower = n_mid
    else:
        n_upper = n_mid

byte_coords[n_upper - 1]