In [33]:
from collections import namedtuple
from functools import lru_cache
from itertools import product
from dataclasses import dataclass


# life rule, for a 3x3 collection of cells, where E is the centre
def life(a, b, c, d, E, f, g, h, i):
    outer = sum([t.n for t in [a, b, c, d, f, g, h, i]])
    return on if (E.n and outer == 2) or outer == 3 else off


class Node:
    @classmethod
    def base(cls, n):
        node = Node()
        node.k = 0
        node.n = n
        node.hash = n
        Node.cache[node.hash] = node
        return node

    @classmethod
    def join(cls, a, b, c, d):
        nhash = (
            k
            + 313183 * a.hash
            + 3158351985019 * b.hash
            + 897311087 * c.hash
            + 43184901804 * d.hash
        ) & ((1 << 64) - 1)
        if nhash not in Node.cache:
            node = Node()
            node.k = node.a + 1
            node.n = a.n + b.n + c.n + d.n
            node.a, node.b, node.c, node.d = a, b, c, d
            Node.cache[nhash] = Node
        return node

    @classmethod
    def get_zero(cls, k):
        if k == 0:
            return Node.off
        else:
            return Node.join(
                get_zero(k - 1), get_zero(k - 1), get_zero(k - 1), get_zero(k - 1)
            )

    def centre(self):
        z = get_zero(self.k - 1)
        return Node.join(
            Node.join(z, z, z, m.a),
            Node.join(z, z, m.b, z),
            Node.join(z, m.c, z, z),
            Node.join(m.d, z, z, z),
        )

    def subnode(self):
        return Node.join(self.a.d, self.b.c, self.c.b, self.d.a)

    def life_4x4(self):
        assert self.k == 2
        m = self
        na = life(m.a.a, m.a.b, m.b.a, m.a.c, m.a.d, m.b.c, m.c.a, m.c.b, m.d.a)  # AD
        nb = life(m.a.b, m.b.a, m.b.b, m.a.d, m.b.c, m.b.d, m.c.b, m.d.a, m.d.b)  # BC
        nc = life(m.a.c, m.a.d, m.b.c, m.c.a, m.c.b, m.d.a, m.c.c, m.c.d, m.d.c)  # CB
        nd = life(m.a.d, m.b.c, m.b.d, m.c.b, m.d.a, m.d.b, m.c.d, m.d.c, m.d.d)  # DA
        return Node.join(na, nb, nc, nd)

    def successor(self):
        m = self
        if m in Node.successors:
            return Node.successors[m]
        if m.n == 0:  # empty
            return m.a
        elif m.k == 2:  # base case
            s = life_4x4(m)
        else:
            c1 = Node.join(m.a.a, m.a.b, m.a.c, m.a.d).successor()
            c2 = Node.join(m.a.b, m.b.a, m.a.d, m.b.c).successor()
            c3 = Node.join(m.b.a, m.b.b, m.b.c, m.b.d).successor()
            c4 = Node.join(m.a.c, m.a.d, m.c.a, m.c.b).successor()
            c5 = Node.join(m.a.d, m.b.c, m.c.b, m.d.a).successor()
            c6 = Node.join(m.b.c, m.b.d, m.d.a, m.d.b).successor()
            c7 = Node.join(m.c.a, m.c.b, m.c.c, m.c.d).successor()
            c8 = Node.join(m.c.b, m.d.a, m.c.d, m.d.c).successor()
            c9 = Node.join(m.d.a, m.d.b, m.d.c, m.d.d).successor()

            s = Node.join(
                (Node.join(c1, c2, c4, c5).successor()),
                (Node.join(c2, c3, c5, c6).successor()),
                (Node.join(c4, c5, c7, c8).successor()),
                (Node.join(c5, c6, c8, c9).successor()),
            )
        Node.successors[m] = s
        return s

    def __hash__(self):
        return self.hash


Node.cache = {}
Node.successors = {}
# base level binary nodes
Node.on = Node.base(1)
Node.off = Node.base(0)

