In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d

from benchmark_utils import timeout, Timer

# random data loader
def dataloader(num_iter, n, d=3, seed=0):
    randg = np.random.RandomState(seed)
    for _ in range(num_iter):
        x = randg.rand(n, d).astype(np.float32)
        y = randg.rand(1000, d).astype(np.float32)
        yield (x, y)

# timeout after 100 seconds
@timeout(100)
def run_single_benchmark(n, k, radius, num_iter, search_method, device, dtype):
    # prepare dataset
    dataset = dataloader(num_iter=num_iter, n=n)
    prepare_timer, match_timer, total_timer = Timer(), Timer(), Timer()
    min_k = min(n, k)

    for i, (x, y) in enumerate(dataset):
        # convert data to open3d tensor
        ref = o3d.core.Tensor(x, device=device, dtype=dtype)
        query = o3d.core.Tensor(y, device=device, dtype=dtype)
        
        total_timer.tic()

        # prepare
        prepare_timer.tic()
        nn_index = o3d.core.nns.NearestNeighborSearch(ref)
        if search_method == 'knn':
            nn_index.knn_index()
        else:
            nn_index.fixed_radius_index(radius)
        prepare_timer.toc()

        # match
        match_timer.tic()
        if search_method == 'knn':
            nn_index.knn_search(query, k)
        else:
            nn_index.fixed_radius_index(query, radius)
        match_timer.toc()
        total_timer.toc()

        sys.stdout.write(
            "\r [%d/%d] N %d, K %d, prepare_time %.4f, matching time %.4f, total_time %.4f"
            % (i + 1, num_iter, n, k, prepare_timer.avg, match_timer.avg,
               total_timer.avg))
        sys.stdout.flush()
    print("")
    return [n, k, prepare_timer.avg, match_timer.avg, total_timer.avg]

# run benchmark
def run_benchmark(opt):
    dtype = o3d.core.Dtype.Float32
    # warm-up
    ref = o3d.core.Tensor(np.random.rand(100, 3), device=opt['device'], dtype=dtype)
    query = o3d.core.Tensor(np.random.rand(100, 3), device=opt['device'], dtype=dtype)
    index = o3d.core.nns.NearestNeighborSearch(ref)
    index.knn_index()
    index.knn_search(query, 1)

    # run benchmark
    n_list = [pow(10, s) for s in range(1, opt['max_scale'] + 1)]
    results = -np.ones((len(n_list), 3))
    for i, n in enumerate(n_list):
        try:
            result = run_single_benchmark(n, opt['knn'], opt['radius'], opt['num_iter'],
                                   opt['search_method'], opt['device'], dtype)
        except TimeoutError:
            print(f"timeout, n: {n}")
            break
        except Exception as e:
            print(e)
            continue
        results[i] = result[-3:]
    return results

# plot graphs
def plot(ths, times, methods):
    fig = plt.figure(figsize=(22, 6))
    titles = ['prepare', 'match', 'total']
    
    for i in range(3):
        ax = fig.add_subplot(1, 3, i + 1)
        ax.set_xlabel('log(num_dataset_points)')
        ax.set_ylabel(f'log({titles[i]} time)(s)')
        for time, method in zip(times, methods):
            time = time[:, i]
            valid = time > 0
            ax.plot(np.log10(ths[valid]), np.log(time[valid]), label=method, marker='x', markevery=[valid.sum()-1])
        plt.legend()

In [None]:
baseline_result = np.load('baseline_results.npz')
ths = baseline_result['ths']
times = baseline_result['times']
methods = baseline_result['methods']

opt = dict(
    device=o3d.core.Device('CUDA:0'),
    max_scale=9,
    knn=1,
    radius=0.1,
    num_iter=2,
    search_method='knn'
)
results = run_benchmark(opt)

times = np.concatenate((times, results[None]), axis=0)
methods = [*methods, 'O3D_latest']
plot(ths, times, methods)