In [3]:
#!/usr/bin/env python3
# Implementing Algorithm 1 from the tech report: https://www2.eecs.berkeley.edu/Pubs/TechRpts/2021/EECS-2021-10.pdf
# paper (seems very different from tech report): https://arxiv.org/pdf/2107.13477
# Advanced tutorial (Z3): https://ece.uwaterloo.ca/~agurfink/stqam/z3py-advanced
# Reference manual (Z3): https://z3prover.github.io/api/html/namespacez3py.html#ab8fb082f1c350596ec57ea4a4cd852fd
from z3 import *

def smt(e): 
    """Run expression e in an SMT Solver and return the answer"""
    s = Solver()
    s.add([e])
    return s

# https://stackoverflow.com/a/12600208/5305365 
def z3_to_py(v):
    """Convert Z3 value to corresponding python value"""
    if is_bool(v):
        return is_true(v)
    if is_int_value(v):
        return v.as_long()
    raise RuntimeError("unknown z3 value to be coerced |%s|" % (v, ))

#thetai: uninterpreted function
#Ii: oracle function
#rho: expression
#model: SMT model
def consistent(thetai, Ii, rho, model): 
    """
    Check if all applocations uninterpreted function thetai (representing oracle Ii) agrees with Ii
    in all invocations of thetai in constraint rho with respect to the current model
    """
    alpha_total = True 
    # for app in applications of thetai in rho:
    for app in rho.children():
        if is_app(app) and app.decl() == thetai: 
                model_val = model.eval(app)
                arg_vals_in_model = [z3_to_py(model.eval(arg)) for arg in app.children()]
                oracle_val = Ii(*arg_vals_in_model)
                model_app_val = model.eval(thetai(*app.children()))

                # current constraint
                concrete_app = thetai(*arg_vals_in_model)
                alpha_new = concrete_app == oracle_val
                # print("oracle_val: %s | appval: %s ~= %s | not-equal? %s" % (oracle_val, model_app_val, z3_to_py(model_app_val), oracle_val != z3_to_py(model_app_val)))
                if oracle_val != z3_to_py(model_app_val):
                    # TODO: consider And(alpha_total, alpha_new)
                    # print("early returning")
                    return False, alpha_new
                alpha_total = And(alpha_total, alpha_new)
                # alphacur = 
                # print("app: %s | modelval: %s(%s) = %s | oracleval: %s " % (app, appval , arg_vals_in_model, model_val, oracle_val))
        is_consistent, alpha_new = consistent(thetai, Ii, app, model)
        if not is_consistent:
            return is_consistent, alpha_new
        else:
            alpha_total = And(alpha_total, alpha_new)
    return True, alpha_total

def check(xs, thetas, rho, Is):
    """create a model given variables xs, uninterpreted functions thetas, constraint rho, oracles Is"""
    alpha = True
    while True:
        success = True
        s = smt(And(rho,  alpha)) 
        if s.check() == unsat:
            return (False, alpha)
        else:
            model = s.model()
            for i in range(len(thetas)):
                is_consistent, alphanew = consistent(thetas[i], Is[i], rho, model)
                print("is_consistent: %s | alphanew: %s" % (is_consistent, alphanew))
                alpha = And(alpha, alphanew)
                if not is_consistent:
                    success = False
                    break   
            if success:
                return (True, model)


def isPrime(x):
    """Function that is to be modeled"""
    if x < 2: return False
    if x == 2: return True
    for d in range(2, x//2 + 1):
        if x % d == 0:
            return False
    return True
# U = uninterpreted
isPrimeU = Function('isPrime', IntSort(), BoolSort())
x = Int('x')
y = Int('y')
z = Int('z')
query = And(x * y * z == 76, isPrimeU(x) == True, isPrimeU(y) == True, isPrimeU(z) == True) 

# variables: [x, y, z]
# prime uninterpreted function: isPrimeU
# query: find prime factorization of 76 into three primes
# isPrime: real oracle.
success, out = check([x, y, z], [isPrimeU], query, [isPrime])
print("####OUTPUT####\n")
print("success: %s\nout:%s" % (success, out))


queryFalse = And(x * y * z == 100, isPrimeU(x) == True, isPrimeU(y) == True, isPrimeU(z) == True) 
success, out = check([x, y, z], [isPrimeU], queryFalse, [isPrime])
print("****OUTPUT TWO****\n")
print("success: %s\nout:%s" % (success, out))


is_consistent: False | alphanew: isPrime(76) == False
is_consistent: False | alphanew: isPrime(-38) == False
is_consistent: False | alphanew: isPrime(-76) == False
is_consistent: False | alphanew: isPrime(38) == False
is_consistent: True | alphanew: And(And(And(And(True,
                And(And(True,
                        And(And(True,
                                And(And(True, True), True)),
                            True)),
                    True)),
            And(And(And(True, isPrime(19) == True),
                    And(True, True)),
                True)),
        And(And(And(True, isPrime(2) == True),
                And(True, True)),
            True)),
    And(And(And(True, isPrime(2) == True), And(True, True)),
        True))
####OUTPUT####

success: True
out:[x = 19,
 z = 2,
 y = 2,
 isPrime = [else -> Or(Var(0) == 2, Var(0) == 19)]]
is_consistent: False | alphanew: isPrime(100) == False
is_consistent: False | alphanew: isPrime(-50) == False
is_consistent: False | 