# Gradients and derivatives with JAX

## JAX can compute derivatives of functions

- JAX can automatically compute gradients of scalar-valued functions using `jax.grad`, enabling efficient optimisation and machine learning workflows.
- JAX automatically computes exact derivatives, not numerical approximations.
- JAX also provides `jax.jacobian` to compute the full Jacobian matrix, which is essential for sensitivity analysis and advanced optimisation.
- Higher-order derivatives are supported via repeated application of `jax.grad`, or using `jax.hessian`.
- JAX's differentiation is based on composable, functional transformations, making it easy to combine with JIT compilation (`jax.jit`) for high performance.
- JAX's autodiff works seamlessly on CPUs, GPUs, and TPUs, enabling scalable scientific computing.

This automatic differentiation capability is essential e.g. for gradient-based optimisation, as we'll see in the location optimisation example below.


In [None]:
from IPython import get_ipython

is_colab = "google.colab" in str(get_ipython())

# Download the data which are part of this repo
if is_colab:
    import urllib.request
    url = "https://github.com/coobas/europython-25/raw/refs/heads/main/data.parquet"
    urllib.request.urlretrieve(url, "data.parquet")

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px

## Basic usage


In [None]:
def f(x):
    return x**2 + 2*x + 1

# Create the gradient function
grad_f = jax.grad(f)

# Compute the gradient at a point
grad_f(3.0)  # Returns 8.0 (derivative of x²+2x+1 is 2x+2)

For functions with multiple arguments, use the `argnums` parameter:

In [None]:
def loss(w, b):
    return w**2 + b**2

# Gradient with respect to first argument (w)
grad_w = jax.grad(loss, argnums=0)

# Gradient with respect to both arguments
grad_both = jax.grad(loss, argnums=(0, 1))

grad_both(1.0, 2.0)  # Returns (2.0, 4.0)

### Exercise: Gradient of a function using jax.numpy

`jax.grad` works with functions built using `jax.numpy`. Let's compute the gradient of a
function involving `jnp.sum` and `jnp.sin`.

Consider $k(v) = \sum_i \sin(v_i)$, where $v$ is a vector. The gradient $\nabla k(v)$ is a vector where
the j-th element is $\frac{\partial k}{\partial v_j} = \cos(v_j)$.

1. Implement the function `k(v)`.
2. Construct the gradient `k_grad(v)` using `jax.grad`.
3. Test the gradient by comparing the analytical gradient with the numerical gradient.


## Mini-project: Optimising location of a property

Say we would like to find the optimal location for buying a property that:
- Minimises the distance to some reference points
- Is within a given price range (not too expensive, not too cheap)

### Fundamental functions

Someone has provided price data and implemented useful functions for kNN regression, which we can use to estimate the price at any given location.

In [None]:
def euclidean_distances_jax(
    query_points: jnp.ndarray, dataset: jnp.ndarray
) -> jnp.ndarray:
    return jnp.sqrt(jnp.sum((dataset[:, jnp.newaxis, :] - query_points) ** 2, axis=-1))


euclidean_distances_jax_jit = jax.jit(euclidean_distances_jax)


def knn_search_jax(
    query_points: jnp.ndarray,
    dataset: jnp.ndarray,
    k: int,
) -> jnp.ndarray:
    distances = euclidean_distances_jax_jit(query_points, dataset)
    values, nearest_indices = jax.lax.top_k(-distances.T, k)
    return nearest_indices


knn_search_jax_jit = jax.jit(knn_search_jax, static_argnames=["k"])


In [None]:
def knn_mean_jax(query_points: jnp.ndarray, dataset: jnp.ndarray, values: jnp.ndarray, k: int) -> jnp.ndarray:
    """
    Predict target values for the provided data.

    Parameters
    ----------
    query_points : jnp.ndarray
        Array of shape (n_samples, n_features) containing the points for which predictions are required.
    dataset : jnp.ndarray
        Array of shape (n_train_samples, n_features) containing the training data.
    values : jnp.ndarray
        Array of shape (n_train_samples,) containing the values of the training data.
    k : int
        Number of nearest neighbours to consider for each prediction.

    Returns
    -------
    jnp.ndarray of shape (n_samples,)
        Predicted target values.
    """

    # Find k nearest neighbours for each query point
    neighbor_indices = knn_search_jax_jit(
        query_points, dataset, k=k
    )

    # Get the neighbour values of the nearest neighbours
    neighbor_values = values[neighbor_indices]

    # Return the mean of the neighbour targets (regression prediction)
    return jnp.mean(neighbor_values, axis=1).squeeze()


