# [Day 17](https://adventofcode.com/2023/day/17)


## Model

In [1]:
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from enum import Enum


class Choice(Enum):
    LEFT = 0
    STRAIGHT = 1
    RIGHT = 2

type Pos = tuple[int, int]


class Dir(Enum):
    NORTH = 0
    EAST = 1
    SOUTH = 2
    WEST = 3

    def reversed(self) -> 'Dir':
        return Dir((self.value + 2) % 4)

    def turn(self, choice: Choice) -> 'Dir':
        if choice == Choice.LEFT:
            return Dir((self.value - 1) % 4)
        elif choice == Choice.RIGHT:
            return Dir((self.value + 1) % 4)
        else:
            return self
        
    def of(self, pos: Pos):
        match self:
            case Dir.NORTH:
                return (pos[0], pos[1]-1)
            case Dir.EAST:
                return (pos[0]+1, pos[1])
            case Dir.SOUTH:
                return (pos[0], pos[1]+1)
            case Dir.WEST:
                return (pos[0]-1, pos[1])
        

@dataclass(frozen=True)
class Edge:
    dir: Dir
    fromNode: 'Node'
    to: 'Node'
    cost: int
    
@dataclass(unsafe_hash=True)
class Node:
    x: int
    y: int
    cost: int

    connected: dict[Dir, Edge] = field(default_factory=dict, init=False, compare=False, hash=False)

    def next(self, dir: Dir, choice: Choice) -> Edge:
        nextDir = dir.turn(choice)
        return self.connected[nextDir] if nextDir in self.connected else None
    
    def pos(self) -> tuple[int, int]:
        return (self.x, self.y)
    
    def __repr__(self):
        return f'{self.pos()}={self.cost}'

def readInput(filename: str) -> (dict[Pos, Node], list[list[str]]):
    nodes: dict[Pos, Node] = dict()
    map: list[list[str]] = []
    with open(filename, 'r') as f:
        for y, line in enumerate(f):
            line = line.strip()
            map.append(list(line))
            
            for x, c in enumerate(line):
                cost = int(c)
                node = Node(x, y, cost)

                nodes[node.pos()] = node

                for dir in Dir:
                    if (pos := dir.of(node.pos())) in nodes:
                        neighbor = nodes[pos]
                        node.connected[dir] = Edge(dir, node, neighbor, neighbor.cost)
                        rev = dir.reversed()
                        neighbor.connected[rev] = Edge(rev, neighbor, node, node.cost)
    return nodes, map

nodes, m = readInput('test.txt')

print('\n'.join(''.join(l) for l in m))

2413432311323
3215453535623
3255245654254
3446585845452
4546657867536
1438598798454
4457876987766
3637877979653
4654967986887
4564679986453
1224686865563
2546548887735
4322674655533



## Draw function

In [2]:
from graphviz import Digraph
from IPython.display import Image, display

imageCount = 1
def drawGraph(nodes, path: set[Edge] = set(), costs: dict[Node, int] = dict(), file: str = None, scale: int = 1):
    global imageCount
    if not file:
        file = f'image_{imageCount:03d}.png'  
        imageCount += 1
    g = Digraph('G', engine='neato', strict=False)
    for n in nodes:
        label = str(n.cost)
        if costs:
            label = f'{costs[n] if n in costs else ''}'
        g.node(f'{n.x} {n.y}', label=label, shape='square', colorscheme='oranges9', style='filled', fillcolor=f'{n.cost}', color=f'{n.cost}', pos=f'{n.x/1.4},{(len(m)+1-n.y)/1.4}!')
        # for e in n.connected.values():
        #     g.edge(f'{e.fromNode.x} {e.fromNode.y}', f'{e.to.x} {e.to.y}', dir='forward', colorscheme='oranges9', color=str(e.cost), fillcolor=str(e.cost))
    for e in path:
        g.node(f'{e.to.x} {e.to.y}', color='#00ff00ff', penwidth='3')
        if e.fromNode:
            g.edge(f'{e.fromNode.x} {e.fromNode.y}', f'{e.to.x} {e.to.y}', colorscheme='oranges9', fillcolor=f'{n.cost}', color='#00ff00ff', penwidth='2')
    # g.format = 'svg'
    g.format = 'png'
    # g.formatter = 'cairo'

    g.graph_attr['size'] = f'{5*scale,5*scale}!'
    g.render(outfile=file, cleanup=True)
    display(Image(url=f'./{file}'))

drawGraph(nodes.values())

## Shortest path

could be faster, but works

In [3]:
from queue import PriorityQueue
from typing import Generic
from typing import TypeVar

INF = 999999999999

T = TypeVar('T')
class EdgePriorityQueue(Generic[T]):
    idx: int
    queue: PriorityQueue[tuple[int, ...]]
    refCount: dict[T, list[int]]
    def __init__(self):
        self.idx = 0
        self.queue = PriorityQueue()
        self.refCount = defaultdict(list)
    def contains(self, item: T) -> bool:
        return len(self.refCount[item]) > 0
    def getCosts(self, item: T) -> list[int]:
        return list(*self.refCount[item])
    def put(self, item: T, *costs: list[int]):
        self.idx += 1
        self.queue.put((*costs, self.idx, item))
        if type(costs) == tuple:
            self.refCount[item].append(list(costs))
        else:
            self.refCount[item].append(costs)
    def get(self) -> tuple[T, ...]:
        *costs, _, item = self.queue.get()
        self.refCount[item].remove(costs)
        return (item, costs)
    def empty(self) -> bool:
        return self.queue.empty()
    def size(self) -> int:
        return self.queue.qsize()

