In [1]:
import copy

class Grid:
    def __init__(self, filename, neighbour_threshold=4):
        with open(filename) as f:
            self.grid = [list(row.strip()) for row in f]
        self.height = len(self.grid)
        self.width = len(self.grid[0])
        self.neighbour_threshold = neighbour_threshold
        
    def __len__(self):
        "Return the number of occupied seats"
        return sum(
            self.grid[y][x] == '#'
            for y in range(self.height)
            for x in range(self.width)
        )
    
    def __str__(self):
        s = ''
        for y in range(self.height):
            s += ''.join(self.grid[y]) + '\n'
        return s
    
    def __next__(self):
        return iter(self)
        
    def __iter__(self):
        new_grid = copy.deepcopy(self.grid.copy())
        for y in range(self.height):
            for x in range(self.width):
                num_neighbours = self.count_neighbours(x, y)
                if self.grid[y][x] == 'L' and num_neighbours == 0:
                    new_grid[y][x] = '#'
                elif self.grid[y][x] == '#' and num_neighbours >= self.neighbour_threshold:
                    new_grid[y][x] = 'L'
        self.grid = new_grid
        return self
        
    def count_neighbours(self, x, y):
        neighbour_cells = (
            (x - 1 , y - 1), (x + 0 , y - 1), (x + 1 , y - 1),
            (x - 1 , y + 0),                  (x + 1 , y + 0),
            (x - 1 , y + 1), (x + 0 , y + 1), (x + 1 , y + 1),
        )
        return sum(
            self.grid[y][x] == '#'
            for x, y in neighbour_cells
            if 0 <= x < self.width and 0 <= y < self.height
        )
    
    def iterate_until_stable(self):
        before = len(self)
        for _ in self:
            after = len(self)
            if before == after:
                break
            before = after

In [2]:
grid = Grid('input')
grid.iterate_until_stable()

print("Part 1:")
print(len(grid))

Part 1:
2483


In [3]:
class Grid2(Grid):
    def count_neighbours(self, x, y):
        deltas = (
            (-1, -1), (-1, 0), (-1, 1),
            (0, -1), (0, 1),
            (1, -1), (1, 0), (1, 1),
        )
        neighbours = 0
        for delta in deltas:
            dx, dy = x, y
            while True:
                dx += delta[0]
                dy += delta[1]
                if 0 <= dx < self.width and 0 <= dy < self.height:
                    if self.grid[dy][dx] == '#':
                        neighbours += 1
                        break
                    elif self.grid[dy][dx] == 'L':
                        break
                else:
                    break
        return neighbours

In [4]:
grid2 = Grid2('input', neighbour_threshold=5)
grid2.iterate_until_stable()
print("Part 2:")
print(len(grid2))

Part 2:
2285