knn_mean_jax_jit = jax.jit(knn_mean_jax, static_argnames=["k"])

In [None]:
price_data = pd.read_parquet("data.parquet")

In [None]:
px.scatter(price_data.query("floor == 0"), x="x", y="y", color="price").update_layout(yaxis_scaleanchor="x", yaxis_constrain="domain")

In [None]:
dataset = price_data[["x", "y", "floor"]].to_numpy().astype(np.float32)
values = price_data["price"].to_numpy().astype(np.float32)

In [None]:
# just to check if the knn_mean_jax is working
query_points = np.array([[7, 0.5, 1]])
knn_mean_jax(query_points, dataset, values, k=5)

### Objective function and constraints definition

We define the objective function as the mean distance to the reference points.

Also, we define the constraints as:
- Coordinates are within the bounds [-10, -10, 0] and [10, 10, 10] (the dimensions are x, y, floor)
- The price is between 700 and 1500
  - Note that the price is a nonlinear constraint.
  

In [None]:
locations = np.array([
    [2.5, 3.8],    # School
    [-1.2, 0.5],   # Work
    [4.1, -2.3],   # Parents
    [-0.8, 2.9],   # Sports Club
], dtype=np.float32)
min_price_bound = 700
max_price_bound = 1500
k = 5


def objective(x: jnp.ndarray) -> float:
    # Compute Euclidean distances to each location
    dists = euclidean_distances_jax_jit(x[:2], locations)
    # Return the mean distance
    return jnp.mean(dists)


def price(x: jnp.ndarray) -> float:
    price = knn_mean_jax_jit(x, dataset, values, k=k)
    return price

In [None]:
x_test = jnp.array([2, 1, 3])
print(f"Objective: {objective(x_test)}")
print(f"Price: {price(x_test)}")

In [None]:
import scipy.optimize


x_bounds = scipy.optimize.Bounds([-10, -10, 0], [10, 10, 10])

price_constraint = scipy.optimize.NonlinearConstraint(price, min_price_bound, max_price_bound, jac=jax.jacobian(price))


### Exercise: Use `scipy.optimize` to find the optimal location

Use the [Trust-Region Constrained Algorithm (method='trust-constr')](https://docs.scipy.org/doc/scipy/tutorial/optimize.html#trust-region-constrained-algorithm-method-trust-constr) of `scipy.optimize.minimize` to find the optimal location.
- Use some random starting point `x0`.
- Provide the Jacobian and Hessian of the objective function, constructed using JAX.
  - You can eventually JIT-compile the Jacobian and Hessian.
- Do not forget the bounds and constraints.
- Plot the result using `plot_locations_and_optimum`.

In [None]:
import plotly.graph_objects as go

def plot_locations_and_optimum(
    locations: np.ndarray,
    optimum: jnp.ndarray
) -> None:
    """
    Visualise the location points and the optimum found by optimisation using Plotly.

    Args:
        locations (list): List of Location objects with x, y, and name attributes.
        optimum (jnp.ndarray): Optimised coordinates (at least 2D).

    This function creates an interactive scatter plot of the key locations
    (e.g., School, Work, Parents, Sports Club) and overlays the optimum
    location found by the optimisation routine. This aids in interpreting
    the optimisation result in the context of the real-world locations.

    Example:
        plot_locations_and_optimum(locations, x_opt.x)
    """
    # Extract x and y coordinates and names from the locations
    xs = locations[:, 0]
    ys = locations[:, 1]
    names = [f"Location {i}" for i in range(len(locations))]

    # Prepare the scatter plot for locations
    location_trace = go.Scatter(
        x=xs,
        y=ys,
        mode="markers+text",
        name="Locations",
        marker=dict(size=14, color="blue"),
        text=names,
        textposition="top right"
    )

    # Prepare the scatter plot for the optimum
    optimum_trace = go.Scatter(
        x=[float(optimum[0])],
        y=[float(optimum[1])],
        mode="markers+text",
        name="Optimum",
        marker=dict(size=20, color="red", symbol="star"),
        text=["Optimum"],
        textposition="bottom left"
    )

    layout = go.Layout(
        title="Location Points and Optimum from Optimisation",
        xaxis=dict(title="X coordinate"),
        yaxis=dict(title="Y coordinate"),
        legend=dict(x=0.01, y=0.99),
        width=700,
        height=700
    )

    fig = go.Figure(data=[location_trace, optimum_trace], layout=layout)
    fig.show()