In [20]:
from z3 import *
import z3
import pandas as pd
import time
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from itertools import islice

In [54]:
def parse_queries_from_klee_smt2_dump(path: str):
    queries = []
    with open(path) as f:
        current = ""
        for line in f:
            if line == "(check-sat)":
                continue
            if line.strip() == "(reset)":
                queries.append(z3.parse_smt2_string(current))
                current = ""
            current += line
    return queries

queries = parse_queries_from_klee_smt2_dump("who-cache.smt2")
print(f"{len(queries)} queries were loaded.")            

3090 queries were loaded.


In [55]:
def calc_distance_keeping_constraint_order(first: list, second: list):
    """
    Calculates a distance between two queries while not trying to change the order
    of constraints inside them.
    The returned distance mimics the cost of pushing and popping when the second query
    is solved right after the first query.
    """
    i = 0
    while i < len(first) and i < len(second) and first[i].eq(second[i]):
        i += 1

    return (i, len(first) - i + len(second) - i)
    # If you believe that there's no cost in popping, you can use the following
    # formula. But personally, I see each pop as a negative thing and losing a
    # progress.
    # return (i, len(second) - i)

def calc_distances_for(query_index: int, queries, calc_distance):
    query = queries[query_index]
    distances = [0] * len(queries)
    for j in range(len(queries)):
        distances[j] = calc_distance(query, queries[j])
    return distances

def calc_distances_for_map(query_index: int):
    return calc_distances_for(query_index, queries, calc_distance_keeping_constraint_order)

with ProcessPoolExecutor() as executor:
    distances = list(tqdm(executor.map(calc_distances_for_map, range(len(queries))), total=len(queries)))

common_prefix_lens = [[col[0] for col in row] for row in distances]
distances = [[col[1] for col in row] for row in distances]

100%|██████████| 3090/3090 [00:14<00:00, 212.53it/s]


In [56]:
def get_tsp_circuit(distances):
    import ortools.constraint_solver.pywrapcp as pywrapcp
    from ortools.constraint_solver import routing_enums_pb2

    manager = pywrapcp.RoutingIndexManager(len(distances), 1, 0)
    routing = pywrapcp.RoutingModel(manager)

    def distance_callback(from_index, to_index):
        """Returns the distance between the two nodes."""
        # Convert from routing variable Index to distance matrix NodeIndex.
        from_node = manager.IndexToNode(from_index)
        to_node = manager.IndexToNode(to_index)
        return distances[from_node][to_node]

    transit_callback_index = routing.RegisterTransitCallback(distance_callback)
    routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
    search_parameters = pywrapcp.DefaultRoutingSearchParameters()
    search_parameters.first_solution_strategy = (
        routing_enums_pb2.FirstSolutionStrategy.AUTOMATIC)

    def print_solution(manager, routing, solution):
        """Prints solution on console."""
        print('Objective: {} miles'.format(solution.ObjectiveValue()))
        index = routing.Start(0)
        plan_output = 'Route for vehicle 0:\n'
        route_distance = 0
        while not routing.IsEnd(index):
            plan_output += ' {} ->'.format(manager.IndexToNode(index))
            previous_index = index
            index = solution.Value(routing.NextVar(index))
            route_distance += routing.GetArcCostForVehicle(previous_index, index, 0)
        plan_output += ' {}\n'.format(manager.IndexToNode(index))
        print(plan_output)
        plan_output += 'Route distance: {}miles\n'.format(route_distance)

    def get_vector(manager, routing, solution):
        index = routing.Start(0)
        route = []
        while not routing.IsEnd(index):
            route.append(manager.IndexToNode(index))
            index = solution.Value(routing.NextVar(index))

        return route

    solution = routing.SolveWithParameters(search_parameters)
    # if solution:
    #     print_solution(manager, routing, solution)

    return get_vector(manager, routing, solution)

best_case = get_tsp_circuit(distances)

In [57]:
from typing import List


