In [327]:
input_file = "input_files/day_17.txt"

with open(input_file) as lines:
    data = lines.read().splitlines()
    
data = [[int(c) for c in line] for line in data]


In [329]:
from typing import NamedTuple
from heapq import heappop, heappush

class Point(NamedTuple):
    r: int    # row position
    c: int    # column position
    d_r: int  # direction of movement in rows
    d_c: int  # direction of movement in columns
    s: int    # count of straight motion

def get_neigbors_regular(n, W, H):
    # straight - constrainted to straight count
    r, c, d_r, d_c, s = n
    if (0 <= r + d_r < H) and (0 <= c + d_c < W) and s < 2:
        yield Point(r + d_r, c + d_c, d_r, d_c, s+1)

    # right:
    if (0 <= r + d_c * -1 < H) and (0 <= c + d_r * -1 < W):
        yield Point(r + d_c * -1, c + d_r * -1, d_c * -1, d_r * -1, 0)

    # left:
    if (0 <= r + d_c < H) and (0 <= c + d_r < W):
        yield Point(r + d_c, c + d_r, d_c, d_r, 0)

def get_neigbors_ultra(n, W, H):
    # straight - constrainted to straight count
    r, c, d_r, d_c, s = n
    ns = []
    if (0 <= r + d_r < H) and (0 <= c + d_c < W) and s < 9:
        yield Point(r + d_r, c + d_c, d_r, d_c, s+1)

    # left - constrainted to straight count:
    if (0 <= r + d_c * -1 < H) and (0 <= c + d_r * -1 < W) and s > 2:
         yield Point(r + d_c * -1, c + d_r * -1, d_c * -1, d_r * -1, 0)

    # right - constrainted to straight count:
    if (0 <= r + d_c < H) and (0 <= c + d_r < W) and s > 2:
        yield Point(r + d_c, c + d_r, d_c, d_r, 0)


def path(data, part_two=False):
    '''
    Basic Dijkstra 
    '''
    neighbor_function = get_neigbors_ultra if part_two else get_neigbors_regular
    
    distances = {}             # {point: cost}
    p = Point(0, 0, 0, 1, 0)
    distances = {p: (0, None)} # {p: (cost, previous)}

    H = len(data)
    W = len(data[0])
    
    target = (H-1, W-1)
    
    # min heap (cost, point)
    h = [(0, p)]
    
    while len(h):
        cost, current_node = heappop(h)

        # found target
        if (current_node.r, current_node.c) == target:
            if not part_two:
                return cost, distances, current_node
            elif current_node.s >= 3:
                # part two needs to end on a 
                # straight run of at least 4
                return cost, distances, current_node
                    
        if cost > distances[current_node][0]:
            # we've alredy found a cheaper way here
            continue
        
        for neighbor in neighbor_function(current_node, W, H):
            edge_cost = data[neighbor.r][neighbor.c]
            next_cost = cost+edge_cost
            
            if next_cost < distances.setdefault(neighbor, (float('inf'), None))[0]:
                heappush(h, (next_cost, neighbor))
                distances[neighbor] = (next_cost, current_node)
        

cost, distances, end = path(data, part_two=False)
print("Part One: ", cost)

cost, distances, end = path(data, part_two=True)
print("Part Two: ", cost)



Part One:  942
Part Two:  1082


## Printing example paths

In [324]:
data_ex = [
    '2413432311323',
    '3215453535623',
    '3255245654254',
    '3446585845452',
    '4546657867536',
    '1438598798454',
    '4457876987766',
    '3637877979653',
    '4654967986887',
    '4564679986453',
    '1224686865563',
    '2546548887735',
    '4322674655533',
]

data_small = [
    '111111111111',
    '999999999991',
    '999999999991',
    '999999999991',
    '999999999991'
]

data_ex = [[int(c) for c in line] for line in data_ex]
data_small = [[int(c) for c in line] for line in data_small]


In [326]:
def print_path(data, distances, end):
    H = len(data)
    W = len(data[0])

    paths = {}
    current = end

    symbols = {
        (0, 1): ">",
        (1, 0): "v",
        (0, -1): "<",
        (-1, 0): "^"  
    }
    while True:
        paths[(current.r, current.c)] = current
        cost, current = distances[current]
        if current is None:
            break

    for row, line in enumerate(data):
        print()
        for col, n in enumerate(line):
            if (row, col) in paths:
                p = paths[(row, col)]
                print(symbols[(p.d_r, p.d_c)], end="")
            else:
                print(data[row][col], end="")


                
cost, distances, end = path(data_ex, False)
print("Part one minimum:", cost)
print_path(data_ex, distances, end)
print('\n')
print("=" * 20)
cost, distances, end = path(data_ex, True)
print("Part two example minimum:", cost)
print_path(data_ex, distances, end)

print('\n')
print("=" * 20)
cost, distances, end = path(data_small, True)
print("Part two small example minimum:", cost)
print_path(data_small, distances, end)


Part one minimum: 102

>>>34^>>>1323
32v>>>35v>623
325524565v>54
3446585845v52
4546657867v>6
14385987984v4
44578769877v6
36378779796v>
465496798688v
456467998645v
12246868655<v
25465488877v5
43226746555v>

Part two example minimum: 94

>>>>>>>>>1323
32154535v5623
32552456v4254
34465858v5452
45466578v>>>>
143859879845v
445787698776v
363787797965v
465496798688v
456467998645v
122468686556v
254654888773v
432267465553v

Part two small example minimum: 71

>>>>>>>>1111
9999999v9991
9999999v9991
9999999v9991
9999999v>>>>