## 🚀 Compare speed of KNN Functions 🚀

In [1]:
%matplotlib inline
import matplotlib
matplotlib.use("module://matplotlib_inline.backend_inline")

In [2]:
from task import (
    our_knn_L2_CUDA,
    our_knn_L2_CUPY,
    our_knn_cosine_CUPY,
    our_knn_dot_CUPY,
    our_knn_L1_CUPY,
    our_knn_l2_triton,
    our_knn_cosine_triton,
    our_knn_dot_triton,
    our_knn_l1_triton,
    our_knn_l2_cpu,
    our_knn_l1_cpu,
    our_knn_cosine_cpu,
    our_knn_dot_cpu,
    test_knn_wrapper
    
)
import numpy as np
import seaborn as sns
import torch, cupy

In [3]:
np.random.seed(1967)
N = 4_000_000
D = 512
A = np.random.rand(N,D).astype(np.float32)
X = np.random.rand(D,)
K = 10

In [4]:
functions = {
    # "L2": [
    #     our_knn_L2_CUDA,
    #     our_knn_L2_CUPY,
    #     our_knn_l2_triton,
    #     our_knn_l2_cpu,
    # ],
    # "L1": [
    #     our_knn_L1_CUPY,
    #     our_knn_l1_triton,
    #     our_knn_l1_cpu
    # ],
    # "Cosine": [
    #     our_knn_cosine_CUPY,
    #     our_knn_cosine_triton,
    #     our_knn_cosine_cpu
    # ],
    "Dot Product": [
        our_knn_dot_CUPY,
        our_knn_dot_triton,
        our_knn_dot_cpu,
    ],
}

In [5]:
results_list = {}
for function_type, function_list in functions.items():
    results_list[function_type]= []
    for function in function_list:
        inner_results_item = []
        # ✅ Clear GPU memory
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        cupy.get_default_memory_pool().free_all_blocks()
        torch.cuda.synchronize()
        #clear GPU memory
        result = test_knn_wrapper(function, N, D, A, X, K, repeat=10)
        inner_results_item.append(result[2])
        results_item = {function.__name__: inner_results_item}
        results_list[function_type].append(results_item)
for function_type, function_list in results_list.items():
    print(function_type)
    for function in function_list:
        for function_name, result in function.items():
            print(f"{function_name}: {result}")
    print()


Running our_knn_dot_CUPY with 4000000 vectors of dimension 512 and K=10 for 10 times.
our_knn_dot_CUPY - Result: [1534894 1390303 3413792 3622658  933929 2306566 1498634   51503 1215858
 3298632], Number of Vectors: 4000000, Dimension: 512, K: 10, 
Time: 1764.895201 milliseconds.

Running our_knn_dot_triton with 4000000 vectors of dimension 512 and K=10 for 10 times.
our_knn_dot_triton - Result: [1534894 1390303 3413792 3622658  933929 2306566 1498634   51503 1215858
 3298632], Number of Vectors: 4000000, Dimension: 512, K: 10, 
Time: 1791.846490 milliseconds.

Running our_knn_dot_cpu with 4000000 vectors of dimension 512 and K=10 for 10 times.
our_knn_dot_cpu - Result: [1534894 1390303 3413792 3622658  933929 2306566 1498634   51503 1215858
 3298632], Number of Vectors: 4000000, Dimension: 512, K: 10, 
Time: 5745.416498 milliseconds.

Dot Product
our_knn_dot_CUPY: [1764.8952007293701]
our_knn_dot_triton: [1791.846489906311]
our_knn_dot_cpu: [5745.416498184204]



In [6]:
# #Print four graphs: Each graph should have 3 lines CUPY, Triton, CPU
# #Graph should be in log-log scale but the x-axis should be power of 2
# #Plotting the results
# import matplotlib.pyplot as plt
# def plot_results(results_list):
#     for function_type, function_list in results_list.items():
#         plt.figure()
#         for function in function_list:
#             for function_name, result in function.items():
#                 plt.plot(vector_sizes, result, label=function_name)
#         plt.xscale('log', base=2)
#         plt.yscale('log')
#         plt.xlabel('Vector Size')
#         plt.ylabel('Time (s)')
#         plt.title(function_type)
#         plt.legend()
#         plt.show()
# plot_results(results_list)
# #Plotting the results


In [7]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns

# Set a modern seaborn style
sns.set_theme(style="whitegrid", font_scale=1.2)

# Optional: use a specific color palette
colors = sns.color_palette("colorblind")  # good for accessibility

def plot_results(results_list):
    for idx, (function_type, function_list) in enumerate(results_list.items()):
        plt.figure(figsize=(8, 7.5))  # bigger, cleaner layout

        color_idx = 0
        for function in function_list:
            for function_name, result in function.items():
                plt.plot(
                    vector_sizes,
                    result,
                    label=function_name.replace('_', ' ').upper(),
                    color=colors[color_idx % len(colors)],
                    linewidth=2.5,
                    marker='o',
                    markersize=5,
                )
                color_idx += 1

        plt.xscale('log', base=2)
        plt.yscale('log')

        # Log ticks with base-2 labels
        plt.gca().xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: f"$2^{{{int(np.log2(x))}}}$"))

        plt.xlabel("Vector Size (log scale)", labelpad=10)
        plt.ylabel("Time (s) (log scale, descending)", labelpad=10)
        plt.gca().invert_yaxis() 
        plt.title(f"Average time to compute {function_type} distance between two random vectors", fontsize=16, weight="bold")

        plt.legend(title="Implementation", loc="best", frameon=True)
        plt.tight_layout()
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
                # Add caption below the plot
        plt.figtext(
            0.5, -0.12,
            "In the case of the GPU accelerated libraries, these timings are inclusive of the memory transfer in the GPU. "
            "As we can see here, CPU performance is better at lower dimensions and scales similarly with CuPy and Triton as the memory increases.\n\n"
            "We note here that this is because there is only one distance calculation being carried out and, despite parallelising across segments within the vectors "
            "and reducing these partial sums, the memory overhead involved means that there is no significant benefit from utilizing the GPU for a single distance calculation.",
            wrap=True,
            ha="center",
            fontsize=10
        )
        plt.show()
plot_results(results_list)

NameError: name 'vector_sizes' is not defined

<Figure size 800x750 with 0 Axes>