## part 1 ##

In [1]:
testdata = '''2199943210
3987894921
9856789892
8767896789
9899965678'''.split('\n')
testdata

['2199943210', '3987894921', '9856789892', '8767896789', '9899965678']

In [2]:
with open('day9.txt') as fp:
    puzzledata = fp.read().split('\n')[:-1]
puzzledata[-1]

'7857987631234567934569219865466976521345678967891236987656567943456789999901298798920989323678998989'

In [3]:
def heightmap(data):
    hmap = {}
    nrows = len(data)
    ncols = len(data[0])
    for i, row in enumerate(data):
        for j,c in enumerate(row):
            height = int(c)
            hmap[(i,j)] = height
    return nrows, ncols, hmap

In [4]:
def isnbor(pos, nrows, ncols):
    r, c = pos
    return (0 <= r < nrows) and (0 <= c < ncols)

In [5]:
def islow(pos, hmap, nrows, ncols):
    h = hmap[pos]
    row, col = pos
    nbors = [(r,c) for (r,c) in [(row-1, col), (row+1, col), (row, col-1), (row, col+1)] 
             if isnbor((r,c), nrows, ncols)]
    return all(hmap[nb] > h for nb in nbors)

            

In [6]:
def lowpts(hmap, nrows, ncols):
    return [pos for pos in hmap if islow(pos, hmap, nrows, ncols)]

In [7]:
def totrisk(pts, hmap):
    return sum(hmap[pt]+1 for pt in pts)

In [8]:
testrows, testcols, testhmap = heightmap(testdata)

In [9]:
totrisk(lowpts(testhmap, testrows, testcols), testhmap)

15

In [10]:
puzzlerows, puzzlecols, puzzlehmap = heightmap(puzzledata)

In [11]:
totrisk(lowpts(puzzlehmap, puzzlerows, puzzlecols), puzzlehmap)

548

## part 2 ##

In [12]:
def walk(pos, hmap, basin):
    r, c = pos
    h = hmap[pos]
    nbors = [(r-1, c), (r+1, c), (r, c-1), (r, c+1)]
    for nb in nbors:
        if (nb not in hmap) or (nb in basin):
            continue
        nbh = hmap[nb]
        if nbh == 9:
            continue
        if (h < nbh):
            basin.append(nb)
            walk(nb, hmap, basin)
    return basin

In [13]:
def getbasin(pt, hmap):
    return walk(pt, hmap, [pt])

In [14]:
import math

In [15]:
def solve(hmap, nrows, ncols):
    basin_sizes = []
    for pt in lowpts(hmap, nrows, ncols):
        basin = getbasin(pt, hmap)
        basin_sizes.append(len(basin))
    return math.prod(sorted(basin_sizes)[-3:])

In [16]:
solve(testhmap, testrows, testcols)

1134

In [17]:
solve(puzzlehmap, puzzlerows, puzzlecols)

786048