---
# --- Day 17: Clumsy Crucible ---
---

In [1]:
import numpy as np
from typing import Tuple, List, Set

from heapq import heappop, heappush

In [2]:
V = lambda *x: np.array(x)

## Load data

In [3]:
full_puzzle_data = True

In [4]:
file_suffix = "" if full_puzzle_data else "_test"
with open(f"data/day17_input{file_suffix}.txt", "r") as f:
    city = np.array([list(map(int, list(row))) for row in f.read().splitlines()])

## --- Part One ---

In [5]:
MAX_VALUE = 1e10
DIRECTIONS = {0: (0, -1), 1: (1, 0), 2: (0, 1), 3: (-1, 0)}
ND = len(DIRECTIONS)

def find_shortest_distance(city: np.ndarray):
    m, n = city.shape
    queue = [(0, 0, 0, 0, 0)] # cost, x, y, direction index, consecutive straight moves
    seen = set()
    costs = {}
    while len(queue) > 0:
        cost, x, y, ori, n_moves = heappop(queue)
        if x == m - 1 and y == n -1:
            return cost
        if (x, y, ori, n_moves) in seen:
            continue
        seen.add((x, y, ori, n_moves))
        neighbors = [(ori, n_moves + 1)] if n_moves < 3 else []
        neighbors += [((ori + 1) % ND,  1), ((ori + ND - 1) % ND, 1)]
        for new_ori, new_moves in neighbors:
            dx, dy = DIRECTIONS[new_ori]
            nx = x + dx
            ny = y + dy
            if nx >=0 and nx < m and ny >=0 and ny < n:
                v = (nx, ny, new_ori, new_moves)
                new_cost = cost + city[nx, ny]
                current_cost = costs.get(v, MAX_VALUE)
                if new_cost < current_cost:
                    costs[v] = new_cost
                    heappush(queue, (new_cost, nx, ny, new_ori, new_moves))

In [6]:
print(find_shortest_distance(city))

1039


#### [OLD] first attempt

In [7]:
key = lambda a: a[-2] + 3*a[-1]
def prune_queue(q: List[Tuple[Tuple[int, int], Tuple[int, int], int, int, int]], 
                n_to_keep: int) -> List[Tuple[np.ndarray, np.ndarray, int, int, int]]:
    return sorted(q, key=key, reverse=False)[:n_to_keep]

def add_possibilities(status: Tuple[Tuple[int, int], Tuple[int, int], int, int, int], city: np.ndarray, 
                      m: int, n: int) -> Tuple[Tuple[int, int], Tuple[int, int], int, int, int]:
    pos, ori, nmoves, loss, _ = status
    new_status = []
    moves_to_check = [(V(*ori), nmoves + 1)] if nmoves < 3 else []
    moves_to_check += [(V(*ori).dot(V((0,1),(1,0))), 1), (V(*ori).dot(V((0,-1),(-1,0))), 1)]
    # go straight, turn left or turn right
    for new_ori, move_count in moves_to_check:
        new_pos = V(*pos) + new_ori
        if new_pos[0] >=0 and new_pos[0] < m and new_pos[1] >=0 and new_pos[1] < n:
            new_loss = loss + city[new_pos[0], new_pos[1]]
            new_dist = np.abs(V(m-1,n-1) - new_pos).dot(V(1,1))
            new_status.append((tuple(new_pos.tolist()), tuple(new_ori.tolist()), move_count, new_loss, new_dist))
    return new_status

def find_best_way(city: np.ndarray, status_size: int=100):
    m,n = city.shape
    status = [
        ((0,0), (0,1), 0, 0, m+n-2), # position, direction, # consecutive straight moves, heat loss, distance
        ((0,0), (1,0), 0, 0, m+n-2)
    ]
    arrived = False
    while not arrived:
        status_ = list()                            # List of possible options
        for s in status:
            status_ += add_possibilities(s, city, m, n)
        status_ = list(set(status_))
        status = prune_queue(status_, status_size)  # Prune the search queue
        arrived = any([s[0] == (m-1,n-1) for s in status])
    return min([s[-2] for s in status if s[0] == (m-1,n-1)])

In [8]:
# find_best_way(city, status_size=500)

## --- Part Two ---

In [9]:
def find_ultra_crucible_way(city: np.ndarray):
    m, n = city.shape
    queue = [(0, 0, 0, 1, 0), (0, 0, 0, 2, 0)] # cost, x, y, direction index, consecutive straight moves
    seen = set()
    costs = {}
    while len(queue) > 0:
        cost, x, y, ori, n_moves = heappop(queue)
        if x == m - 1 and y == n -1:
            return cost
        if (x, y, ori, n_moves) in seen:
            continue
        seen.add((x, y, ori, n_moves))
        if n_moves < 4:
            neighbors = [(ori, 4, 4 - n_moves)]
        else:
            neighbors = [(ori, n_moves + 1, 1)] if n_moves < 10 else []
            neighbors += [((ori + 1) % ND, 4, 4), ((ori + ND - 1) % ND, 4, 4)]
        for new_ori, new_moves, steps in neighbors:
            dx, dy = DIRECTIONS[new_ori]
            nx = x + dx * steps
            ny = y + dy * steps
            if nx >=0 and nx < m and ny >=0 and ny < n:
                v = (nx, ny, new_ori, new_moves)
                cost_array = city[min(x+1,nx):max(x,nx+1), y] if dy == 0 else city[x, min(y+1,ny):max(y,ny+1)]
                new_cost = cost + cost_array.sum()
                current_cost = costs.get(v, MAX_VALUE)
                if new_cost < current_cost:
                    costs[v] = new_cost
                    heappush(queue, (new_cost, nx, ny, new_ori, new_moves))

In [10]:
print(find_ultra_crucible_way(city))

1201
