In [1]:
import collections
from typing import Dict, List, Set


class CaveGraph:
    
    @classmethod
    def from_input(cls, input_filename: str):
        with open(input_filename) as input_file:
            raw_lines = input_file.readlines()
        
        # Map caves to their neighbors
        graph: Dict[str, [List[str]]] = collections.defaultdict(list)
        
        for raw_line in raw_lines:
            c1, c2 = raw_line.strip().split("-")
            graph[c1].append(c2)
            graph[c2].append(c1)
            
        return cls(graph)    
    
    def __init__(self, graph):
        self.graph = graph
        self.num_paths = 0
        
    def find_paths(self, revisit_small_caves: bool = False):
        """Breadth-first search the caves starting at `start`."""
        self.num_paths = 0  # reset just in case we're not starting fresh
        
        if not revisit_small_caves:
            self.travel(set(), "start")
        else:
            self.travel2(set(), set(), "start")

        return self.num_paths
    
    def travel(self, visited: Set[str], cave: str) -> None:
        """We only count small caves towards `visited`."""
        if cave == "end":
            self.num_paths += 1
            return
        
        if cave.islower():
            visited = visited.union({cave})
        
        for next_cave in self.graph[cave]:
            if next_cave not in visited:
                self.travel(visited, next_cave)
                
    def travel2(self, visited: Set[str], small_visited: Set[str], cave: str) -> None:
        """
        Caves that are added to visited:
        - `start` cave
        - first small cave that is visited twice
        - all other small caves
        """
        if cave == "end":
            self.num_paths += 1
            return
        
        if cave.islower():
            if cave == "start":
                visited = visited.union({cave})
            
            elif len(visited) >= 2:
                # We've already revisited a small cave
                visited = visited.union({cave})
            
            elif cave in small_visited:
                # We are revisiting a small cave for the first time
                visited = visited.union(small_visited)
                small_visited = set()
                
            else:
                # We haven't re-visited any small caves yet
                small_visited = small_visited.union({cave})
        
        for next_cave in self.graph[cave]:
            if next_cave not in visited:
                self.travel2(visited, small_visited, next_cave)

In [2]:
example_graph_1 = CaveGraph.from_input("input-example-1.txt")
example_graph_2 = CaveGraph.from_input("input-example-2.txt")
example_graph_3 = CaveGraph.from_input("input-example-3.txt")

cave_graph = CaveGraph.from_input("input.txt")

# Part 1

In [3]:
# Test Cases
assert example_graph_1.find_paths() == 10
assert example_graph_2.find_paths() == 19
assert example_graph_3.find_paths() == 226

In [4]:
num_paths = cave_graph.find_paths()
print("Number of paths:", num_paths)

Number of paths: 4885


# Part 2

In [5]:
# Test Cases
assert example_graph_1.find_paths(revisit_small_caves=True) == 36
assert example_graph_2.find_paths(revisit_small_caves=True) == 103
assert example_graph_3.find_paths(revisit_small_caves=True) == 3509

In [6]:
num_paths = cave_graph.find_paths(revisit_small_caves=True)
print("Number of paths:", num_paths)

Number of paths: 117095