def check_by_resetting(queries, solver=None):
    solver = solver if solver is not None else Solver()
    results = {}
    
    total_time = 0.0
    
    for i, query in enumerate(tqdm(queries)):
        start_time = time.perf_counter()
        result = solver.check(query)
        end_time = time.perf_counter()
        total_time += end_time - start_time
        results[i] = result
    
    return (results, solver.statistics(), total_time)

def check_incrementally(queries, ordering: List[int], solver=None):
    solver = solver if solver is not None else Solver()
    results = {}

    queries = queries + [queries[ordering[-1]]]
    ordering = ordering + [ordering[-1]]
    
    # last_index = len(queries) - 1
    current_stack_count = 0

    total_time = 0.0
    total_pops = 0
    total_pushes = 0

    for index, next_index in tqdm(zip(ordering[:len(ordering) - 1], ordering[1:]), total=len(queries) - 1):
        if index in results:
            print("Warning: Skipping repeated query. Query index =", index)
            continue
        # last_query = queries[last_index]
        query = queries[index]
        
        next_prefix_len = common_prefix_lens[index][next_index]

        start_time = time.perf_counter()

        if current_stack_count < next_prefix_len:
            for i in range(current_stack_count, next_prefix_len):
                solver.push()
                solver.add(query[i])
            total_pushes += next_prefix_len - current_stack_count
            current_stack_count = next_prefix_len

        solver.push()
        total_pushes += 1
        solver.add(query[current_stack_count:])
        
        result = solver.check()

        solver.pop()
        total_pops +=1
        
        if current_stack_count > next_prefix_len:
            solver.pop(current_stack_count - next_prefix_len)
            total_pops += current_stack_count - next_prefix_len
            current_stack_count = next_prefix_len

        end_time = time.perf_counter()
        total_time += end_time - start_time
        results[index] = result

        # last_index = index
    
    return (results, solver.statistics(), total_time, (total_pushes, total_pops, sum(len(q) for q in queries)))

In [58]:
# Base case
base_results, statistics, total_time = check_by_resetting(queries)
print("Spent time:", total_time)

100%|██████████| 3090/3090 [00:13<00:00, 222.32it/s]

Spent time: 13.829318084637634





In [59]:
# Best case
best_results, statistics, total_time, (total_pushes, total_pops, total_constraints) = check_incrementally(queries, best_case)
assert(base_results == best_results)
print(total_time)
print(total_pops, total_pushes, total_constraints)

100%|██████████| 3090/3090 [00:09<00:00, 341.15it/s]

9.01407857844606
155242 155524 680042





In [60]:
REPEAT_COUNT = 5
total_time = 0.0
solver = Solver()
for i in range(REPEAT_COUNT):
    _, _, spent_time = check_by_resetting(queries, solver=solver)
    total_time += spent_time
    solver.reset()
print("Average time for resetting:", total_time / REPEAT_COUNT)

total_time = 0.0
for i in range(REPEAT_COUNT):
    _, _, spent_time, _ = check_incrementally(queries, best_case, solver=solver)
    total_time += spent_time
    solver.reset()
print("Average time for incremental:", total_time / REPEAT_COUNT)

100%|██████████| 3090/3090 [00:14<00:00, 214.53it/s]
100%|██████████| 3090/3090 [00:14<00:00, 217.69it/s]
100%|██████████| 3090/3090 [00:14<00:00, 211.92it/s]
100%|██████████| 3090/3090 [00:14<00:00, 215.19it/s]
100%|██████████| 3090/3090 [00:14<00:00, 217.02it/s]


Average time for resetting: 14.285437431070022


100%|██████████| 3090/3090 [00:08<00:00, 345.28it/s]
100%|██████████| 3090/3090 [00:09<00:00, 341.59it/s]
100%|██████████| 3090/3090 [00:09<00:00, 331.55it/s]
100%|██████████| 3090/3090 [00:12<00:00, 255.47it/s]
100%|██████████| 3090/3090 [00:11<00:00, 273.73it/s]

Average time for incremental: 10.081214405503124



