In [1]:
from z3 import *
import z3
import pandas as pd
import time
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from itertools import islice

In [2]:
def parse_queries_from_klee_smt2_dump(path: str):
    queries = []
    with open(path) as f:
        current = ""
        for line in f:
            if line == "(check-sat)":
                continue
            if line.strip() == "(reset)":
                queries.append(z3.parse_smt2_string(current))
                current = ""
            current += line
    return queries        

In [12]:
queries = parse_queries_from_klee_smt2_dump("echo-cache.smt2")
print(f"{len(queries)} queries were loaded.")

4302 queries were loaded.


In [3]:
def calc_distance_keeping_constraint_order(first: list, second: list, pop_cost, push_cost):
    """
    Calculates a distance between two queries while not trying to change the order
    of constraints inside them.
    The returned distance mimics the cost of pushing and popping when the second query
    is solved right after the first query.
    """
    i = 0
    while i < len(first) and i < len(second) and first[i].eq(second[i]):
        i += 1

    return (i, pop_cost(len(first) - i) + push_cost(len(second) - i))
    # If you believe that there's no cost in popping, you can use the following
    # formula. But personally, I see each pop as a negative thing and losing a
    # progress.
    # return (i, len(second) - i)

def calc_distances_for(query_index: int, queries, calc_distance, pop_cost, push_cost):
    query = queries[query_index]
    distances = [0] * len(queries)
    for j in range(len(queries)):
        distances[j] = calc_distance(query, queries[j], pop_cost, push_cost)
    return distances

pop_cost = lambda x: x
push_cost = lambda x: x

def calc_distances_for_map(query_index: int):
    return calc_distances_for(query_index, queries, calc_distance_keeping_constraint_order, pop_cost, push_cost)

In [13]:
with ProcessPoolExecutor() as executor:
    distances = list(tqdm(executor.map(calc_distances_for_map, range(len(queries))), total=len(queries)))

common_prefix_lens = [[col[0] for col in row] for row in distances]
distances = [[col[1] for col in row] for row in distances]

100%|██████████| 4302/4302 [00:28<00:00, 149.23it/s]


In [5]:
def get_tsp_circuit(distances):
    import ortools.constraint_solver.pywrapcp as pywrapcp
    from ortools.constraint_solver import routing_enums_pb2

    manager = pywrapcp.RoutingIndexManager(len(distances), 1, 0)
    routing = pywrapcp.RoutingModel(manager)

    def distance_callback(from_index, to_index):
        """Returns the distance between the two nodes."""
        # Convert from routing variable Index to distance matrix NodeIndex.
        from_node = manager.IndexToNode(from_index)
        to_node = manager.IndexToNode(to_index)
        return distances[from_node][to_node]

    transit_callback_index = routing.RegisterTransitCallback(distance_callback)
    routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
    search_parameters = pywrapcp.DefaultRoutingSearchParameters()
    search_parameters.first_solution_strategy = (
        routing_enums_pb2.FirstSolutionStrategy.AUTOMATIC)

    def print_solution(manager, routing, solution):
        """Prints solution on console."""
        print('Objective: {} miles'.format(solution.ObjectiveValue()))
        index = routing.Start(0)
        plan_output = 'Route for vehicle 0:\n'
        route_distance = 0
        while not routing.IsEnd(index):
            plan_output += ' {} ->'.format(manager.IndexToNode(index))
            previous_index = index
            index = solution.Value(routing.NextVar(index))
            route_distance += routing.GetArcCostForVehicle(previous_index, index, 0)
        plan_output += ' {}\n'.format(manager.IndexToNode(index))
        print(plan_output)
        plan_output += 'Route distance: {}miles\n'.format(route_distance)

    def get_vector(manager, routing, solution):
        index = routing.Start(0)
        route = []
        while not routing.IsEnd(index):
            route.append(manager.IndexToNode(index))
            index = solution.Value(routing.NextVar(index))

        return route

    solution = routing.SolveWithParameters(search_parameters)
    # if solution:
    #     print_solution(manager, routing, solution)

    return get_vector(manager, routing, solution)

