In [1]:
from z3 import *
import z3
import pandas as pd
import time

In [2]:
def parse_queries_from_klee_smt2_dump(path: str):
    queries = []
    with open(path) as f:
        current = ""
        for line in f:
            if line == "(check-sat)":
                continue
            if line.strip() == "(reset)":
                queries.append(z3.parse_smt2_string(current))
                current = ""
            current += line
    return queries

queries = parse_queries_from_klee_smt2_dump("query_dump.smt2")
print(f"{len(queries)} queries were loaded.")            

61 queries were loaded.


In [3]:
def calc_distance_keeping_constraint_order(first: list, second: list):
    """
    Calculates a distance between two queries while not trying to change the order
    of constraints inside them.
    The returned distance mimics the cost of pushing and popping when the second query
    is solved right after the first query.
    """
    i = 0
    while i < len(first) and i < len(second) and first[i].eq(second[i]):
        i += 1

    return len(first) - i + len(second) - i
    # If you believe that there's no cost in popping, you can use the following
    # formula. But personally, I see each pop as a negative thing and losing a
    # progress.
    # return len(second) - i

distances = [[0] * len(queries)] * len(queries)
for i in range(len(queries)):
    for j in range(len(queries)):
        distances[i][j] = calc_distance_keeping_constraint_order(queries[i], queries[j])

In [4]:
def get_tsp_circuit(distances):
    import ortools.constraint_solver.pywrapcp as pywrapcp
    from ortools.constraint_solver import routing_enums_pb2

    manager = pywrapcp.RoutingIndexManager(len(distances), 1, 0)
    routing = pywrapcp.RoutingModel(manager)

    def distance_callback(from_index, to_index):
        """Returns the distance between the two nodes."""
        # Convert from routing variable Index to distance matrix NodeIndex.
        from_node = manager.IndexToNode(from_index)
        to_node = manager.IndexToNode(to_index)
        return distances[from_node][to_node]

    transit_callback_index = routing.RegisterTransitCallback(distance_callback)
    routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
    search_parameters = pywrapcp.DefaultRoutingSearchParameters()
    search_parameters.first_solution_strategy = (
        routing_enums_pb2.FirstSolutionStrategy.AUTOMATIC)

    def print_solution(manager, routing, solution):
        """Prints solution on console."""
        print('Objective: {} miles'.format(solution.ObjectiveValue()))
        index = routing.Start(0)
        plan_output = 'Route for vehicle 0:\n'
        route_distance = 0
        while not routing.IsEnd(index):
            plan_output += ' {} ->'.format(manager.IndexToNode(index))
            previous_index = index
            index = solution.Value(routing.NextVar(index))
            route_distance += routing.GetArcCostForVehicle(previous_index, index, 0)
        plan_output += ' {}\n'.format(manager.IndexToNode(index))
        print(plan_output)
        plan_output += 'Route distance: {}miles\n'.format(route_distance)

    def get_vector(manager, routing, solution):
        index = routing.Start(0)
        route = []
        while not routing.IsEnd(index):
            route.append(manager.IndexToNode(index))
            index = solution.Value(routing.NextVar(index))

        return route

    solution = routing.SolveWithParameters(search_parameters)
    if solution:
        print_solution(manager, routing, solution)

    return get_vector(manager, routing, solution)

best_case = get_tsp_circuit(distances)

Objective: 629 miles
Route for vehicle 0:
 0 -> 60 -> 23 -> 2 -> 1 -> 5 -> 3 -> 46 -> 16 -> 12 -> 6 -> 4 -> 56 -> 31 -> 29 -> 28 -> 27 -> 26 -> 25 -> 22 -> 19 -> 15 -> 14 -> 13 -> 10 -> 7 -> 54 -> 52 -> 50 -> 49 -> 44 -> 43 -> 41 -> 38 -> 36 -> 35 -> 34 -> 32 -> 30 -> 20 -> 18 -> 17 -> 11 -> 8 -> 59 -> 58 -> 57 -> 55 -> 53 -> 51 -> 48 -> 47 -> 45 -> 42 -> 40 -> 39 -> 37 -> 33 -> 24 -> 21 -> 9 -> 0



In [5]:
from typing import List


def check_by_resetting(queries, solver=None):
    solver = solver if solver is not None else Solver()
    results = {}
    
    total_time = 0.0
    
    for i, query in enumerate(queries):
        start_time = time.perf_counter()
        result = solver.check(query)
        end_time = time.perf_counter()
        total_time += end_time - start_time
        results[i] = result
    
    return (results, solver.statistics(), total_time)

def check_incrementally(queries, ordering: List[int], solver=None):
    solver = solver if solver is not None else Solver()
    results = {}

    last_query = []

    total_time = 0.0
    total_pops = 0
    total_pushes = 0

    for index in ordering:
        if index in results:
            print("Warning: Skipping repeated query. Query index =", index)
            continue
        query = queries[index]
        
        i = 0
        while i < len(last_query) and i < len(query) and last_query[i].eq(query[i]):
            i += 1

        solver.pop(len(last_query) - i)
        total_pops += len(last_query) - i

        for i in range(i, len(query)):
            solver.push()
            total_pushes += 1
            solver.add(query[i])
        
        start_time = time.perf_counter()
        result = solver.check()
        end_time = time.perf_counter()
        total_time += end_time - start_time
        results[index] = result

        last_query = query
    
    return (results, solver.statistics(), total_time, (total_pushes, total_pops, sum(len(q) for q in queries)))

In [6]:
# Base case
base_results, statistics, total_time = check_by_resetting(queries)
print("Spent time:", total_time)

Spent time: 0.06405430100858212


In [7]:
# Best case
best_results, statistics, total_time, (total_pushes, total_pops, total_constraints) = check_incrementally(queries, best_case)
assert(base_results == best_results)
print(total_time)
print(total_pops, total_pushes, total_constraints)

0.07369697460671887
240 246 287


In [9]:
REPEAT_COUNT = 100
total_time = 0.0
solver = Solver()
for i in range(REPEAT_COUNT):
    _, _, spent_time = check_by_resetting(queries, solver=solver)
    total_time += spent_time
    solver.reset()
print("Average time for resetting:", total_time / REPEAT_COUNT)

total_time = 0.0
for i in range(REPEAT_COUNT):
    _, _, spent_time, _ = check_incrementally(queries, best_case, solver=solver)
    total_time += spent_time
    solver.reset()
print("Average time for incremental:", total_time / REPEAT_COUNT)

Average time for resetting: 0.0601999520737445
Average time for incremental: 0.07666392293991521
