## Part 1

In [1]:
from enum import Enum
from queue import PriorityQueue

In [77]:
TEST_INFILE_1 = "inputs/day_17_test_1.txt"
TEST_INFILE_2 = "inputs/day_17_test_2.txt"
INFILE = "inputs/day_17_1.txt"

#with open(TEST_INFILE_1) as infile:
#with open(TEST_INFILE_2) as infile:
with open(INFILE) as infile:
    lines = infile.read().splitlines()

grid = [[int(el) for el in line] for line in lines]

In [78]:
grid

[[1,
  4,
  3,
  2,
  2,
  1,
  3,
  5,
  2,
  3,
  3,
  4,
  4,
  3,
  2,
  4,
  2,
  1,
  5,
  1,
  3,
  1,
  1,
  5,
  1,
  1,
  5,
  5,
  3,
  1,
  5,
  2,
  6,
  5,
  3,
  3,
  3,
  6,
  2,
  6,
  3,
  2,
  5,
  1,
  3,
  3,
  2,
  4,
  3,
  6,
  4,
  5,
  5,
  4,
  1,
  2,
  7,
  6,
  3,
  1,
  2,
  7,
  1,
  4,
  5,
  2,
  5,
  7,
  6,
  2,
  7,
  2,
  3,
  5,
  2,
  3,
  1,
  3,
  7,
  3,
  4,
  6,
  4,
  2,
  7,
  6,
  2,
  2,
  1,
  7,
  4,
  3,
  7,
  1,
  4,
  3,
  5,
  4,
  1,
  4,
  6,
  3,
  6,
  3,
  5,
  4,
  5,
  2,
  2,
  5,
  4,
  3,
  2,
  2,
  4,
  1,
  5,
  5,
  5,
  2,
  3,
  4,
  3,
  1,
  3,
  2,
  3,
  3,
  2,
  5,
  4,
  1,
  2,
  2,
  5,
  2,
  4,
  4,
  3,
  3,
  4],
 [2,
  3,
  2,
  2,
  2,
  4,
  1,
  2,
  2,
  5,
  1,
  5,
  2,
  1,
  3,
  2,
  3,
  4,
  3,
  4,
  3,
  2,
  1,
  3,
  2,
  3,
  2,
  2,
  2,
  5,
  2,
  6,
  6,
  5,
  3,
  5,
  2,
  2,
  3,
  6,
  4,
  2,
  4,
  5,
  4,
  4,
  3,
  2,
  1,
  5,
  7,
  7,
  6,
  5,
  2,
  2,
  3,
  3,
  6,

In [79]:
class Point:
    def __init__(self, row, col):
        self.row = row
        self.col = col

    def __repr__(self):
        return f"({self.row}, {self.col})"

    def __add__(self, other):
        return Point(self.row + other.row, self.col + other.col)

    def __radd__(self, other):
        return self + other
    
    def __eq__(self, other):
        return self.row == other.row and self.col == other.col
    
    def __hash__(self):
        return hash((self.row, self.col))
    

class Direction(Enum):
    UP = Point(-1, 0)
    DOWN = Point(1, 0)
    LEFT = Point(0, -1)
    RIGHT = Point(0, 1)


assert Point(0, 0) == Point(0, 0)
assert Point(1, 0) + Point(1, 10) == Point(2, 10)
p = Point(1, 0)
p += Point(1, 10)
assert p == Point(2, 10)


class Vector:
    def __init__(self, point, direction):
        self.point = point
        self.direction = direction

    def __repr__(self):
        return f"({self.point.row}, {self.point.col}) => {self.direction}"
    
    def __eq__(self, other):
        return self.point == other.point and self.direction == other.direction
    
    def __lt__(self, other):
        return self.point.row < other.point.row or (self.point.row == other.point.row and self.point.col < other.point.col)
    
    def __hash__(self):
        return hash((self.point.row, self.point.col, self.direction))
    

assert Vector(Point(0, 0), Direction.UP) == Vector(Point(0, 0), Direction.UP)

In [80]:
def get_neighbors(state, grid=grid):
    # state is a (vector, current_run) pair
    current_point = Point(state[0].point.row, state[0].point.col)  
    neighbors = []
    
    # go left
    match state[0].direction:
        case Direction.UP:
            left = Direction.LEFT
        case Direction.DOWN:
            left = Direction.RIGHT
        case Direction.LEFT:
            left = Direction.DOWN
        case Direction.RIGHT:
            left = Direction.UP
    neighbor_point = current_point + left.value
    if neighbor_point.row >= 0 and neighbor_point.row < len(grid):
        if neighbor_point.col >= 0 and neighbor_point.col < len(grid[0]):
            neighbors.append((Vector(neighbor_point, left), 1))

    # go right
    match state[0].direction:
        case Direction.UP:
            right = Direction.RIGHT
        case Direction.DOWN:
            right = Direction.LEFT
        case Direction.LEFT:
            right = Direction.UP
        case Direction.RIGHT:
            right = Direction.DOWN
    neighbor_point = current_point + right.value
    if neighbor_point.row >= 0 and neighbor_point.row < len(grid):
        if neighbor_point.col >= 0 and neighbor_point.col < len(grid[0]):
            neighbors.append((Vector(neighbor_point, right), 1))

    # if we can keep going
    if state[1] < 3:
        neighbor_point = current_point + state[0].direction.value
        if neighbor_point.row >= 0 and neighbor_point.row < len(grid):
            if neighbor_point.col >= 0 and neighbor_point.col < len(grid[0]):
                neighbors.append((Vector(neighbor_point, state[0].direction), state[1] + 1))

    return neighbors

In [81]:
# test_state = Vector(Point(0, 0), Direction.RIGHT)
# assert get_neighbors((test_state, 0)) == [
#     (Vector(Point(1, 0), Direction.DOWN), 0),
#     (Vector(Point(0, 1), Direction.RIGHT), 1)
# ]
# assert get_neighbors((test_state, 2)) == [
#     (Vector(Point(1, 0), Direction.DOWN), 0),
#     (Vector(Point(0, 1), Direction.RIGHT), 3)
# ]

# assert get_neighbors((test_state, 3)) == [
#     (Vector(Point(1, 0), Direction.DOWN), 0)
# ]

In [82]:
start = (Vector(Point(0, 0), Direction.RIGHT), 0)
goals = [Vector(Point(len(grid) - 1, len(grid[0]) - 1), d) for d in Direction]

costs = {}
preds = {}
costs[start] = 0
preds[start] = None

open_q = PriorityQueue()
open_q.put((0, start))

while not open_q.empty():
    cost, state = open_q.get()
    #print(f"Popped state {state} with cost {cost}")

    current_cost = grid[state[0].point.row][state[0].point.col]

    if state[0] in goals:
        print("GOT TO THE END!!")
        break

    for neighbor in get_neighbors(state):
        if neighbor not in costs or (cost + current_cost) < costs[neighbor]:
            new_cost = cost + current_cost
            #print(f"Adding ({new_cost}, {neighbor}) to the queue.")
            costs[neighbor] = new_cost
            preds[neighbor] = state
            
            priority = new_cost
            # priority = new_cost + heuristic(neighbor, goal)
            open_q.put((new_cost, neighbor))

GOT TO THE END!!


In [83]:
[(state, cost) for state, cost in costs.items() if state[0] in goals]

[(((140, 140) => Direction.DOWN, 1), 852),
 (((140, 140) => Direction.DOWN, 2), 852),
 (((140, 140) => Direction.DOWN, 3), 852),
 (((140, 140) => Direction.RIGHT, 1), 850),
 (((140, 140) => Direction.RIGHT, 2), 852)]

In [84]:
# get the minimum cost and subtract off the start because we don't incur it
# and add the end because we do
min_cost = min(cost for state, cost in costs.items() if state[0] in goals)
min_cost -= grid[start[0].point.row][start[0].point.col]
min_cost += grid[goals[0].point.row][goals[0].point.col]
min_cost

851

In [85]:
path = []
state = (Vector(Point(len(grid) - 1, len(grid[0]) - 1), Direction.DOWN), 3)
        
while state != start:
    path.append(state)
    state = preds[state]

path.append(start)

In [86]:
for row_n in range(len(grid)):
    for col_n in range(len(grid[0])):
        if Point(row_n, col_n) in [v[0].point for v in path]:
            direction = [v[0].direction for v in path if v[0].point == Point(row_n, col_n)][0]
            match direction:
                case Direction.UP:
                    symbol = "^"
                case Direction.DOWN:
                    symbol = "v"
                case Direction.LEFT:
                    symbol = "<"
                case Direction.RIGHT:
                    symbol = ">"
            print(symbol, end="")
        else:
            print(grid[row_n][col_n], end="")
    print("")

>4322135^>>4432421513115115531526533362632513324364554127631271452576272352313734642762217437143541463635452254322415552343132332541225244334
v>>>24^>>5v>>>3234343213232225266535223642454432157765223362631471727224725433514675532722143475746311414234363235445634532455243532411134324
434v>>>234244v>543334332142636222216655142344222672761227432126111437515747465667463152236426766314656313121424313621521641112315143254325312
14223354233244v>32524455342216415323231264646775323527175456667615625464425746766771354666467261542113552215211452313626312115142125311223333
334253231423335v>5541154151336514215462164446253115516645377524461331342411413735421513175332556364417746656124461524131363612211554242414513
3244423215322543v>>16315112614422233477263665256176764467633612465117654766546711516623645356422724366243443244534331132523662433244444321335
425315223212514152v34251224646342113377475455547365477243432266131145713635725565565224363637662213125111442654343546461216251352352125351414
551314

## Part 2

In [87]:
def get_neighbors_ultra(state, grid=grid):
    # state is a (vector, current_run) pair
    current_point = Point(state[0].point.row, state[0].point.col)  
    neighbors = []
    
    # need to go at least 4 to turn
    if state[1] >= 4:
        # go left
        match state[0].direction:
            case Direction.UP:
                left = Direction.LEFT
            case Direction.DOWN:
                left = Direction.RIGHT
            case Direction.LEFT:
                left = Direction.DOWN
            case Direction.RIGHT:
                left = Direction.UP
        neighbor_point = current_point + left.value
        if neighbor_point.row >= 0 and neighbor_point.row < len(grid):
            if neighbor_point.col >= 0 and neighbor_point.col < len(grid[0]):
                neighbors.append((Vector(neighbor_point, left), 1))

        # go right
        match state[0].direction:
            case Direction.UP:
                right = Direction.RIGHT
            case Direction.DOWN:
                right = Direction.LEFT
            case Direction.LEFT:
                right = Direction.UP
            case Direction.RIGHT:
                right = Direction.DOWN
        neighbor_point = current_point + right.value
        if neighbor_point.row >= 0 and neighbor_point.row < len(grid):
            if neighbor_point.col >= 0 and neighbor_point.col < len(grid[0]):
                neighbors.append((Vector(neighbor_point, right), 1))

    # if we can, keep going
    if state[1] < 10:
        neighbor_point = current_point + state[0].direction.value
        if neighbor_point.row >= 0 and neighbor_point.row < len(grid):
            if neighbor_point.col >= 0 and neighbor_point.col < len(grid[0]):
                neighbors.append((Vector(neighbor_point, state[0].direction), state[1] + 1))

    return neighbors

In [None]:
start_R = (Vector(Point(0, 0), Direction.RIGHT), 0)
start_D = (Vector(Point(0, 0), Direction.DOWN), 0)
goals = [Vector(Point(len(grid) - 1, len(grid[0]) - 1), d) for d in Direction]

costs = {}
preds = {}
costs[start_R] = 0
costs[start_D] = 0
preds[start_R] = None
preds[start_D] = None

open_q = PriorityQueue()
open_q.put((0, start_R))
open_q.put((0, start_D))

while not open_q.empty():
    cost, state = open_q.get()
    #print(f"Popped state {state} with cost {cost}")

    current_cost = grid[state[0].point.row][state[0].point.col]

    if state[0] in goals and state[1] >= 4:
        print("GOT TO THE END!!")
        print(state)
        break

    for neighbor in get_neighbors_ultra(state):
        if neighbor not in costs or (cost + current_cost) < costs[neighbor]:
            new_cost = cost + current_cost
            #print(f"Adding ({new_cost}, {neighbor}) to the queue.")
            costs[neighbor] = new_cost
            preds[neighbor] = state
            
            priority = new_cost
            # priority = new_cost + heuristic(neighbor, goal)
            open_q.put((new_cost, neighbor))

GOT TO THE END!!
((140, 140) => Direction.RIGHT, 10)


In [102]:
[(state, cost) for state, cost in costs.items() if state[0] in goals and state[1] >= 4]
#[(state, cost) for state, cost in costs.items() if state[0] in goals]

[(((140, 140) => Direction.RIGHT, 10), 981),
 (((140, 140) => Direction.RIGHT, 7), 982),
 (((140, 140) => Direction.RIGHT, 9), 983)]

In [103]:
# get the minimum cost and subtract off the start because we don't incur it
# and add the end because we do
min_cost = min(cost for state, cost in costs.items() if state[0] in goals and state[1] >= 4)
min_cost -= grid[start[0].point.row][start[0].point.col]
min_cost += grid[goals[0].point.row][goals[0].point.col]
min_cost

982

In [106]:
path = []
state = (Vector(Point(len(grid) - 1, len(grid[0]) - 1), Direction.RIGHT), 10)

while state not in [start_R, start_D]:
    path.append(state)
    state = preds[state]

path.append(start)

In [107]:
for row_n in range(len(grid)):
    for col_n in range(len(grid[0])):
        if Point(row_n, col_n) in [v[0].point for v in path]:
            direction = [v[0].direction for v in path if v[0].point == Point(row_n, col_n)][0]
            match direction:
                case Direction.UP:
                    symbol = "^"
                case Direction.DOWN:
                    symbol = "v"
                case Direction.LEFT:
                    symbol = "<"
                case Direction.RIGHT:
                    symbol = ">"
            print(symbol, end="")
        else:
            print(grid[row_n][col_n], end="")
    print("")

>43221352334432421513115115531526533362632513324364554127631271452576272352313734642762217437143541463635452254322415552343132332541225244334
v32224122515213234343213232225266535223642454432157765223362631471727224725433514675532722143475746311414234363235445634532455243532411134324
v34211123424411543334332142636222216655142344222672761227432126111437515747465667463152236426766314656313121424313621521641112315143254325312
v42233542332445332524455342216415323231264646775323527175456667615625464425746766771354666467261542113552215211452313626312115142125311223333
v34253231423335225541154151336514215462164446253115516645377524461331342411413735421513175332556364417746656124461524131363612211554242414513
v24442321532254331116315112614422233477263665256176764467633612465117654766546711516623645356422724366243443244534331132523662433244444321335
v25315223212514152234251224646342113377475455547365477243432266131145713635725565565224363637662213125111442654343546461216251352352125351414
v51314