In [122]:
from collections import Counter
from itertools import product, combinations
import networkx as nx

In [198]:
def load_small_example():
    example = """
start-A
start-b
A-c
A-b
b-d
A-end
b-end
    """

    example = example.split('\n')[1:-1]
    return example

def load_larger_example():
    example = """
dc-end
HN-start
start-kj
dc-start
dc-HN
LN-dc
HN-end
kj-sa
kj-HN
kj-dc
    """
    example = example.split('\n')[1:-1]
    return example

def load_largest_example():
    example = """
fs-end
he-DX
fs-he
start-DX
pj-DX
end-zg
zg-sl
zg-pj
pj-he
RW-he
fs-DX
pj-RW
zg-RW
start-pj
he-WI
zg-he
pj-fs
start-RW
    """
    example = example.split('\n')[1:-1]
    return example

def load_my_input():
    a = """
EO-jc
end-tm
jy-FI
ek-EO
mg-ek
jc-jy
FI-start
jy-mg
mg-FI
jc-tm
end-EO
ds-EO
jy-start
tm-EO
mg-jc
ek-jc
tm-ek
FI-jc
jy-EO
ek-jy
ek-LT
start-mg
    """
    a = a.split('\n')[1:-1]
    return a


load_largest_example()

['fs-end',
 'he-DX',
 'fs-he',
 'start-DX',
 'pj-DX',
 'end-zg',
 'zg-sl',
 'zg-pj',
 'pj-he',
 'RW-he',
 'fs-DX',
 'pj-RW',
 'zg-RW',
 'start-pj',
 'he-WI',
 'zg-he',
 'pj-fs',
 'start-RW']

In [220]:
example = load_small_example()
G = nx.Graph()
G.add_edges_from(x.split('-') for x in example)
G.edges()

EdgeView([('start', 'A'), ('start', 'b'), ('A', 'c'), ('A', 'b'), ('A', 'end'), ('b', 'd'), ('b', 'end')])

In [128]:
G.nodes()

NodeView(('start', 'A', 'b', 'c', 'd', 'end'))

In [197]:
def is_lower(node):
    return node == node.lower()

def is_upper(node):
    return node != node.lower()

# returns a set of visited nodes
def build_path(G, cur_node, path_so_far, visited_nodes, all_paths):
    visited_nodes.add(cur_node)
    path_so_far.append(cur_node)
    
    # path_so_far is a list that we return once we hit the "end" node
    if cur_node == 'end':
        path_str = ','.join(path_so_far)
        all_paths.add(path_str)
    
    else:
        for edge in G.edges(cur_node):
            next_node = edge[1]
            # it's okay to visit the node again if it's uppercase, but not if it's lowercase
            if next_node not in visited_nodes or is_upper(next_node):
                build_path(G, next_node, path_so_far, visited_nodes, all_paths)
    
    path_so_far.pop()
    if cur_node in visited_nodes:
        visited_nodes.remove(cur_node)

        
example = load_largest_example()
G = nx.Graph()
G.add_edges_from(x.split('-') for x in example)
visited_nodes = set()
path_so_far = []
all_paths = set()
build_path(G, 'start', path_so_far, visited_nodes, all_paths)

len(all_paths)

226

In [200]:
G = nx.Graph()
G.add_edges_from(x.split('-') for x in load_my_input())
visited_nodes = set()
path_so_far = []
all_paths = set()
build_path(G, 'start', path_so_far, visited_nodes, all_paths)

len(all_paths)

5228

In [216]:
# part 2: it's okay to visit one small cave more than once

# visited_nodes is a counter dictionary of how many times we've seen different nodes
def build_path2(G, cur_node, path_so_far, visited_nodes, all_paths):
    # add the current node to the visited nodes, or increment if already exists
    if cur_node in visited_nodes:
        visited_nodes[cur_node] += 1
    else:
        visited_nodes[cur_node] = 1

    path_so_far.append(cur_node)
    
    # path_so_far is a list that we return once we hit the "end" node
    if cur_node == 'end':
        path_str = ','.join(path_so_far)
        all_paths.add(path_str)
    
    else:
        for edge in G.edges(cur_node):
            next_node = edge[1]
            if next_node == 'start' and 'start' in visited_nodes:
                continue  # don't visit "start" again
                
            small_cave_counts = {k: v for k, v in visited_nodes.items() if is_lower(k)}
            
            # first time seeing this node
            if next_node not in visited_nodes or visited_nodes.get(next_node) == 0:
                build_path2(G, next_node, path_so_far, visited_nodes, all_paths)
            
            # it's okay to visit the node again if it's uppercase
            elif is_upper(next_node):
                build_path2(G, next_node, path_so_far, visited_nodes, all_paths)
            
            # if it's a lowercase node, it's okay to visit it again if it's been visited once before and no other
            # small caves have been visited twice before
            elif visited_nodes.get(next_node) == 1 and not any(x == 2 for x in small_cave_counts.values()):
                build_path2(G, next_node, path_so_far, visited_nodes, all_paths)

    path_so_far.pop()
    visited_nodes[cur_node] -= 1
        

example = load_small_example()
G = nx.Graph()
G.add_edges_from(x.split('-') for x in example)
visited_nodes = {}
path_so_far = []
all_paths = set()
build_path2(G, 'start', path_so_far, visited_nodes, all_paths)

len(all_paths)

36

In [219]:
G = nx.Graph()
G.add_edges_from(x.split('-') for x in load_my_input())
visited_nodes = {}
path_so_far = []
all_paths = set()
build_path2(G, 'start', path_so_far, visited_nodes, all_paths)

len(all_paths)

131228