In [1]:
import os
from pathlib import Path
from collections import defaultdict

FOLDER = Path(os.path.dirname(os.path.realpath("__file__"))) / 'data'

in_file = 'day12.txt'

with open(FOLDER / in_file) as f:
    data = f.read().splitlines()  

In [2]:
'''
It's a non-directed graph, need to add 
edge for both direction. But don't bother
going back to start or starting at end.
'''

graph = defaultdict(set)

for line in data:
    source, dest = line.split('-')
    
    if dest != 'start' and source != 'end':
        graph[source].add(dest)

    if source !='start' and dest!='end':
        graph[dest].add(source)

### Problem One

In [3]:
def problem_one(graph, symbol, seen=None):
    if seen is None:
        seen = set()
    if symbol == 'end':
        yield 1
        return
    
    for node in graph[symbol] - seen:
        local_seen = set() if node.isupper() else set([node])
        yield sum(problem_one(graph, node, seen | local_seen))
        
paths = problem_one(graph, 'start')  
print(f"solution 1: {sum(paths)}")


solution 1: 3410


### Problem Two

In [18]:
def problem_two(graph, symbol, seen=None, twice=False):
    if seen is None:
        seen = set()
    if symbol == 'end':
        yield 1
        return

    for node in graph[symbol]:
        local_twice = twice
        local_seen = set()
        if node.islower():
            local_seen.add(node)
            if node in seen:
                if local_twice:
                    continue
                local_twice = True
            
        yield sum(problem_two(graph, node, seen | local_seen, local_twice))
        
paths = problem_two(graph, 'start')
print(f"solution 2: {sum(paths)}")

solution 2: 98796


In [7]:
%timeit sum(problem_two(graph, 'start')) #... hmm let's try a cache

527 ms ± 2.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Problem Two with lru_cache

In [38]:
from functools import lru_cache

# graph can't be in the args since 
# a dict isn't hashable
graph = graph

@lru_cache
def problem_two_cache(symbol, seen=None, twice=False):
    if seen is None:
        seen = frozenset()
    if symbol == 'end':
        return 1

    total = 0
    for node in graph[symbol]:
        local_twice = twice
        local_seen = set()
        if node.islower():
            local_seen.add(node)
            if node in seen:
                if local_twice:
                    continue
                local_twice = True
            
        total += problem_two_cache(node, frozenset(seen | local_seen), local_twice)
    return total

In [41]:
%timeit problem_two_cache('start') # ...better...but too good. lru_cache is probably not cleared betweeen runs

72.7 ns ± 0.607 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)


In [39]:
problem_two_cache('start')

98796

In [42]:
problem_two_cache.cache_info() # all those hits are from the timeit test

CacheInfo(hits=81116644, misses=4674, maxsize=128, currsize=128)

In [43]:
def run_and_clear():
    problem_two_cache.cache_clear()
    return problem_two_cache('start')


98796

In [44]:
%timeit run_and_clear() # maybe closer to the truth?

10.3 ms ± 114 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
