In [1]:
def load(dataset): 
    """
    Load and parse a dataset from file
    
    Returns
    -------
    matrix
        list of list of elevation numbers
    """
    with open(f"./{dataset}.txt", "r") as file: 
        data = file.read().strip().split("\n")
        
    return [[int(x) for x in row] for row in data]

def find(grid):
    """
    Find all the low points in a grid, where 
    a low point is defined a being strictly lower 
    than the four adjacent points
    
    Yields
    ------
    x, y, height
        Low point in the grid
    """
    w, h = len(grid[0]), len(grid)
    
    for y in range(h): 
        for x in range(w):
            
            if grid[y][x] >= (grid[y][x-1] if x != 0 else 10): 
                continue
                
            if grid[y][x] >= (grid[y][x+1] if x != (w-1) else 10): 
                continue
    
            if grid[y][x] >= (grid[y-1][x] if y != 0 else 10):
                continue
                
            if grid[y][x] >= (grid[y+1][x] if y != (h-1) else 10):
                continue
            
            yield (x, y, grid[y][x])
            
def main(): 
    grid = load("input")
    
    return sum(h+1 for x, y, h in find(grid))

main()

425

In [2]:
def explore(position, grid):
    """
    Given a grid of elevations points and a starting position,
    finds all the adjacent points less than 9
    
    Returns
    -------
    set 
        set of adjacent points (as (x, y) tuples)
    """

    pending = set([position])
    visited = set()

    while len(pending) != 0: 
        x, y = pending.pop()
        
        visited.add((x,y))
        
        for h, v in [(-1,0),(+1,0),(0,-1),(0,+1)]: 
            newy, newx = max(0, y+v), max(0, x+h)
            
            try: 
                if grid[newy][newx] == 9: 
                    continue
                    
            except IndexError:
                continue
                
            if (newx, newy) in visited: 
                continue
                
            pending.add((newx, newy))
            
    return visited

def main():
    grid = load("input")

    basins = []
    for x, y, h in find(grid):
        basins.append(((x, y), len(explore((x, y), grid))))
    
    product = 1
    for origin, size in sorted(basins, key=lambda x: x[1], reverse=True)[0:3]:
        product *= size
        
    return product

main()

1135260