In [1]:
# Imports & read file
import time
import itertools

def read_file(filename):
    with open(filename) as infile:
        lines = [line.strip() for line in infile.readlines()]
    space = {(x,y,0) for y in range(len(lines)) for x in range(len(lines[y])) if lines[y][x] == '#'}
    return space

In [2]:
# Part 1
def neighbors(x, y, z):
    r = [-1, 0, 1]
    return {(x+a,y+b,z+c) for a in r for b in r for c in r if not a == b == c == 0}

def cycle(space, cycles=6):
    nbs = {c: neighbors(*c) for c in space}
    check = space | set.union(*nbs.values())
    for i in range(cycles):
        for c in check:
            if c not in nbs:
                nbs[c] = neighbors(*c)
        space = {c for c in check if (len(nbs[c] & space) == 3) or (c in space and len(nbs[c] & space) == 2)}
        check = space | {n for c in space for n in nbs[c]}
    return space

def print_space(space):
    z_space = {z for (x,y,z) in space}
    z_min, z_max = min(z_space), max(z_space) + 1
    y_space = {y for (x,y,z) in space}
    y_min, y_max = min(y_space), max(y_space) + 1
    x_space = {x for (x,y,z) in space}
    x_min, x_max = min(x_space), max(x_space) + 1
    for z in range(z_min, z_max):
        print("z="+str(z))
        for y in range(y_min, y_max):
            print(''.join('#' if (x,y,z) in space else '.' for x in range(x_min, x_max)))
        print()

In [3]:
# Test part 1
start = time.time()
print(len(cycle(read_file("test01.txt"))) == 112)
time.time() - start

True


0.016699552536010742

In [4]:
# Solve part 1
start = time.time()
print(len(cycle(read_file("input.txt"))))
time.time() - start

362


0.03318643569946289

In [5]:
# Part 2
def read_file_4d(filename):
    with open(filename) as infile:
        lines = [line.strip() for line in infile.readlines()]
    space = {(x,y,0,0) for y in range(len(lines)) for x in range(len(lines[y])) if lines[y][x] == '#'}
    return space

adj_4d = set(itertools.product([-1, 0, 1], repeat=4)) - {(0,0,0,0)}
def neighbors_4d(x, y, z, w):
    return {(x+a,y+b,z+c,w+d) for (a,b,c,d) in adj_4d}

def cycle_4d(space, cycles=6):
    nbs = {c: neighbors_4d(*c) for c in space}
    for i in range(cycles):
        nbc = {}
        for c in space:
            if c not in nbs:
                nbs[c] = neighbors_4d(*c)
            for nb in nbs[c]:
                nbc[nb] = nbc.get(nb, 0) + 1
        space = {c for c,n in nbc.items() if (n == 3) or (n == 2 and c in space)}
    return space

def print_space_4d(space):
    w_space = {w for (x,y,z,w) in space}
    w_min, w_max = min(w_space), max(w_space) + 1
    z_space = {z for (x,y,z,w) in space}
    z_min, z_max = min(z_space), max(z_space) + 1
    y_space = {y for (x,y,z,w) in space}
    y_min, y_max = min(y_space), max(y_space) + 1
    x_space = {x for (x,y,z,w) in space}
    x_min, x_max = min(x_space), max(x_space) + 1
    for w in range(w_min, w_max):
        for z in range(z_min, z_max):
            print("z="+str(z)+", w="+str(w))
            for y in range(y_min, y_max):
                print(''.join('#' if (x,y,z,w) in space else '.' for x in range(x_min, x_max)))
            print()

In [6]:
# Test part 2
start = time.time()
print(len(cycle_4d(read_file_4d("test01.txt"))) == 848)
time.time() - start

True


0.11667847633361816

In [7]:
# Solve part 2
start = time.time()
print(len(cycle_4d(read_file_4d("input.txt"))))
time.time() - start

1980


0.16673493385314941