In [7]:
best_case = get_tsp_circuit(distances)

In [11]:
from typing import List


def check_by_resetting(queries, ordering: List[int], common_prefix_lens, enable_direct_subset_answer=False, solver=None):
    solver = solver if solver is not None else Solver()
    results = {}
    
    total_time = 0.0
    
    for index in tqdm(ordering):
        query = queries[index]
        start_time = time.perf_counter()
        result = solver.check(query)
        end_time = time.perf_counter()
        total_time += end_time - start_time
        results[index] = result
    
    return (results, solver.statistics(), total_time, (0, 0, 0))

def check_dummy_incrementally(queries, ordering: List[int], common_prefix_lens, enable_direct_subset_answer=False, solver=None):
    solver = solver if solver is not None else Solver()
    results = {}

    last_query = []

    total_time = 0.0
    total_pops = 0
    total_pushes = 0

    for index in tqdm(ordering):
        if index in results:
            print("Warning: Skipping repeated query. Query index =", index)
            continue
        query = queries[index]
        
        i = 0
        while i < len(last_query) and i < len(query) and last_query[i].eq(query[i]):
            i += 1

        start_time = time.perf_counter()
        
        solver.pop(len(last_query) - i)
        total_pops += len(last_query) - i

        for i in range(i, len(query)):
            solver.push()
            total_pushes += 1
            solver.add(query[i])
        
        result = solver.check()
        end_time = time.perf_counter()
        total_time += end_time - start_time
        results[index] = result

        last_query = query
    
    return (results, solver.statistics(), total_time, (total_pushes, total_pops, sum(len(q) for q in queries)))

def check_incrementally(queries, ordering: List[int], common_prefix_lens, enable_direct_subset_answer=False, solver=None):
    solver = solver if solver is not None else Solver()
    results = {}

    queries = queries + [queries[ordering[-1]]]
    ordering = ordering + [ordering[-1]]
    
    last_index = len(queries) - 1
    current_stack_count = 0

    total_time = 0.0
    total_pops = 0
    total_pushes = 0

    for index, next_index in tqdm(zip(ordering[:len(ordering) - 1], ordering[1:]), total=len(queries) - 1):
        if index in results:
            print("Warning: Skipping repeated query. Query index =", index)
            continue
        # last_query = queries[last_index]
        query = queries[index]
        
        next_prefix_len = common_prefix_lens[index][next_index]

        start_time = time.perf_counter()
        if enable_direct_subset_answer and current_stack_count == len(query) and results[last_index] == sat:
            results[index] = sat
        else:
            if current_stack_count < next_prefix_len:
                for i in range(current_stack_count, next_prefix_len):
                    solver.push()
                    solver.add(query[i])
                total_pushes += next_prefix_len - current_stack_count
                current_stack_count = next_prefix_len

            solver.push()
            total_pushes += 1
            solver.add(query[current_stack_count:])
            
            result = solver.check()

            solver.pop()
            total_pops +=1
                        
            results[index] = result
        
        if current_stack_count > next_prefix_len:
            solver.pop(current_stack_count - next_prefix_len)
            total_pops += current_stack_count - next_prefix_len
            current_stack_count = next_prefix_len

        end_time = time.perf_counter()
        total_time += end_time - start_time
        last_index = index
    
    return (results, solver.statistics(), total_time, (total_pushes, total_pops, sum(len(q) for q in queries)))

In [14]:
# Base case
base_results, statistics, total_time, _ = check_by_resetting(queries, list(range(len(queries))), common_prefix_lens)
print("Spent time:", total_time)

100%|██████████| 4302/4302 [00:14<00:00, 292.31it/s]

Spent time: 14.654909132630564





In [15]:
best_results, statistics, total_time, (total_pushes, total_pops, total_constraints) = check_dummy_incrementally(queries, list(range(len(queries))), common_prefix_lens)
assert(base_results == best_results)
print(total_time)
print(total_pops, total_pushes, total_constraints)

