# Gradients and derivatives with JAX

## JAX can compute derivatives of functions

- JAX can automatically compute gradients of scalar-valued functions using `jax.grad`, enabling efficient optimisation and machine learning workflows.
- JAX also provides `jax.jacobian` to compute the full Jacobian matrix, which is essential for sensitivity analysis and advanced optimisation.
- Higher-order derivatives are supported via repeated application of `jax.grad`, or using `jax.hessian`.
- JAX's differentiation is based on composable, functional transformations, making it easy to combine with JIT compilation (`jax.jit`) for high performance.
- JAX's autodiff works seamlessly on CPUs, GPUs, and TPUs, enabling scalable scientific computing.


In [47]:
# Run this in Google Collab, perhaps not elsewhere
# Skip if you've already done this

# !git clone https://github.com/coobas/europython-25.git
# !mkdir -p local_data
# !cp europython-25/*.parquet local_data/

In [48]:
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px

## Basics

## Mini-project: Optimising location of a property

In [20]:
def euclidean_distances_jax(
    query_points: jnp.ndarray, dataset: jnp.ndarray
) -> jnp.ndarray:
    return jnp.sqrt(jnp.sum((dataset[:, jnp.newaxis, :] - query_points) ** 2, axis=-1))


euclidean_distances_jax_jit = jax.jit(euclidean_distances_jax)


def knn_search_jax(
    query_points: jnp.ndarray,
    dataset: jnp.ndarray,
    k: int,
) -> jnp.ndarray:
    distances = euclidean_distances_jax_jit(query_points, dataset)
    values, nearest_indices = jax.lax.top_k(-distances.T, k)
    return nearest_indices


knn_search_jax_jit = jax.jit(knn_search_jax, static_argnames=["k"])


In [21]:
def knn_mean_jax(query_points: jnp.ndarray, dataset: jnp.ndarray, values: jnp.ndarray, k: int) -> jnp.ndarray:
    """
    Predict target values for the provided data.

    Parameters
    ----------
    query_points : jnp.ndarray
        Array of shape (n_samples, n_features) containing the points for which predictions are required.
    dataset : jnp.ndarray
        Array of shape (n_train_samples, n_features) containing the training data.
    values : jnp.ndarray
        Array of shape (n_train_samples,) containing the values of the training data.
    k : int
        Number of nearest neighbours to consider for each prediction.

    Returns
    -------
    jnp.ndarray of shape (n_samples,)
        Predicted target values.
    """

    # Find k nearest neighbours for each query point
    neighbor_indices = knn_search_jax_jit(
        query_points, dataset, k=k
    )

    # Get the neighbour values of the nearest neighbours
    neighbor_values = values[neighbor_indices]

    # Return the mean of the neighbour targets (regression prediction)
    return jnp.mean(neighbor_values, axis=1).squeeze()


knn_mean_jax_jit = jax.jit(knn_mean_jax, static_argnames=["k"])

In [22]:
price_data = pd.read_parquet("data.parquet")

In [23]:
px.scatter(price_data.query("floor == 0"), x="x", y="y", color="price").update_layout(yaxis_scaleanchor="x", yaxis_constrain="domain")

In [24]:
dataset = price_data[["x", "y", "floor"]].to_numpy().astype(np.float32)
values = price_data["price"].to_numpy().astype(np.float32)

In [29]:
query_points = np.array([[7, 0.5, 1]])

In [30]:
knn_mean_jax(query_points, dataset, values, k=5)

Array(552.65857, dtype=float32)

In [37]:
locations = np.array([
    [2.5, 3.8],    # School
    [-1.2, 0.5],   # Work
    [4.1, -2.3],   # Parents
    [-0.8, 2.9],   # Sports Club
], dtype=np.float32)
min_price_bound = 700
max_price_bound = 1500
k = 5


def objective(x: jnp.ndarray) -> float:
    # Compute Euclidean distances to each location
    dists = euclidean_distances_jax_jit(x[:2], locations)
    # Return the mean distance
    return jnp.mean(dists)


def price(x: jnp.ndarray) -> float:
    price = knn_mean_jax_jit(x, dataset, values, k=k)
    return price

