# KNN Optimisation

We will explore how to optimise performance using [JAX](https://docs.jax.dev) and [Numba](https://numba.pydata.org/).
Our goal will be to optimise the performance of a k-nearest neighbours (kNN) search.
We will not focus on the algorithm itself, or the data, but rather on the performance of the code.

In [2]:
import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go


## kNN Search Algorithm

We will use the kNN search (and later the kNN regressor) as a running example.

### kNN Search

The kNN search is a simple algorithm that finds the k nearest neighbours of a query point in a dataset.
It can be described as follows:

1. Calculate the distance between the query point and all points in the dataset.
2. Sort the distances and find the k smallest distances.
3. Return the indices of the k smallest distances.

### Numpy implementation

In [3]:
def euclidean_distances_numpy(query_points: np.ndarray, dataset: np.ndarray) -> np.ndarray:
    """Calculate Euclidean distances between all query points and all dataset points using NumPy.

    We consider the input arrays to be in the shape of (n_points, n_dimensions).

    Args:
        query_points (np.ndarray): Query points (2D array).
        dataset (np.ndarray): Dataset of reference points (2D array).

    Returns:
        np.ndarray: Euclidean distances between query points and dataset points.
    """
    return np.sqrt(np.sum((dataset[:, np.newaxis, :] - query_points) ** 2, axis=-1))


def knn_search_numpy(
    query_points: np.ndarray,
    dataset: np.ndarray,
    k: int,
) -> np.ndarray:
    """
    Finds the k nearest neighbors for a single query point using NumPy.

    Args:
        query_points (np.ndarray): Query points (2D array).
        dataset (np.ndarray): Dataset of reference points (2D array).
        k (int): The number of neighbors to find.

    Returns:
        np.ndarray: Indices of the k nearest neighbors in the dataset.
    """
    distances = euclidean_distances_numpy(query_points, dataset)

    # Find the indices of the k smallest distances
    nearest_indices = np.argpartition(distances, k, axis=0)[:k].T

    return nearest_indices

### Helper plotting function

In [4]:
def visualise_knn(query_points: np.ndarray, dataset: np.ndarray, neighbours: np.ndarray) -> go.Figure:
    # Plot all data points
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=dataset[:, 0],
            y=dataset[:, 1],
            mode="markers",
            marker=dict(size=8, color="lightgrey"),
            name="Dataset Points",
        )
    )

    show_legend = True
    for query_point, point_neighbours in zip(query_points, neighbours):
        # Plot the query point
        fig.add_trace(
            go.Scatter(
                x=[query_point[0]],
                y=[query_point[1]],
                mode="markers",
                marker=dict(size=14, color="red", symbol="x"),
                name="Query Point",
                showlegend=show_legend,
            )
        )

        # Draw lines from query point to each neighbour for clarity
        for neighbour in point_neighbours:
            fig.add_trace(
                go.Scatter(
                    x=[query_point[0], neighbour[0]],
                    y=[query_point[1], neighbour[1]],
                    mode="lines",
                    line=dict(color="royalblue", dash="dot"),
                    showlegend=False,
                )
            )

        # Plot the k nearest neighbours
        fig.add_trace(
            go.Scatter(
                x=point_neighbours[:, 0],
                y=point_neighbours[:, 1],
                mode="markers",
                marker=dict(size=12, color="royalblue", symbol="circle-open"),
                name="Nearest Neighbours",
                showlegend=show_legend,
            )
        )

        show_legend = False

    fig.update_layout(
        title="kNN Search Visualisation",
        xaxis_title="Sepal Width",
        yaxis_title="Sepal Length",
        legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
        width=700,
        height=500,
        # Ensure axes have identical scales for accurate spatial interpretation
        xaxis=dict(
            scaleanchor="y",  # Link x-axis scale to y-axis
            scaleratio=1,     # 1:1 aspect ratio
        ),
        yaxis=dict(
            constrain="domain",  # Prevent stretching of y-axis
        ),
    )

    return fig


### Demonstration

Let's demonstrate the kNN search algorithm on a simple 2D dataset. We will use the popular Iris dataset, and select just the first two features.

In [5]:
# Load a simple 2D dataset (Iris, first two features)
dataset = px.data.iris()[["sepal_width", "sepal_length"]].to_numpy()