In [9]:
def advance(m, g):
    """Return the 2**k-1 x 2**k-1 sub node, g generations in the future"""    
    k_bit = (g >> (m.k - 2)) & 1    
    #print(m.k, g, g & ((1 << (m.k - 2)) - 1))
    # no more bits to consider
    if g & ((1 << (m.k - 1)) - 1) == 0:
        return subnode(m)
    if m.n==0: # empty
        return m.a    
    elif m.k == 2:  # base case       
        s = successor(m) if k_bit else subnode(m)        
    else:
        c1 = advance(join(m.a.a, m.a.b, m.a.c, m.a.d), g)
        c2 = advance(join(m.a.b, m.b.a, m.a.d, m.b.c), g)
        c3 = advance(join(m.b.a, m.b.b, m.b.c, m.b.d), g)
        c4 = advance(join(m.a.c, m.a.d, m.c.a, m.c.b), g)        
        c5 = advance(join(m.a.d, m.b.c, m.c.b, m.d.a), g)
        c6 = advance(join(m.b.c, m.b.d, m.d.a, m.d.b), g)
        c7 = advance(join(m.c.a, m.c.b, m.c.c, m.c.d), g)
        c8 = advance(join(m.c.b, m.d.a, m.c.d, m.d.c), g)
        c9 = advance(join(m.d.a, m.d.b, m.d.c, m.d.d), g)
        
        if k_bit:
            s = join(
                successor(join(c1, c2, c4, c5)),
                successor(join(c2, c3, c5, c6)),
                successor(join(c4, c5, c7, c8)),
                successor(join(c5, c6, c8, c9)),
            )  
        else:            
            s = join(
                (join(c1.d, c2.c, c4.b, c5.a)),
                (join(c2.d, c3.c, c5.b, c6.a)),
                (join(c4.d, c5.c, c7.b, c8.a)),
                (join(c5.d, c6.c, c8.b, c9.a)),
            )    

    return s

In [10]:
# pre-generate all 4x4 successors
def product_tree(pieces):
    return [join(a, b, c, d) for a, b, c, d in product(pieces, repeat=4)]

boot_2x2 = product_tree([on, off])
boot_4x4 = product_tree(boot_2x2)
centres = [successor(p) for p in boot_4x4]

TypeError: unhashable type: 'Node'

In [11]:
def construct(pt_list):
    # Force start at (0,0)
    min_x = min(*[x for x, y in pt_list])
    min_y = min(*[y for x, y in pt_list])
    pattern = {(x - min_x, y - min_y): on for x, y in pt_list}

    k = 0
    while len(pattern) != 1:
        # bottom-up construction
        next_level = {}
        z = get_zero(k)
        while len(pattern) > 0:
            x, y = next(iter(pattern))
            x_q, y_q = x - (x & 1), y - (y & 1)
            # read all 2x2 neighbours, removing from those to work through
            # at least one of these must exist by definition
            a = pattern.pop((x_q, y_q), z)
            b = pattern.pop((x_q + 1, y_q), z)
            c = pattern.pop((x_q, y_q + 1), z)
            d = pattern.pop((x_q + 1, y_q + 1), z)
            next_level[x_q >> 1, y_q >> 1] = join(a, b, c, d)
        # merge at the next level
        pattern = next_level
        k += 1
    return pattern.popitem()[1]

In [12]:
def expand(node, x=0, y=0, clip=None, level=0):
    size = 2 ** node.k
    # bounds check
    if clip is not None:
        if x + size < clip[0] or x > clip[1] or y + size < clip[2] or y > clip[3]:
            return []
    if node.k == level:
        # base case: return the gray level of this node
        return [(x, y, node.n / (size ** 2))] if node.n > 0 else []
    else:
        # return all points contained inside this cell
        offset = size >> 1
        return (
            expand(node.a, x=x, y=y, clip=clip, level=level)
            + expand(node.b, x=x + offset, y=y, clip=clip, level=level)
            + expand(node.c, x=x, y=y + offset, clip=clip, level=level)
            + expand(node.d, x=x + offset, y=y + offset, clip=clip, level=level)
        )

