In [14]:
import numpy as np
import six
from itertools import product

In [124]:
DATA = """
#.#..#.#
#.......
####..#.
.#.#.##.
..#..#..
###..##.
.#..##.#
.....#..
""".strip()

In [119]:
DATA = """
.#.
..#
###
""".strip()

In [125]:
class Solver:
    
    def __init__(self, layer):
        self.layout = {}
        for y, row in enumerate(layer.splitlines()):
            for x, value in enumerate(row):
                self.layout[(x, y, 0, 0)] = value == '#'
                
    def read(self, x, y, z, w):
        return self.layout.get((x, y, z, w), False)
    
    def flip(self, x, y, z, w):
        self.layout[(x, y, z, w)] = not self.layout[(x, y, z, w)]
                
    def neighbours(self, x, y, z, w):
        for i, (dx, dy, dz, dw) in enumerate(product((-1, 0, +1), (-1, 0, +1), (-1, 0, +1), (-1, 0, +1))):
            if (dx, dy, dz, dw) == (0, 0, 0, 0): continue
            nx, ny, nz, nw = x + dx, y + dy, z + dz, w + dw
            yield (nx, ny, nz, nw), self.read(nx, ny, nz, nw)
    
    def step(self):
        
        toCreate = set()
        for (x, y, z, w), v in six.iteritems(self.layout):
            if not v: continue
            neighbours = list(self.neighbours(x, y, z, w))
            for (nx, ny, nz, nw), _ in neighbours:
                if (nx, ny, nz, nw) not in self.layout:
                    toCreate.add((nx, ny, nz, nw))
        for (x, y, z, w) in toCreate:
            self.layout[(x, y, z, w)] = False
            
        toFlip = set()
        
        for (x, y, z, w), v in six.iteritems(self.layout):
            neighbours = list(self.neighbours(x, y, z, w))
            numberOfActiveNeighbours = sum(isActive for _, isActive in neighbours)
            if v and numberOfActiveNeighbours not in (2, 3):
                toFlip.add((x, y, z, w))
            elif not v and numberOfActiveNeighbours == 3:
                toFlip.add((x, y, z, w))

        for (x, y, z, w) in toFlip:
            self.flip(x, y, z, w)
            
    def count(self):
        return sum(six.itervalues(self.layout))

In [126]:
z = Solver(DATA)
for _ in range(6):
    z.step()
print(z.count())

1836
