In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import SubplotSpec
import pickle
import itertools
from math import ceil, log

In [None]:
methods = ['O3DKnn', 'O3DKnnCPU']
files = []
for m in methods:
    with open(f'{m}.pkl', 'rb') as f:
        file = pickle.load(f)
        files.append(file)

In [None]:
experiments = list(files[0].keys())
exp_k1 = [e for e in experiments if "k=1" in e]
exp_k37 = [e for e in experiments if "k=37" in e]
exp_k64 = [e for e in experiments if "k=64" in e]

In [None]:
def create_subtitle(fig: plt.Figure, grid: SubplotSpec, title: str):
    "Sign sets of subplots with title"
    row = fig.add_subplot(grid)
    row.set_title(f'{title}\n', fontweight='semibold')
    row.set_frame_on(False)
    row.axis('off')

xaxis_label = "num_points"
yaxis_label = "log(sec)"

datasets = ['small_tower', 'kitti_1', 'kitti_2', 'fluid_1000', 's3dis_1', 's3dis_2']
knns = [1 ,37 ,64]
combinations = list(itertools.product(datasets, knns))

rows = len(combinations)
cols = 3
fig, axs = plt.subplots(rows, cols, figsize=(18, 4 * rows))
grid = plt.GridSpec(rows, cols)

for i, (d, k) in enumerate(combinations):
    ax1 = axs[i][0]
    ax2 = axs[i][1]
    ax3 = axs[i][2]
    exp = list(filter(lambda x: d in x and f"k={k}" in x, experiments))
    title = exp[0].split(" ")
    title = " ".join([title[0], title[2]])
    create_subtitle(fig, grid[i, ::], title)
    for file, method in zip(files, methods):
        xs, t_setup, t_search, t_total = [], [], [], []
        for e in exp:
            f = file[e]
            xs.append(f['num_points'])
            _t_setup = np.median(f['knn_setup'])
            _t_search = np.median(f['knn_search'])
            _t_total = _t_setup + _t_search
            t_setup.append(_t_setup)
            t_search.append(_t_search)
            t_total.append(_t_total)
        t_setup = np.log(t_setup)
        t_search = np.log(t_search)
        t_total = np.log(t_total)
        ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x,_: f"{int(x/1000)}K"))
        ax1.set_title("setup time")
        ax1.plot(xs, t_setup, label=method)
        ax1.set_xlabel(xaxis_label)
        ax1.set_ylabel(yaxis_label)
        
        ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x,_: f"{int(x/1000)}K"))
        ax2.set_title("search time")
        ax2.set_xlabel(xaxis_label)
        ax2.set_ylabel(yaxis_label)
        ax2.plot(xs, t_search, label=method)
        ax2.legend()
        
        ax3.xaxis.set_major_formatter(plt.FuncFormatter(lambda x,_: f"{int(x/1000)}K"))
        ax3.set_title("total time")
        ax3.set_xlabel(xaxis_label)
        ax3.set_ylabel(yaxis_label)
        ax3.plot(xs, t_total, label=method)

fig.tight_layout()
fig.set_facecolor('w')
plt.savefig("nns_benchmark")

In [None]:
rows = 3
cols = 3
fig, axs = plt.subplots(rows, cols, figsize=(18, 4 * rows))
grid = plt.GridSpec(rows, cols)
for i, (exp, title) in enumerate(
        zip([exp_k1, exp_k37, exp_k64], ["knn=1", "knn=37", "knn=64"])):
    ax1 = axs[i][0]
    ax2 = axs[i][1]
    ax3 = axs[i][2]
    create_subtitle(fig, grid[i, ::], title)
    for file, method in zip(files, methods):
        xs, t_setup, t_search, t_total = [], [], [], []
        for e in exp:
            f = file[e]
            xs.append(f['num_points'])
            _t_setup = np.median(f['knn_setup'])
            _t_search = np.median(f['knn_search'])
            _t_total = _t_setup + _t_search
            t_setup.append(_t_setup)
            t_search.append(_t_search)
            t_total.append(_t_total)

        idx_sort = np.argsort(xs)
        xs = np.array(xs)[idx_sort]
        t_setup = np.log(t_setup)[idx_sort]
        t_search = np.log(t_search)[idx_sort]
        t_total = np.log(t_total)[idx_sort]
        ax1.set_title("setup time")
        ax1.plot(xs, t_setup, label=method)
        ax1.set_xlabel(xaxis_label)
        ax1.set_ylabel(yaxis_label)
        ax2.set_title("search time")
        ax2.set_xlabel(xaxis_label)
        ax2.set_ylabel(yaxis_label)
        ax2.plot(xs, t_search, label=method)
        ax2.legend()
        ax3.set_title("total time")
        ax3.set_xlabel(xaxis_label)
        ax3.set_ylabel(yaxis_label)
        ax3.plot(xs, t_total, label=method)

fig.tight_layout()
fig.set_facecolor('w')