100%|██████████| 4302/4302 [00:15<00:00, 282.27it/s] 

11.20084131823387
334036 334043 679921





In [16]:
best_results, statistics, total_time, (total_pushes, total_pops, total_constraints) = check_incrementally(queries, list(range(len(queries))), common_prefix_lens)
assert(base_results == best_results)
print(total_time)
print(total_pops, total_pushes, total_constraints)

100%|██████████| 4302/4302 [00:08<00:00, 481.63it/s] 

8.891306128469296
149101 149108 679928





In [26]:
# Best case
best_results, statistics, total_time, (total_pushes, total_pops, total_constraints) = check_incrementally(queries, best_case, common_prefix_lens, True)
assert(base_results == best_results)
print(total_time)
print(total_pops, total_pushes, total_constraints)

100%|██████████| 4302/4302 [00:08<00:00, 492.59it/s] 

8.698012300301343
148177 148458 680202





In [19]:

def calc_distances():
    with ProcessPoolExecutor() as executor:
        distances = list(tqdm(executor.map(calc_distances_for_map, range(len(queries))), total=len(queries)))

    common_prefix_lens = [[col[0] for col in row] for row in distances]
    distances = [[col[1] for col in row] for row in distances]
    return common_prefix_lens, distances

def evaluate(name, enable_direct_subset_answer, output_suffix, repeat_count = 5, input_suffix = ".smt2"):
    # Using global because of pickling
    global queries
    queries = parse_queries_from_klee_smt2_dump(name + input_suffix)
    print(f"Evaluating {name} with {len(queries)} queries.")

    print("Calculating distances ...")
    common_prefix_lens, distances = calc_distances()
    
    original_ordering = list(range(len(queries)))
    print("Finding TSP ordering ...")
    tsp_ordering = get_tsp_circuit(distances)
    
    cases = {
        "res_original": (check_by_resetting, original_ordering),
        "dinc_original": (check_dummy_incrementally, original_ordering),
        "incr_original": (check_incrementally, original_ordering),
        "incr_tsp": (check_incrementally, tsp_ordering)
    }

    results = {}
    check_results = []
    for (key, (check, ordering)) in cases.items():
        print(f"Testing {key} for {repeat_count} times ...")

        times = []
        for i in range(repeat_count):
            check_result, statistics, total_time, (total_pushes, total_pops, total_constraints) = check(queries, ordering, common_prefix_lens, enable_direct_subset_answer)
            times.append(total_time)
        check_results.append(check_result)
        print(total_time)
        
        results[key] = {
            "times": times,
            "statistics": str(statistics),
            "check_result": {key: str(value) for key, value in check_results[-1].items()},
            "ordering": ordering,
            "total_pushes": total_pushes,
            "total_pops": total_pops,
            "total_constraints": total_constraints,
        }
    
    assert all(result == check_results[0] for result in check_results)
    
    import json
    file_name = f"{name}_{output_suffix}"
    while os.path.exists(f"{file_name}.json"):
        file_name += "_new"
    with open(f"{file_name}.json", "a") as f:
        json.dump(results, f)

    print("Result saved to:", file_name)

def zero_cost(x):
    return 0

def identity_cost(x):
    return x


In [20]:
push_cost = identity_cost
pop_cost = zero_cost
evaluate("echo-cache", True, "direct_i_0")
evaluate("echo-nocache", True, "direct_i_0")
evaluate("pwd-nocache", True, "direct_i_0")
evaluate("tail-cache", True, "direct_i_0")
evaluate("tail-nocache", True, "direct_i_0")
evaluate("who-cache", True, "direct_i_0")
evaluate("who-nocache", True, "direct_i_0")

Evaluating echo-cache with 4302 queries.
Calculating distances ...


100%|██████████| 4302/4302 [00:28<00:00, 151.80it/s]


Finding TSP ordering ...