In [13]:
def print_points(points):    
    px, py = 0, 0
    for x, y, gray in sorted(points, key=lambda x:(x[1], x[0])):
        while y>py:
            print()
            py += 1
            px = 0
        while x>px:
            print(" ", end="")
            px += 1
        print("*", end="")                            
    

In [14]:
qtree = construct([(0,0), (1,0), (2,0), (0, 1), (2,2)])
print_points(expand((qtree)))

TypeError: unhashable type: 'Node'

In [14]:
def ffwd(node, n):
    
    for i in range(n):
       
        while (node.k < 3 or node.a.n != node.a.d.d.n or
                node.b.n != node.b.c.c.n or
                node.c.n != node.c.b.b.n or
                node.d.n != node.d.a.a.n):
                node = centre(node)    
        node =  successor(node)
    return node

def warp(node, n):
    # ensure a large enough space
    while (2 << node.k) < n:
        node = centre(node)     
    node = centre(centre(node))
    print(node.k)
    return  advance(node, n<<1)    

In [15]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [25]:
           
##expand(warp(qtree, 2048))
%prun ffwd(qtree, 9)

 

In [626]:
join.cache_info()

CacheInfo(hits=331840, misses=65871, maxsize=1048576, currsize=65871)

In [333]:
def validate_tree(qtree):
    if qtree.k > 0:
        assert qtree.a.k == qtree.b.k == qtree.c.k == qtree.b.k == qtree.k - 1
        assert type(qtree).__name__ == "Node"
        assert type(qtree.a).__name__ == "Node"
        assert type(qtree.b).__name__ == "Node"
        assert type(qtree.c).__name__ == "Node"
        assert type(qtree.d).__name__ == "Node"
        validate_tree(qtree.a)
        validate_tree(qtree.b)
        validate_tree(qtree.c)
        validate_tree(qtree.d)

In [334]:
print_points(expand(construct([(1,1), (2,2), (3,3), (4,4)])))

*
 *
  *
   *

In [335]:
def advance_by_bits(node, bits):
    assert len(bits)==node.k
    if bits[0]==0:
        pass
    else:
        pass
    if node.k==k:
        return successor(k)

def advance(node, n):        
    bits = []
    while n > 0:
        bits.append(n & 1)
        n = n >> 1
        node = centre(node)
    bits = ([0] * (node.k - len(bits))) + bits
    advance_by_bits(node, bits)            
    return node

In [482]:
@lru_cache(maxsize=2**20)
def next_gen(m):
    """Return the 2**k-1 x 2**k-1 successor, 2**k-1 generations in the future"""    
    if m.n==0: # empty
        return m.a    
    elif m.k == 2:  # base case               
        s = life_4x4(m)    
    else:
        c1 = successor(join(m.a.a, m.a.b, m.a.c, m.a.d))
        c2 = successor(join(m.a.b, m.b.a, m.a.d, m.b.c))
        c3 = successor(join(m.b.a, m.b.b, m.b.c, m.b.d))
        c4 = successor(join(m.a.c, m.a.d, m.c.a, m.c.b))        
        c5 = successor(join(m.a.d, m.b.c, m.c.b, m.d.a))
        c6 = successor(join(m.b.c, m.b.d, m.d.a, m.d.b))
        c7 = successor(join(m.c.a, m.c.b, m.c.c, m.c.d))
        c8 = successor(join(m.c.b, m.d.a, m.c.d, m.d.c))
        c9 = successor(join(m.d.a, m.d.b, m.d.c, m.d.d))
        
        s = join(
            (join(c1.d, c2.c, c4.b, c5.a)),
            (join(c2.d, c3.c, c5.b, c6.a)),
            (join(c4.d, c5.c, c7.b, c8.a)),
            (join(c5.d, c6.c, c8.b, c9.a)),
        )    
        
        
    return s