In [None]:
import matplotlib.pyplot as plt
import itertools
from faiss.contrib.evaluation import OperatingPoints
from enum import Enum
from bench_fw.benchmark_io import BenchmarkIO as BIO

In [None]:
root = "/checkpoint"
results = BIO(root).read_json("result.json")
results.keys()

In [None]:
results['indices']

In [None]:
class Cost:
    def __init__(self, values):
        self.values = values

    def __le__(self, other):
        return all(v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True))

    def __lt__(self, other):
        return all(v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True))

class ParetoMode(Enum):
    DISABLE = 1  # no Pareto filtering
    INDEX = 2    # index-local optima
    GLOBAL = 3   # global optima


class ParetoMetric(Enum):
    TIME = 0        # time vs accuracy
    SPACE = 1       # space vs accuracy
    TIME_SPACE = 2  # (time, space) vs accuracy

def range_search_recall_at_precision(experiment, precision):
    return round(max(r for r, p in zip(experiment['range_search_pr']['recall'], experiment['range_search_pr']['precision']) if p > precision), 6)

def filter_results(
    results,
    evaluation,
    accuracy_metric, # str or func
    time_metric=None, # func or None -> use default
    space_metric=None, # func or None -> use default
    min_accuracy=0,
    max_space=0,
    max_time=0,
    scaling_factor=1.0,
    
    pareto_mode=ParetoMode.DISABLE,
    pareto_metric=ParetoMetric.TIME,
):
    if isinstance(accuracy_metric, str):
        accuracy_key = accuracy_metric
        accuracy_metric = lambda v: v[accuracy_key]

    if time_metric is None:
        time_metric = lambda v: v['time'] * scaling_factor + (v['quantizer']['time'] if 'quantizer' in v else 0)

    if space_metric is None:
        space_metric = lambda v: results['indices'][v['codec']]['code_size']
    
    fe = []
    ops = {}
    if pareto_mode == ParetoMode.GLOBAL:
        op = OperatingPoints()
        ops["global"] = op
    for k, v in results['experiments'].items():
        if f".{evaluation}" in k:
            accuracy = accuracy_metric(v)
            if min_accuracy > 0 and accuracy < min_accuracy:
                continue
            space = space_metric(v)
            if max_space > 0 and space > max_space:
                continue
            time = time_metric(v)
            if max_time > 0 and time > max_time:
                continue
            idx_name = v['index']
            experiment = (accuracy, space, time, k, v)
            if pareto_mode == ParetoMode.DISABLE:
                fe.append(experiment)
                continue
            if pareto_mode == ParetoMode.INDEX:
                if idx_name not in ops:
                    ops[idx_name] = OperatingPoints()
                op = ops[idx_name]
            if pareto_metric == ParetoMetric.TIME:
                op.add_operating_point(experiment, accuracy, time)
            elif pareto_metric == ParetoMetric.SPACE:
                op.add_operating_point(experiment, accuracy, space)
            else:
                op.add_operating_point(experiment, accuracy, Cost([time, space]))

    if ops:
        for op in ops.values():
            for v, _, _ in op.operating_points:
                fe.append(v)

    fe.sort()
    return fe

In [None]:
def plot_metric(experiments, accuracy_title, cost_title, plot_space=False):
    x = {}
    y = {}
    for accuracy, space, time, k, v in experiments:
        idx_name = v['index']
        if idx_name not in x:
            x[idx_name] = []
            y[idx_name] = []
        x[idx_name].append(accuracy)
        if plot_space:
            y[idx_name].append(space)
        else:
            y[idx_name].append(time)

    #plt.figure(figsize=(10,6))
    plt.yscale("log")
    plt.title(accuracy_title)
    plt.xlabel(accuracy_title)
    plt.ylabel(cost_title)
    marker = itertools.cycle(("o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "X", "D"))    
    for index in x.keys():
        plt.plot(x[index], y[index], marker=next(marker), label=index)
    plt.legend(bbox_to_anchor=(1, 1), loc='upper left')

In [None]:
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 16 cores)")

In [None]:
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")

In [None]:
# index local optima
precision = 0.2
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")

In [None]:
# global optima
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")

In [None]:
def plot_range_search_pr_curves(experiments):
    x = {}
    y = {}
    show = {
        'Flat': None,
    }
    for _, _, _, k, v in fr:
        if ".weighted" in k: # and v['index'] in show:
            x[k] = v['range_search_pr']['recall']
            y[k] = v['range_search_pr']['precision']
    
    plt.title("range search recall")
    plt.xlabel("recall")
    plt.ylabel("precision")
    for index in x.keys():
        plt.plot(x[index], y[index], '.', label=index)
    plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')

In [None]:
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)
plot_range_search_pr_curves(fr)