# Select a query point (e.g., a random point not in the dataset)
query_points = np.array([[2.5, 6], [3.5, 7.0]])

# Find k nearest neighbours using the previously defined function
neighbour_indices = knn_search_numpy(query_points, dataset, k=3)
neighbours = dataset[neighbour_indices]

print(f"Nearest neighbour indices: {neighbour_indices}")
print(f"Nearest neighbours of {query_points}:\n{neighbours}")

Nearest neighbour indices: [[134  83  92]
 [109  50 120]]
Nearest neighbours of [[2.5 6. ]
 [3.5 7. ]]:
[[[2.6 6.1]
  [2.7 6. ]
  [2.6 5.8]]

 [[3.6 7.2]
  [3.2 7. ]
  [3.2 6.9]]]


In [6]:
visualise_knn(query_points, dataset, neighbours)

### Performance

Let's measure the performance (run time) of our algorithm. We will use the `%timeit` magic command to measure the time it takes to run the algorithm.

In [7]:
n_dataset_points: int = 10_000
n_query_points: int = 100
n_dim: int = 3
k: int = 5


def create_random_data(
    n_points: int, n_dim: int, *, seed: int = 42
) -> np.ndarray:
    np.random.seed(seed)
    return np.random.sample((n_points, n_dim)).astype(np.float32)

dataset = create_random_data(n_dataset_points, n_dim, seed=420)


In [8]:
execution_times = []

In [9]:
for n_query_points in [100, 1000, 10000]:
    query_points = create_random_data(n_query_points, n_dim)
    execution_time = %timeit -o knn_search_numpy(query_points, dataset, k=k)
    execution_times.append(
        {
            "n_query_points": n_query_points,
            "n_dataset_points": n_dataset_points,
            "n_dim": n_dim,
            "k": k,
            "execution_time": execution_time.average,
            "function": "knn_search_numpy",
        }
    )


29.4 ms ± 102 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
313 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.51 s ± 43.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
for n_query_points in [100, 1000, 10000]:
    query_points = create_random_data(n_query_points, n_dim)
    execution_time = %timeit -o euclidean_distances_numpy(query_points, dataset)
    execution_times.append(
        {
            "n_query_points": n_query_points,
            "n_dataset_points": n_dataset_points,
            "n_dim": n_dim,
            "execution_time": execution_time.average,
            "function": "euclidean_distances_numpy",
        }
    )


19.8 ms ± 204 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
202 ms ± 2.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.03 s ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
px.line(
    execution_times,
    x="n_query_points",
    y="execution_time",
    title="Execution Time: kNN Search",
    labels={"n_query_points": "Number of Query Points", "execution_time": "Execution Time (s)"},
    log_x=True,
    log_y=True,
    markers=True,
    color="function",
)

This is basically what we are interested in:
- Measure the run time of a specific function.
- See how it scales with the size of the input.
- Compare to other functions, in this case we measure the sub-function `euclidean_distances_numpy`. We would usually do this using a profiler, such as py-spy, but in this case we can use this simple and illustrative approach.
- Later, we will compare to other implementations.

## JAX Numpy and JIT