type Item = tuple[Edge, int]
def addChoices(minStraight: int, maxStraight: int, queue: EdgePriorityQueue[tuple[Node, Node]], item: Item, previousItems: dict[Item, Item], edgeCosts: dict[Item, int], visited: set[Item] = set()):
    (edge, straight) = item
    if edge.fromNode is None:
        # add all children of start node
        for child in edge.to.connected.values():
            cost = edge.to.cost + child.cost
            queue.put((item, (child, 1)), cost)
    else:
        for choice in Choice:
            nextEdge = edge.to.next(edge.dir, choice)
            if not nextEdge:
                continue
            if choice != Choice.STRAIGHT and straight < minStraight:
                continue
            if choice == Choice.STRAIGHT and straight == maxStraight:
                continue
            nextItem = (nextEdge, (straight + 1) if choice == Choice.STRAIGHT else 1)
        
            if nextItem in visited:
                continue

            nextCost = edgeCosts[item] + nextEdge.cost
            if queue.contains(nextItem):
                costs = queue.getCosts(nextItem)
                if nextCost < min(costs):
                    queue.put((item, nextItem), nextCost)
            else:
                queue.put((item, nextItem), nextCost)



def shortestPath(start: Node, goal: Node, minStraight: int = 0, maxStraight: int = 0) -> tuple[dict[Node, Node], dict[Node, int], dict[(Edge, int), int]]:
    itemCosts: dict[Item, int] = dict()
    previousItems: dict[Item, Item] = dict()
    queue:  EdgePriorityQueue[tuple[Item, Item]] = EdgePriorityQueue()
    nodeCosts: dict[Node, int] = dict()
    
    visited: set[(Edge, straight)] = set()

    queue.put((None, (Edge(None, None, start, 0), 1)), 0)
    previousItems[None] = None
    goalItem = None
    while not queue.empty():
        (fromItem, item), _ = queue.get()
        (edge, straight) = item
        if item in visited:
            continue
        
        cost = itemCosts.get(fromItem, 0) + edge.cost
        
        if cost < itemCosts.get(item, INF):
            visited.add(item)
            itemCosts[item] = cost
            previousItems[item] = fromItem
            node = edge.to 
            if straight >= minStraight and straight <= maxStraight:
                if cost < nodeCosts.get(node, INF):
                    nodeCosts[node] = cost
                if node == goal:
                    if cost < itemCosts.get(goalItem, INF):
                        goalItem = item
        addChoices(minStraight, maxStraight, queue, item, previousItems, itemCosts, visited)
        
    return goalItem, previousItems, itemCosts, nodeCosts

def reconstructPath(item: Item, previousItems: dict[Item, Item]) -> tuple[set[Edge], dict[Node, int]]:
    path: set[Edge] = set()
    costs: dict[Node, int]= dict()
    while item is not None:
        path.add(item[0])
        costs[item[0].to] = itemCosts[item]
        item = previousItems[item]
    return path, costs

nodes, m = readInput('test.txt')
start = nodes[(0, 0)]
goal = nodes[(len(m[0])-1, len(m)-1)]

goalItem, previousItems, itemCosts, nodeCosts = shortestPath(nodes[(0, 0)], nodes[(len(m[0])-1, len(m)-1)], 0, 3)
print('Test: ', itemCosts[goalItem])
path, pathCosts = reconstructPath(goalItem, previousItems)
nodeCosts.update(pathCosts)
drawGraph(nodes.values(), path, nodeCosts)

Test:  102


## Part 1

In [4]:
nodes, m = readInput('input.txt')
start = nodes[(0, 0)]
goal = nodes[(len(m[0])-1, len(m)-1)]

goalItem, previousItems, itemCosts, nodeCosts = shortestPath(nodes[(0, 0)], nodes[(len(m[0])-1, len(m)-1)], 0, 3)
print('Part 1: ', itemCosts[goalItem])

path, pathCosts = reconstructPath(goalItem, previousItems)
nodeCosts.update(pathCosts)
drawGraph(nodes.values(), path, nodeCosts)


Part 1:  956


## Part 2

In [5]:

nodes, m = readInput('test.txt')
start = nodes[(0, 0)]
goal = nodes[(len(m[0])-1, len(m)-1)]

goalItem, previousItems, itemCosts, nodeCosts = shortestPath(nodes[(0, 0)], nodes[(len(m[0])-1, len(m)-1)], 4, 10)
print('Test 1: ', itemCosts[goalItem])

path, pathCosts = reconstructPath(goalItem, previousItems)
nodeCosts.update(pathCosts)
drawGraph(nodes.values(), path, nodeCosts)

nodes, m = readInput('test2.txt')
start = nodes[(0, 0)]
goal = nodes[(len(m[0])-1, len(m)-1)]
goalItem, previousItems, itemCosts, nodeCosts = shortestPath(nodes[(0, 0)], nodes[(len(m[0])-1, len(m)-1)], 4, 10)
print('Test 2: ', itemCosts[goalItem])

path, pathCosts = reconstructPath(goalItem, previousItems)
nodeCosts.update(pathCosts)
drawGraph(nodes.values(), path, nodeCosts)


nodes, m = readInput('input.txt')
start = nodes[(0, 0)]
goal = nodes[(len(m[0])-1, len(m)-1)]

goalItem, previousItems, itemCosts, nodeCosts = shortestPath(nodes[(0, 0)], nodes[(len(m[0])-1, len(m)-1)], 4, 10)
print('Part 2: ',itemCosts[goalItem])

path, pathCosts = reconstructPath(goalItem, previousItems)
nodeCosts.update(pathCosts)
drawGraph(nodes.values(), path, nodeCosts)

Test 1:  94


Test 2:  71


Part 2:  1106
