In [None]:
from collections import namedtuple

Node = namedtuple("Node", ["k", "a", "b", "c", "d", "succ"])
cache = {}

on = Node(1, None, None, None, None)
off = Node(0, None, None, None, None)

def get_zero(k):
    if k == 0:
        return off
    else:
        return join(get_zero(k - 1), get_zero(k - 1), get_zero(k - 1), get_zero(k - 1))

def centre(m):
    z = get_zero(m.k - 1)  # get the right zero node
    return join(
        join(z, z, z, m.a), join(z, z, m.b, z), join(z, m.c, z, z), join(m.d, z, z, z)
    )


def construct(pattern):
    # pattern: set of (x,y) coordinates. 
    # Must start at (0,0)
    while len(pattern) != 1:
        next_level = {}
        while len(pattern)>0:
            x, y = next(iter(pattern))
            x_q, y_q = x - (x&1), y-(y&1)
            # read all 2x2 neighbours, removing from those to work through
            # at least one of these must exist by definition
            a = pattern.pop((x_q, y_q), off)
            b = pattern.pop((x_q+1, y_q), off)
            c = pattern.pop((x_q, y_q+1), off)
            d = pattern.pop((x_q+1, y_q+1), off)
            next_level[x_q>>1, y_q>>1] = join(a, b, c, d)
        # merge at the next level
        pattern = next_level
    # centre the result with zero padding
    return centre(centre(pattern.popitem()))
    

def join(a, b, c, d):
    k = a.k + 1
    if (a, b, c, d) not in cache:
        cache[(a,b,c,d)] = Node(k, a, b, c, d, succ=succ(k, a, b, c, d))        
        
    return cache[(a,b,c,d)]


def succ(k, a, b, c, d):
    if k == 2:
        return life_4x4(a, b, c, d)
    if k < 2:
        return None

    c1 = join(a.a, a.b, a.c, a.d).succ
    c2 = join(a.b, b.a, a.d, b.c).succ
    c3 = join(b.a, b.b, b.c, b.d).succ
    c4 = join(a.c, a.d, c.a, c.b).succ
    c5 = join(a.d, b.c, c.b, d.a).succ
    c6 = join(b.c, b.d, d.a, d.b).succ
    c7 = join(c.a, c.b, c.c, c.d).succ
    c8 = join(c.b, d.a, c.d, d.c).succ
    c9 = join(d.a, d.b, d.c, d.d).succ

    return join(
        join(c1.d, c2.c, c4.b, c5.a),
        join(c2.a, c3.c, c5.b, c6.b),
        join(c4.d, c5.c, c7.b, c8.a),
        join(c5.d, c6.c, c8.b, c9.a),
    )

# life rule, for a 3x3 collection of cells, where E is the
# centre cell
def life(a, b, c, d, E, f, g, h, i):
    outer = (
        a.value + b.value + c.value + d.value + f.value + g.value + h.value + i.value
    )
    return 1 if (E.value and outer == 2) or outer == 3 else 0


# take a 4x4 block of cells, and compute the 2x2 inner cell
# by applying the life rule to each 3x3 neighbourhood
# e.g. AD is A of inner cell, computed from
#
# AA   AB   BA | BB
# AC  *AD*  BC | BD
# CA   CB   DA | DB
# -------------+
# CC   CD   DC   DD


def life_4x4(a, b, c, d):
    na = life(a.a, a.b, b.a, a.c, a.d, b.c, c.a, c.b, d.a)  # AD
    nb = life(a.b, b.a, b.b, a.d, b.c, b.d, c.b, d.a, d.b)  # BC
    nc = life(a.c, a.d, b.c, c.a, c.b, d.a, c.c, c.d, d.c)  # CB
    nd = life(a.d, b.c, b.d, c.b, d.a, d.b, c.d, d.c, d.d)  # DA
    return join(na, nb, nc, nd)