UMAP in pure MLX for Apple Silicon. Entire pipeline runs on Metal GPU.
30-46x faster than umap-learn. Fashion-MNIST 70K in 3.4 seconds for 500 epochs.
animation.1.mp4
| umap-mlx | tsne-mlx | pacmap-mlx | |
|---|---|---|---|
| Algorithm | UMAP (McInnes 2018) | t-SNE (van der Maaten 2008) | PaCMAP (Wang 2021) |
| Speedup | 30x vs umap-learn | 12x vs sklearn | 13x vs PaCMAP |
| 70K time | 3.4s | 4.9s | 2.3s |
| Best for | General purpose, fast, preserves local + some global | Local cluster separation, publication-quality plots | Global structure, balanced local/global trade-off |
| Repulsive | Negative sampling SGD | FFT interpolation (N>16K) / exact compiled (N<16K) | Near/mid-near/far pair scheduling |
All three: pure MLX + numpy, no scipy/PyTorch, Metal GPU, pip install -e .
git clone https://github.com/hanxiao/umap-mlx.git && cd umap-mlx
uv venv .venv && source .venv/bin/activate
uv pip install -e .from umap_mlx import UMAP
import numpy as np
X = np.random.randn(1000, 128).astype(np.float32)
Y = UMAP(n_components=2, n_neighbors=15).fit_transform(X)Parameters:
n_components: output dimensions (default 2)n_neighbors: k for k-NN graph (default 15)min_dist: minimum distance in low-dim space (default 0.1)spread: scale of embedded points (default 1.0)n_epochs: optimization epochs (default: 500 for N<=10K, 200 for larger)learning_rate: SGD learning rate (default 1.0)random_state: seed for reproducibilityverbose: print progress
N umap-learn MLX speedup
1000 4.87s 0.40s 12x
2000 6.18s 0.36s 17x
5000 17.22s 0.44s 40x
10000 25.85s 0.56s 46x
20000 22.01s 0.54s 41x
60000 68.99s 2.04s 34x
70000 81.40s 2.65s 31x
Bottleneck at large N is exact k-NN (O(n^2) pairwise distances). umap-learn uses approximate KNN (pynndescent) with better asymptotic scaling.
Fashion-MNIST full (70,000 samples, 784 dims, 10 classes):
Spectral initialization:
After optimization:
- Chunked pairwise distances +
mx.argpartitionfor exact k-NN on Metal GPU - Vectorized binary search for per-point sigma (all N points simultaneously on GPU)
- Sparse graph symmetrization via GPU
_searchsorted+ edge pruning - Spectral initialization via power iteration on normalized graph Laplacian
- Compiled SGD on Metal GPU with
mx.array.at[].add()for scatter accumulation - Negative sampling with repulsive forces
Dependencies: mlx and numpy only. No scipy, no PyTorch.
Approximate KNN via NNDescent. Current exact KNN computes O(n^2) pairwise distances, which dominates runtime at large N (e.g. 2.0s of 3.3s at 70K). NNDescent builds an approximate k-NN graph in O(n log n) by iteratively improving neighbors through neighbor-of-neighbor exploration. The challenge is that NNDescent is inherently iterative with random access patterns (heap operations, conditional updates per candidate), which maps poorly to GPU SIMD execution. umap-learn's implementation is 1500 lines of Numba JIT. A practical MLX version would need to replace heaps with fixed-size sorted buffers and batch the neighbor-of-neighbor lookups as matrix gathers. Likely only worthwhile above 100K+ samples where the O(n^2) distance matrix exceeds memory.
Custom Metal kernel for SGD. The optimization loop currently dispatches
~10 separate MLX operations per epoch (gather, subtract, multiply, power,
clip, scatter-add x3), each as an independent Metal kernel. A fused Metal
shader via mx.fast.metal_kernel() could handle one edge per thread: read
Y[head] and Y[tail], compute the gradient in registers, and atomic-add the
update back to Y. This eliminates intermediate buffer allocations and reduces
kernel dispatch overhead from ~5000 launches to ~500 (one per epoch). Metal 3.0
supports atomic_fetch_add_explicit for float on all Apple Silicon. Estimated
improvement: ~10% on total runtime (SGD portion from 1.3s to ~1.0s at 70K).
The 500-epoch Python loop cannot be collapsed into a single kernel because
Metal lacks global barriers across threadgroups -- each epoch's scatter-add
must complete before the next epoch reads Y.
The demo video above was generated by fashion_mnist_anim.py.
MIT

