In [1]:
test_input = """..@@.@@@@.
@@@.@.@.@@
@@@@@.@.@@
@.@@@@..@.
@@.@@@@.@@
.@@@@@@@.@
.@.@.@.@@@
@.@@@.@@@@
.@@@@@@@@.
@.@.@@@.@."""

## Numpy convolutions ftw

A convolution kernel of 

```
1 1 1
1 0 1
1 1 1
```

will quickly count neighbors of each point in a vectorized way.

Numpy can then filter out non-paper cells and those with too many neighbors.

scipy.signlal.convolve2d can do this, but it's an order of magnitude slower for this simple input. Insread we can just shift and add, which is fully vectorized in Numpy.

In [3]:
import numpy as np
  
data = np.array([[c=='@' for c in line] for line in test_input.split('\n')], dtype=int)
padded = np.pad(data, 1, mode='constant')

neighbors = (padded[:-2, :-2] + padded[:-2, 1:-1] + padded[:-2, 2:] +
             padded[1:-1, :-2]                    + padded[1:-1, 2:] +
             padded[2:, :-2]  + padded[2:, 1:-1]  + padded[2:, 2:])

mask = ((data != 0) & (neighbors < 4)).astype('int')
mask

array([[0, 0, 1, 1, 0, 1, 1, 0, 1, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 0, 0, 0, 0, 1, 0]])

The count of ones is the direct answer to part one:


In [4]:
np.count_nonzero(mask)

13

For part two, just update the count and remove the counted items.


In [5]:
def sum_neighbors(data):
    p = np.pad(data, 1, mode='constant')
        
    # Sum all 8 neighbors by shifting the padded array
    # This is just a convolution, but faster than things like
    # scipy image on a small kernel.
    
    neighbors = (p[:-2, :-2] + p[:-2, 1:-1] + p[:-2, 2:]  +
                 p[1:-1, :-2]               + p[1:-1, 2:] +
                 p[2:, :-2]  + p[2:, 1:-1]  + p[2:, 2:])
    return neighbors


def part_one(s):
    data = np.array([[c=='@' for c in line] for line in s.split('\n')], dtype=int)
    neighbors = sum_neighbors(data)
    to_remove = (data == 1) & (neighbors < 4)
    
    return np.count_nonzero(to_remove)


def part_two(s):
    data = np.array([[1 if c == '@' else 0 for c in line] for line in s.split('\n')], dtype=np.int8)
    initial_count = np.sum(data)
    rows, cols = data.shape

    while True:
        neighbors = sum_neighbors(data)
        to_remove = (data == 1) & (neighbors < 4)
        
        if not to_remove.any():
            break

        data[to_remove] = 0

    return int(initial_count - np.sum(data))

part_one(test_input), part_two(test_input)

(13, 43)

In [6]:
with open('input_files/day_04.txt') as f:
    data = f.read()
    
part_one(data), part_two(data)

(1441, 9050)

In [7]:
%timeit part_two(data)

3.24 ms ± 20.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
