In [180]:
input_file = "input_files/day_23.txt"

with open(input_file) as lines:
    data = lines.read().splitlines()

In [181]:
from typing import NamedTuple
from collections import defaultdict, deque

class Point(NamedTuple):
    row: int
    col: int

    def __add__(self, other):
        return Point(self.row + other.row, self.col + other.col)
    
def get_valid(point, data, seen, part_two=False):
    directions = [
        (Point(1, 0), 'v'),
        (Point(-1, 0), '^'),
        (Point(0, 1), '>'),
        (Point(0, -1), '<')
    ]
    for p, d in directions:
        pp = point + p
        if pp not in seen and 0 <= pp.row < len(data) and 0 <= pp.col < len(data[0]) and data[pp.row][pp.col] != '#':
            if part_two:
                yield pp
            elif (data[pp.row][pp.col] == '.' or data[pp.row][pp.col] == d):
                yield pp    


In [182]:
def make_node_graph(data, start, end, part_two=False):
    '''
    Find all the places on the map where the path splits.
    Make a weighted graph where these splits are nodes 
    and the costs are the distance.
    '''
    G = defaultdict(list)
    seen = set([start])

    q = deque([(start, seen, start)])
    parents = set()
    while len(q):
        location, seen, parent = q.popleft()

        if (location, parent) in parents:
            # allow other ways of getting to this 
            # intersection. Only prune paths from 
            # the same paretn
            continue
            
        parents.add((location, parent))
        
        cost = 1
        while True:
            nexts = list(get_valid(location, data, seen, part_two))
            if len(nexts) == 0:
                break
            if len(nexts) == 1:
                # skip past any steps that don't involve 
                # a branch. Just not the cose
                cost += 1
                
                if nexts[0] == end:
                    # Hey, the exit!
                    G[parent].append(('end', cost))
                    break
                    
                seen.add(nexts[0])
                location = nexts[0]
            else:
                G[parent].append((location, cost))
                cost += 1
                for n in nexts:
                    seen.add(n)
                    # visted set is only for this path
                    # each break gets a new visted set
                    # with previous visted, but should
                    # be independent of others after it
                    q.append([n, seen.copy(), location])
                
                break
            
    return G
            
start = Point(0, 1)
end = Point(len(data)-1, len(data[0]) - 2)
G = make_node_graph(data, start, end)  


In [183]:
def find_longest(G, location, end):
    '''
    Find all paths in node graph.
    '''
    seen = set([location])
    q = deque([(location, 0, seen)]) # point, cost, seen

    while len(q):
        location, cost, seen = q.pop()             

        if location == end:
            yield cost
        
        
        for next_node, node_cost in G[location]:
            if next_node in seen:
                continue
            sc = seen.copy()
            sc.add(next_node)

            q.append((next_node, node_cost + cost, sc))      
            
        
it = find_longest(G, start, 'end')
max(it) - 1

2050

In [184]:
G = make_node_graph(data, start, end,  part_two=True)  
it = find_longest(G, start, 'end')
max(it) - 1

6262