In [38]:
x_test = jnp.array([2, 1, 3])
print(f"Objective: {objective(x_test)}")
print(f"Price: {price(x_test)}")

Objective: 3.344606399536133
Price: 1570.916015625


In [41]:
import scipy.optimize


x_bounds = scipy.optimize.Bounds([-10, -10, 0], [10, 10, 10])

price_constraint = scipy.optimize.NonlinearConstraint(price, min_price_bound, max_price_bound, jac=jax.jacobian(price))


objective_jac = jax.jacobian(objective)

objective_hess = jax.hessian(objective)

x_opt = scipy.optimize.minimize(
    objective,
    x0=[2, 2, 0],
    method="trust-constr",
    jac=objective_jac,
    hess=objective_hess,
    constraints=[price_constraint],
    # jac="2-point",
    # hess=scipy.optimize.BFGS(),
    bounds=x_bounds,
    options={"verbose": 1},
    tol=1e-8,
)

x_opt


`xtol` termination condition is satisfied.
Number of iterations: 188, function evaluations: 207, CG iterations: 185, optimality: 4.74e-02, constraint violation: 0.00e+00, execution time: 0.82 s.


           message: `xtol` termination condition is satisfied.
           success: True
            status: 2
               fun: 3.0321483612060547
                 x: [ 5.041e-01  1.814e+00  1.046e+00]
               nit: 188
              nfev: 207
              njev: 63
              nhev: 63
          cg_niter: 185
      cg_stop_cond: 2
              grad: [ 4.835e-02  4.485e-03  0.000e+00]
   lagrangian_grad: [ 4.739e-02  4.388e-03 -8.976e-10]
            constr: [array([ 1.453e+03], dtype=float32), array([ 5.041e-01,  1.814e+00,  1.046e+00])]
               jac: [array([[ 0.000e+00,  0.000e+00,  0.000e+00]],
                          dtype=float32), array([[ 1.000e+00,  0.000e+00,  0.000e+00],
                           [ 0.000e+00,  1.000e+00,  0.000e+00],
                           [ 0.000e+00,  0.000e+00,  1.000e+00]])]
       constr_nfev: [207, 0]
       constr_njev: [207, 0]
       constr_nhev: [0, 0]
                 v: [array([ 4.050e-11]), array([-9.551e-04, -9.691e-05, 

In [46]:
import plotly.graph_objects as go

def plot_locations_and_optimum(
    locations: np.ndarray,
    optimum: jnp.ndarray
) -> None:
    """
    Visualise the location points and the optimum found by optimisation using Plotly.

    Args:
        locations (list): List of Location objects with x, y, and name attributes.
        optimum (jnp.ndarray): Optimised coordinates (at least 2D).

    This function creates an interactive scatter plot of the key locations
    (e.g., School, Work, Parents, Sports Club) and overlays the optimum
    location found by the optimisation routine. This aids in interpreting
    the optimisation result in the context of the real-world locations.

    Example:
        plot_locations_and_optimum(locations, x_opt.x)
    """
    # Extract x and y coordinates and names from the locations
    xs = locations[:, 0]
    ys = locations[:, 1]
    names = [f"Location {i}" for i in range(len(locations))]

    # Prepare the scatter plot for locations
    location_trace = go.Scatter(
        x=xs,
        y=ys,
        mode="markers+text",
        name="Locations",
        marker=dict(size=14, color="blue"),
        text=names,
        textposition="top right"
    )

    # Prepare the scatter plot for the optimum
    optimum_trace = go.Scatter(
        x=[float(optimum[0])],
        y=[float(optimum[1])],
        mode="markers+text",
        name="Optimum",
        marker=dict(size=20, color="red", symbol="star"),
        text=["Optimum"],
        textposition="bottom left"
    )

    layout = go.Layout(
        title="Location Points and Optimum from Optimisation",
        xaxis=dict(title="X coordinate"),
        yaxis=dict(title="Y coordinate"),
        legend=dict(x=0.01, y=0.99),
        width=700,
        height=700
    )

    fig = go.Figure(data=[location_trace, optimum_trace], layout=layout)
    fig.show()

plot_locations_and_optimum(locations, x_opt.x)