In [6]:
import z3
from golz3 import *
from typing import *
import itertools

"""
Things that can be used as ints. Either symbolically or concretely.
"""
Inty = ArithRef | int

"""
 Game of Life rule sketch function signature is:
 func(count, pre) -> post_state
 - count: Number of alive neighbors.
 - pre: 1 if this cell was alive last turn, 0 otherwise.
 - post_state: 1 if cell should be alive now, 0 otherwise.
"""
FuncSketch = Callable[[Inty, Inty], ArithRef]

def find_rules(func_sketch: FuncSketch, pre: ConcSlice, post: ConcSlice) -> z3.Solver:
    """
    Fill in holes in the provided Game of Life transition function sketch
    that would match the provided pre and post states.

    "pre" and "post" should be 2D matrices of the same size with concrete
    values.

    Args:
        func_sketch (FuncSketch): Function that takes in a count and pre-state
            (both of which might be Z3 expressions) and returns a Z3 expression
            which should be in the range [0, 1] (representing the new cell state).
        pre (ConcSlice): Set of Z3 variables representing PRE state of board.
        post (ConcSlice): Set of Z3 variables representing the POST state of board.
    
    Returns:
        (Solver): solver that's been set up to solve the problem. Just run `check` on it.
            (Or inspect any values).
    """
    solver = z3.Solver()

    NROWS = len(pre)
    NCOLS = len(pre[0])

    # Vars gets replaced . Boilerplate for setting up vars shamelessly stolen
    # from Matthew's stuff in more_life.ipynb.
    vars = [pre, post]

    for t in range(0, 2):
        for i, j in itertools.product(range(NROWS), range(NCOLS)):
            # Each cell is either 0 or 1. (Not a solver query anymore since we assert
            # that concrete values satisfy this property)
            # solver.add(Or(vars[t][i][j] == 0, vars[t][i][j] == 1))
            if (vars[t][i][j] != 0 and vars[t][i][j] != 1):
                raise ValueError("All vars must be 0 or 1")

            # For the t = 0 case there's no way to link with previous steps.
            if t == 0:
                continue

            # Assert that our function sketch correctly relates the pre and post states.

            # `count` counts the number of alive neighbors cell[i][j] has.
            count = 0
            for di, dj in itertools.product(range(-1, 2), range(-1, 2)):
                if di == dj == 0:
                    # Don't count ourselves as neighbor
                    continue
                if 0 <= i + di < NROWS and 0 <= j + dj < NCOLS:
                    count += vars[t - 1][i + di][j + dj]

            prev = vars[t - 1][i][j]
            next = vars[t    ][i][j]
            func_result = func_sketch(count, prev)

            # Require that sketch describes transition appropriately.
            # (tenzinhl): I realize, that if we modify `make_life` to instead take in
            # a function sketch that would allow better code reuse across these two, since
            # a good amount of the boilerplate is repeated lol.
            solver.add(func_result == next)

    return solver


In [11]:
def simple_equals_sketch(count: Inty, pre: Inty,
                        stay_alive_count:Any=Int('stay_alive_count'),
                        come_alive_count:Any=Int('come_alive_count')):
    """
    Simple count == transition function sketch. Cell will stay alive when
    neighbors is certain count, will come alive when at another.

    stay_alive_count and come_alive_count are intended to be the holes that
    are filled in. However we let them be params so its easy to substitute
    in concrete values for printing the filled in sketch.
    """
    return If(pre == 1,
            If(count == stay_alive_count, 1, 0),
            If(count == come_alive_count, 1, 0))

# See what rule it discovers for simple transition.
pre = [
    [1, 1, 1],
    [1, 0, 1],
    [1, 1, 1],
]

post = [
    [0, 0, 0],
    [0, 1, 0],
    [0, 0, 0],
]

# Requirements for this should be that comes alive if count == 8,
# stay_alive != 2 and stay_alive != 4
solver = find_rules(simple_equals_sketch, pre, post) # type: ignore
print(solver.check())
model = solver.model()
print(f"Model: {model}")

# And we should get valid results! Substitute those back into the sketch and we have
# a filled in synthesis!
filled_in_sketch = simple_equals_sketch(Int('count'), Int('pre'), model[Int('stay_alive_count')], model[Int('come_alive_count')])
print(f"Filled in sketch: {filled_in_sketch}")

sat
Model: [stay_alive_count = 0, come_alive_count = 8]
Filled in sketch: If(pre == 1, If(0 == count, 1, 0), If(8 == count, 1, 0))