[JAX Quickstart](https://docs.jax.dev/en/latest/quickstart.html):
> JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.



JAX provides a NumPy-compatible API (`jax.numpy`, often imported as `jnp`) that allows users to write array-based scientific code using familiar NumPy syntax. Unlike standard NumPy, JAX operations are designed to run efficiently on CPUs, GPUs, and TPUs, enabling hardware acceleration for numerical computations.

A key feature of JAX is its Just-In-Time (JIT) compilation, accessed via the `jax.jit` decorator or function. JIT compilation automatically transforms Python functions into highly optimised machine code, fusing operations and reducing Python overhead. This results in substantial performance improvements, especially for large-scale or repeated computations.

By combining the JAX NumPy API with JIT compilation, we can write clear, concise scientific code that is automatically optimised for modern hardware.


### Practical example

Let's demonstrate this with a simple example form the JAX documentation:

In [57]:
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

def selu(x: ArrayLike, alpha: float = 1.67, lmbda: float = 1.05) -> jax.Array:
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


Note also we use `jnp.array` instead of `np.array` in the JAX version. Here, we created from scratch as a JAX array. Often, we would convert from a NumPy array.

Conversion from NumPy to JAX arrays is efficient if the NumPy array is on the default device (CPU) and has a compatible dtype and memory layout—JAX will use zero-copy conversion in this case, simply wrapping the existing memory. However, if the array is not compatible (e.g., wrong dtype, not C-contiguous, or on a different device), JAX will make a copy. Thus, zero-copy is possible but not guaranteed; ensure arrays are C-contiguous and of supported dtype for best efficiency.











In [13]:
key = jax.random.key(1701)
x = jax.random.normal(key, (1_000_000,))

%timeit selu(x)

1.34 ms ± 35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
selu_jit = jax.jit(selu)

print(selu_jit(x)[:3])

[-0.83556366  0.33142313 -0.9244633 ]


Two important things happened above:

1. We instructed to Just-In-Time (JIT) compile the function when we call it.
2. The function *was* compiled in the `print` call. It was compiled for the *concrete input type and shape*.







In [15]:
%timeit selu_jit(x).block_until_ready()

407 μs ± 4.95 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


The compiled function is significantly faster than the uncompiled one!

Note the use of `block_until_ready`. *JAX is asynchronous by default*, and return something like futures. For timing the whole calculation, we need to wait for the result to be ready.

We can also empirically verify that the compilation must be done again for a different input shape, leading to an increase in the execution time.

In [16]:
%timeit -n 1 -r 1 selu_jit(x[:-1]).block_until_ready()

87.7 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### Exercise: JIT compiling Euclidean distance

1. Create a `jax.numpy` version of the Euclidean distance function `euclidean_distances_numpy`.
2. Create a JIT-compiled version of the new function.
3. Compare the performance of the numpy version, and the uncompiled and compiled jax version.
4. Check that the JAX version yields the same result as the numpy version.

Optionally:

5. Compare the scaling of the performance of all the versions with respect to the number of query points or the number of dimensions.

In [4]:
# %%writefile euclidean_distances_no_jit.py

import jax
import jax.numpy as jnp


def euclidean_distances_jax(
    query_points: jnp.ndarray, dataset: jnp.ndarray
) -> jnp.ndarray:
    """
    Calculates the Euclidean distance between a set of query points and a dataset of points.

    Args:
        query_points (jnp.ndarray): Array of shape (n_queries, n_features).
        dataset (jnp.ndarray): Array of shape (n_samples, n_features).

    Returns:
        jnp.ndarray: The Euclidean distance between the query points and the dataset.
    """
    # Broadcasting (dataset - query_point) subtracts query_point from each row of dataset
    return jnp.sqrt(jnp.sum((dataset[:, jnp.newaxis, :] - query_points) ** 2, axis=-1))


euclidean_distances_jax_jit = jax.jit(euclidean_distances_jax)

In [18]:
n_dataset_points: int = 10_000
n_query_points: int = 100
n_dim: int = 3
k: int = 5


dataset = create_random_data(n_dataset_points, n_dim, seed=420)
query_points = create_random_data(n_query_points, n_dim, seed=421)

dataset_jax = jnp.array(dataset)
query_points_jax = jnp.array(query_points)

np.testing.assert_allclose(euclidean_distances_numpy(query_points, dataset), euclidean_distances_jax(query_points_jax, dataset_jax))
np.testing.assert_allclose(euclidean_distances_numpy(query_points, dataset), euclidean_distances_jax_jit(query_points_jax, dataset_jax), rtol=1e-6)

In [19]:
%timeit euclidean_distances_numpy(query_points, dataset)
%timeit euclidean_distances_jax(query_points_jax, dataset_jax).block_until_ready()
%timeit euclidean_distances_jax_jit(query_points_jax, dataset_jax).block_until_ready()

19.8 ms ± 178 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.7 ms ± 141 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
246 μs ± 2.22 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [20]:
euclidean_execution_times = []
euclidean_distances_jax_jit.__name__ = "euclidean_distances_jax_jit"

for function in [euclidean_distances_numpy, euclidean_distances_jax, euclidean_distances_jax_jit]:
    for n_dim in [2, 4]:
        dataset = create_random_data(n_dataset_points, n_dim)
        for n_query_points in [1, 100, 1_000, 10_000]:
            query_points = create_random_data(n_query_points, n_dim)
            if "numpy" in function.__name__:
                execution_time = %timeit -o function(query_points, dataset)
            else:
                function(query_points, dataset).block_until_ready()
                execution_time = %timeit -o function(query_points, dataset).block_until_ready()
            euclidean_execution_times.append(
                {
                    "n_query_points": n_query_points,
                    "n_dataset_points": n_dataset_points,
                    "n_dim": n_dim,
                    "execution_time": execution_time.average,
                    "function": function.__name__,
                }
            )


119 μs ± 6.17 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
14.5 ms ± 160 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
136 ms ± 1.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.39 s ± 4.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
177 μs ± 7.34 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
17.4 ms ± 138 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
172 ms ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.76 s ± 11.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
93 μs ± 549 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
7.67 ms ± 56.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
72.6 ms ± 362 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
769 ms ± 11.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
150 μs ± 2.95 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
10.9 ms ± 30.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
102 ms

In [21]:
px.line(
    euclidean_execution_times,
    x="n_query_points",
    y="execution_time",
    title="Execution Time: kNN Search",
    labels={"n_query_points": "Number of Query Points", "execution_time": "Execution Time (s)"},
    log_x=True,
    log_y=True,
    markers=True,
    color="function",
    facet_row="n_dim",
)

## Numba JIT compilation

### What is Numba?

[Numba documantaion](https://numba.pydata.org) says:
> Numba is an open source JIT compiler that translates a subset of Python and NumPy code into fast machine code.

- Numba is a powerful just-in-time (JIT) compiler for Python that specialises in accelerating numerical and scientific computations.
- It compiles Python functions to optimised machine code at runtime, dramatically speeding up array operations and mathematical algorithms.
- This often allows us to achieve near C-level speeds for data science tasks, while maintaining the readability and flexibility of Python.

### Differences between Numba and JAX

- Numba does not need to recompile for a different input shape.
- Numba works with numpy arrays directly.
- Numba supports ahead of time compilation.
- Numba often requires the code to be written in a specific way, typically using loops.
- Numba does not provide automatic differentiation.
- GPU programming in Numba is more granular, lower level.


### Simple example

In [22]:
import numba

x = np.arange(100).reshape(10, 10)

def go_fast(a: np.ndarray) -> np.ndarray: # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting

go_fast_jit = numba.jit(go_fast)

print(go_fast_jit(x))

[[  9.  10.  11.  12.  13.  14.  15.  16.  17.  18.]
 [ 19.  20.  21.  22.  23.  24.  25.  26.  27.  28.]
 [ 29.  30.  31.  32.  33.  34.  35.  36.  37.  38.]
 [ 39.  40.  41.  42.  43.  44.  45.  46.  47.  48.]
 [ 49.  50.  51.  52.  53.  54.  55.  56.  57.  58.]
 [ 59.  60.  61.  62.  63.  64.  65.  66.  67.  68.]
 [ 69.  70.  71.  72.  73.  74.  75.  76.  77.  78.]
 [ 79.  80.  81.  82.  83.  84.  85.  86.  87.  88.]
 [ 89.  90.  91.  92.  93.  94.  95.  96.  97.  98.]
 [ 99. 100. 101. 102. 103. 104. 105. 106. 107. 108.]]


In [23]:
%timeit go_fast(x)
%timeit go_fast_jit(x)

7.52 μs ± 73.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
545 ns ± 6.67 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


Important in our context:

- Numpy arrays don't need to be converted to a different type.
- Numba "likes loops" - we may need to rewrite the code to use loops instead of vectorised operations.


### Exercise: Numba JIT compilation

1. Create a Numba JIT-compiled version of the Euclidean distance function `euclidean_distances_numpy`.
2. Verify that the JIT-compiled version yields the same result as the numpy version.
3. Compare the performance of the numpy version, and the uncompiled and compiled numba version.
4. Check that JIT does not depend on the input shape by comparing the performance on a `[:-1¨]` slice of the input.


In [24]:
euclidean_distances_numba_jit = numba.jit(euclidean_distances_numpy)

np.testing.assert_allclose(euclidean_distances_numpy(query_points, dataset), euclidean_distances_numba_jit(query_points, dataset))

In [25]:
%timeit euclidean_distances_numpy(query_points, dataset)
%timeit euclidean_distances_numba_jit(query_points, dataset)

1.8 s ± 48.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.12 s ± 9.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [30]:
%timeit euclidean_distances_numba_jit(query_points[:-1], dataset)

1.12 s ± 7.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Numba performance tuning

The following are important options and techniques for optimising Numba code, as summarised from the official [Numba performance tips](https://numba.readthedocs.io/en/stable/user/performance-tips.html):

1. Use `@numba.njit` (no Python mode):
   - Prefer `@numba.njit` (or `@numba.jit(nopython=True)`) to ensure code is compiled in "nopython" mode, avoiding the Python interpreter for maximum speed.

2. Enable parallel execution:
   - Use `parallel=True` in the decorator (e.g., `@numba.njit(parallel=True)`) to allow Numba to automatically parallelise supported loops using `numba.prange`.
   - Replace `range` with `numba.prange` in outer loops to enable parallelism.

3. Enable fast math optimisations:
   - Use `fastmath=True` to allow the compiler to apply aggressive floating-point optimisations, potentially sacrificing some numerical precision for speed.
   - Example: `@numba.njit(fastmath=True)`

4. Prefer simple, explicit loops:
   - Numba excels with explicit for-loops and simple array operations.
   - Avoid complex Python features, object arrays, or unsupported NumPy functions.

5. Minimise Python object usage:
   - Use NumPy arrays and primitive types; avoid lists of objects or dictionaries inside JIT-compiled functions.

6. Preallocate arrays:
   - Allocate output arrays before entering loops to avoid dynamic resizing, which is slow and not supported in nopython mode.

7. Use supported NumPy functions:
   - Stick to NumPy functions and methods that are supported by Numba for best performance.



### Exercise: Numba performance tuning

1. First, try `njit`, `fastmath=True`, and `parallel=True` on the `euclidean_distances_numpy` function. Measure the performance.
2. Try the same options on the `euclidean_distances_numba_optimised` defined below. Compare the performance to the previous versions.
3. Try to optimise the `euclidean_distances_numba_optimised` for `parallel=True` compilation.

```python

```

In [26]:
euclidean_distances_numba_perf = numba.njit(fastmath=True, parallel=False)(euclidean_distances_numpy)
# parallel=True does not work
# there seems to be only a very small speed up from fastmath

np.testing.assert_allclose(euclidean_distances_numpy(query_points, dataset), euclidean_distances_numba_perf(query_points, dataset))

%timeit euclidean_distances_numba_perf(query_points, dataset)

1.09 s ± 7.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
def euclidean_distances_numba_optimised(query_points: np.ndarray, dataset: np.ndarray) -> np.ndarray:
    M = query_points.shape[0]
    N = dataset.shape[0]
    D = query_points.shape[1]
    distances = np.empty((N, M), dtype=dataset.dtype)
    for i in range(N):
        for j in range(M):
            d = 0.0
            for k in range(D):
                tmp = dataset[i, k] - query_points[j, k]
                d += tmp * tmp
            distances[i, j] = np.sqrt(d)
    return distances


In [28]:
euclidean_distances_numba_optimised_jit = numba.njit(fastmath=True, parallel=False)(euclidean_distances_numba_optimised)

np.testing.assert_allclose(euclidean_distances_numpy(query_points, dataset), euclidean_distances_numba_optimised_jit(query_points, dataset), rtol=1e-6)

%timeit euclidean_distances_numba_optimised_jit(query_points, dataset)

248 ms ± 323 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
dataset_jax = jnp.array(dataset)
query_points_jax = jnp.array(query_points)

%timeit euclidean_distances_jax_jit(query_points_jax, dataset_jax).block_until_ready()

20.5 ms ± 157 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## JAX on GPU

Open the [03-jax-gpu.ipynb](https://colab.research.google.com/github/coobas/europython-25/blob/main/03-jax-gpu.ipynb) notebook.

## Further optimisation of the kNN search function

So far we focused on the distance computation. This was more important as that calculation is the most time-consuming part of the kNN search. However, let's also try to optimise the nearest neighbour search, i.e. `knn_search_numpy`.

- We can still get some speed-up by using JIT compilation.
- We can avoind possible memory costly data copies between numpy and jax arrays.
- And we can possibly run everything on GPU.

### Exercise: JIT compilation of `knn_search_numpy`

1. Create a JIT-compiled version of `knn_search_numpy`.
2. Compare the performance of the JIT-compiled version to the original version.
3. Verify the outputs of the JIT-compiled version match the outputs of `knn_search_numpy`. Note that the order of the indices does not matter.
4. Try to run it on GPU and compare the performance.

*Hints:* You will most likely run into issues. Look into how `static_argnames` parameter of `jax.jit` works. You will also need to replace `np.argpartition`: there are alternatives either in `numpy` or in `jax.lax`.

In [34]:
def knn_search_jax(
    query_points: jnp.ndarray,
    dataset: jnp.ndarray,
    k: int,
) -> jnp.ndarray:
    """
    Finds the k nearest neighbors for a single query point using NumPy.

    Args:
        query_points (jnp.ndarray): Array of shape (n_queries, n_features).
        dataset (jnp.ndarray): Array of shape (n_samples, n_features).
        k (int): The number of neighbors to find.

    Returns:
        jnp.ndarray: Array of shape (n_queries, k) containing the indices of the k nearest neighbors in the dataset.
    """
    # This should work for multiple query points - could be an exercise for the workshop
    distances = euclidean_distances_jax_jit(query_points, dataset)

    # Find the indices of the k smallest distances
    values, nearest_indices = jax.lax.top_k(-distances.T, k)
    # nearest_indices = jnp.argpartition(distances, k, axis=0)[:k].T
    return nearest_indices


knn_search_jax_jit = jax.jit(knn_search_jax, static_argnames=["k"])


In [38]:
np.testing.assert_array_equal(
    np.sort(knn_search_jax_jit(query_points, dataset, k=5), axis=1),
    np.sort(knn_search_numpy(query_points, dataset, k=5), axis=1)
)


## JAX vmap

`vmap` (vectorising map) is JAX's automatic vectorisation transformation that allows you to apply functions designed for single inputs to batches of inputs efficiently.

- Transforms functions that work on single values to work on batches without manual loop writing or vectorisation.
- Generates efficient vectorised code that can leverage SIMD instructions and GPU parallelism.
- Automatically handles broadcasting and dimension management across batch dimensions.
- Can be combined with other JAX transformations like `jit`.


### `vmap` example

Let's take a simple function `sum_of_squares`, defined as $f(x) = \sum_{i=1}^n x_i^2$.

First, we implement the function in a way that it expects a single vector as an input.

In [60]:
def sum_of_squares(vector: ArrayLike) -> jax.Array:
  # This function expects a 1D array (vector)
  print(f"Running sum_of_squares for a vector of shape: {vector.shape}")
  return jnp.sum(vector**2)

# Example single vector
single_vector = jnp.array([1., 2., 3.])
result_single = sum_of_squares(single_vector)
print(f"Result for single vector: {result_single}")

Running sum_of_squares for a vector of shape: (3,)
Result for single vector: 14.0


What if we now want to compute the sum of squares for a batch of vectors? Let's try to just execute the function for a batch of vectors:

In [62]:
batch_of_vectors = jnp.array([
    [1., 2., 3.],
    [4., 5., 6.],
    [7., 8., 9.],
    [0., 1., 0.]
])

sum_of_squares(batch_of_vectors)

Running sum_of_squares for a vector of shape: (4, 3)


Array(286., dtype=float32)

This is not what we wanted! We actually needed to calculate the result for each vector in the batch. We could just loop but that would most likely be slow.

Luckily, JAX provides a way to do this automatically. We can use the `vmap` function to vectorise the function.

In [63]:
vectorized_sum_of_squares = jax.vmap(sum_of_squares)

vectorized_sum_of_squares(batch_of_vectors)

Running sum_of_squares for a vector of shape: (3,)


Array([ 14.,  77., 194.,   1.], dtype=float32)

### Exercise: Use `vmap` to vectorise the Euclidean distance calculation

We defined `euclidean_distances_numpy` and `euclidean_distances_jax` already in a vectorised way. This was possible thanks to the broadcasting of numpy arrays.

In this exercise, let's start from a simple `distance_scalar` function, which works on two vectors.
The goal is to vectorise it using `vmap`, so that `x` and `y` can be `(m, n_dim)` and `(n, n_dim)` arrays.
The result should be a `(m, n)` array. This is exactly the same behaviour as in `euclidean_distances_numpy` and `euclidean_distances_jax`.

*Hints:* You will need to use `in_axes` parameter of `vmap`. You may need to use `vmap` twice.

Optionally, compare the performance of the vmap version to the `euclidean_distances_jax` function.

In [46]:
def distance_scalar(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    # return jnp.sqrt(jnp.sum((x - y)**2, axis=-1))
    return jnp.sqrt(jnp.sum((x - y)**2))


distance_vmap = jax.vmap(distance_scalar, in_axes=(None, 0))
distance_vmap_2 = jax.vmap(distance_vmap, in_axes=(0, None))


x_test = jnp.array([[1, 2, 3], [4, 5, 6]])
y_test = jnp.array([[4, 5, 5], [1, 2, 3], [7, 8, 9], [10, 11, 12]])

distance_vmap(x_test, y_test)

Array([ 4.7958317,  5.196152 , 11.61895  , 18.734995 ], dtype=float32)

In [47]:
x_test.shape

(2, 3)

In [48]:
distance_vmap_2(x_test, y_test).shape

(2, 4)

In [29]:
distance_scalar(x_test[0], y_test)

Array([4.690416, 0.      ], dtype=float32)

In [30]:
distance_vmap(x_test, y_test)

Array([[4.690416, 0.      ],
       [1.      , 5.196152]], dtype=float32)

In [27]:
jnp.sum((x_test[0] - y_test)**2, axis=-1)

Array([22,  0], dtype=int32)

In [45]:
def euclidean_distance_sq(point_a, point_b):
  """Calculates the squared Euclidean distance between two points."""
  # Assumes points are 1D vectors
  return jnp.sqrt(jnp.sum((point_a - point_b)**2))

# Batch A: 3 points in 2D space
points_a = jnp.array([[1., 0.],
                      [0., 1.],
                      [-1., 0.]]) # Shape (3, 2)

# Batch B: 4 points in 2D space
points_b = jnp.array([[2., 2.],
                      [-2., 2.],
                      [2., -2.],
                      [-2., -2.]]) # Shape (4, 2)

# Goal: Compute a 3x4 matrix where entry (i, j) is distance(points_a[i], points_b[j])

# Inner map: Compute distance between ONE point from A and ALL points in B
# Map over points_b (axis 0), keep point_a fixed (None)
inner_map = jax.vmap(euclidean_distance_sq, in_axes=(None, 0))
# inner_map(points_a[0], points_b) would compute distances from points_a[0] to all points_b

# Outer map: Apply inner_map to EACH point in A
# Map over points_a (axis 0), pass the whole points_b batch (None axis) to the inner map
pairwise_distance_sq = jax.vmap(inner_map, in_axes=(0, None))

# Equivalent direct definition:
pairwise_distance_sq_direct = jax.vmap(
    jax.vmap(euclidean_distance_sq, in_axes=(None, 0)), # Map over B for fixed A
    in_axes=(0, None)                                  # Map over A, passing full B batch
)

distance_matrix = pairwise_distance_sq(points_a, points_b)
distance_matrix_direct = pairwise_distance_sq_direct(points_a, points_b)

print("Shape of distance matrix:", distance_matrix.shape)
# Expected output: Shape of distance matrix: (3, 4)

print("Distance Matrix (Squared):\n", distance_matrix)

# Verify one element manually: distance_sq(points_a[0], points_b[0])
manual_dist_sq_00 = euclidean_distance_sq(points_a[0], points_b[0]) # (1-2)^2 + (0-2)^2 = 1 + 4 = 5
print("Manual dist_sq[0, 0]:", manual_dist_sq_00)
print("Matrix element [0, 0]:", distance_matrix[0, 0])
# Expected output: Manual dist_sq[0, 0]: 5.0, Matrix element [0, 0]: 5.0

Shape of distance matrix: (3, 4)
Distance Matrix (Squared):
 [[2.236068  3.6055512 2.236068  3.6055512]
 [2.236068  2.236068  3.6055512 3.6055512]
 [3.6055512 2.236068  3.6055512 2.236068 ]]
Manual dist_sq[0, 0]: 2.236068
Matrix element [0, 0]: 2.236068
