JAX implementation of Protein Strain Analysis (PSA): per-site weighted finite strain and deformation gradients, consistent with the reference implementation Sartori-Lab/PSA and the original paper:
Pablo Sartori and Stanislas Leibler, Evolutionary conservation of mechanical strain distributions in functional transitions of protein structures, Phys. Rev. X 14, 011042 (2024). DOI: 10.1103/PhysRevX.14.011042 (APS: link.aps.org).
BibTeX and preprint links: references/sartori_leibler_2024_prx.md.
pip install psax
# editable dev
pip install -e ".[dev]"
# optional: parity tests against upstream PSA (install VCS pin separately — not a PyPI extra)
pip install "psa @ git+https://github.com/Sartori-Lab/PSA.git@a55f44eea3c8165d618cfd607f1c3cebe7535cbb"Requires Python 3.11+.
This project publishes one pure-Python wheel (py3-none-any) and a source distribution, not a matrix of platform wheels. A successful pip install psax still has to install JAX, which pulls jaxlib with platform- and Python-version-specific binaries from PyPI.
If installation “fails on many platforms” or pip spends a long time backtracking, the usual cause is no matching jaxlib wheel for that OS/CPU/Python combo (less common architectures, very new Python, or Windows without the supported stack). See the official JAX installation guide and install a JAX version that supports your environment before or after pip install psax. On Linux x86_64 and recent macOS/arm64 with Python 3.11–3.13, pip install psax typically works out of the box.
proxide is not on PyPI. To run psax run or run_pairwise_psa_from_structures, install proxide from a local checkout in the same environment:
pip install /path/to/proxideIf this repo lives in a uv workspace next to a local proxide member, add a workspace source mapping (see uv workspace sources):
[tool.uv.sources]
proxide = { workspace = true }| Workflow | Purpose |
|---|---|
.github/workflows/ci.yml |
Tests, docs, upstream parity, wheel smoke on main |
.github/workflows/publish.yml |
Build with python -m build and upload to PyPI on GitHub Release (OIDC trusted publishing) |
Configure trusted publishing on PyPI for this repository and workflow, and add a GitHub Environment named pypi if you use protection rules. Tag releases (e.g. v0.1.0) and publish a GitHub Release to trigger the workflow (or run it manually via workflow_dispatch).
For deployment and compiler toolchains, psax exposes a thin export layer in psax.stablehlo: fixed static shapes (N, 3) and (E,) directed edges, with method / ridge_eps / rcond fixed at export time. Serialized artifacts use JAX’s export format (VHLO / calling-convention version) and are tied to the JAX version you built with—see the JAX export documentation and the OpenXLA StableHLO + JAX tutorial.
CLI:
psax emit-stablehlo --n 128 --edges 4096 --out psa_strain.bin(run_pairwise_psa is not the export entrypoint: it builds weights outside of jit.)
- Aligned batch —
Bindependent pairs(x_ref[b], x_def[b])with the same graph (sharededges_i,edges_j,edge_weights): usepsax.batch.deformation_gradient_per_site_vmap_pairsorrun_pairwise_psa_batched_shared_graphwhen the dense weight mask(N, N)is shared. - Ref × design grid — all combinations
(b_ref, b_design)with fixed shared edges:psax.batch.deformation_gradient_per_site_grid_ref_design. - Per-sample edges — same batch size and edge count
E:deformation_gradient_per_site_batched_edges/edge_batch_mode="per_pair". - Different
Eper sample needs padding/masking or a Python loop withpsax.utils.safe_map(seedocs/bucketplaceholder).
import jax.numpy as jnp
from psax.run.pipeline import run_pairwise_psa
x_ref = jnp.array(...) # (N, 3)
x_def = jnp.array(...) # (N, 3)
out = run_pairwise_psa(x_ref, x_def, r_inner=6.0, r_outer=8.0)
F = out.deformation_gradient # (N, 3, 3)
E = out.green_lagrange_strain # (N, 3, 3)
lam = out.principal_strain_eigenvalues # (N, 3)With structures on disk (requires proxide):
from psax.run.pipeline import run_pairwise_psa_from_structures
out = run_pairwise_psa_from_structures("ref.pdb", "def.pdb", align_kabsch=False)psax --version # or -V
psax version
psax run --help
psax emit-stablehlo --help
psax run ref.pdb def.pdb -o out.npz --bfactors-out colored.pdb --template-pdb ref.pdb
psax export template.pdb out.pdb --values "1.0,2.0,3.0"The CLI dispatches to psax.run.pipeline, psax.stablehlo, and psax.io.export.
| What is validated | How |
|---|---|
| JAX vs in-repo NumPy PSA loop | tests/test_core_parity.py (parity_numpy) |
JAX vs installed psa.elastic.deformation_gradient (dense weights, fp64) |
tests/parity_upstream/ (parity_upstream), install psa from git (see Install) |
Symmetrized D + solvers vs NumPy reference |
Covered indirectly when ridge_eps=0 matches upstream dense path |
Not fully cross-checked here: upstream sparse/Numba dict fast paths, energy/rotation pipelines in PSA, or every PSA strain helper name-for-name—only the dense deformation-gradient path and Green–Lagrange strain built from F.
See tests/parity_upstream/README.md for dense vs sparse scope.
pip install -e ".[docs]"
sphinx-build -b html docs docs/_build/htmlpytest
pytest --cov=psax --cov-report=term-missingMarkers: unit, parity_numpy, parity_upstream, structure_io, slow (see pyproject.toml).
- Build a directed edge list from a dense weight matrix outside
jit(psax.graph.edges_from_dense_weight_matrix). deformation_gradient_per_site: (\mathbf{F}_i) from weighted bond vectors usingjax.ops.segment_sum, withmethod=solve/lstsq/svdand optionalreturn_diagnostics(condition numbers).- Strain / kinematics (
psax.core): Green–Lagrange and small-strain Lagrange, Euler strains, Cauchy–Green invariants, principal stretches, rotation axis/angle (polar decomposition). - Energy (
psax.core.energy): Saint Venant–Kirchhoff density and unit helpers. - Alignment (
psax.alignment): vendored soft Smith–Waterman / Needleman–Wunsch (sequence), Kabsch rigid superposition (structure). - Spatial (
psax.spatial): cylindrical/spherical coordinates, tensor rotation into cylindrical frame, rough effective volumes. - I/O (
psax.io):load_structurevia proxide (optional), B-factor export without BioPython. - AOT (
psax.stablehlo):jax.exportof the PSA strain pipeline at fixed shapes. - Synthetic tests (
psax.testing.forms): rods, twist/spin/radial deformations.
Legacy global least-squares (\mathbf{F}) (from the prxteinmpnn snapshot) lives under old_psa for regression tests only; it is not the published per-site PSA in Phys. Rev. X 14, 011042 (2024).
safe_map(psax.utils.safe_map): when batching many independent structures, pass a Pythonintbatch_sizeintosafe_map(not a traced value). Insidejit, prefer fixed-shape kernels and bucketing (seepsax.bucket) instead of compiling over Python loops.- Float precision: enable
jax_enable_x64if you need to match double-precision NumPy/SciPy baselines (tests set this inconftest.py). - Eigenvectors:
principal_strain_eigensystemfixes eigenvector signs deterministically (no random hemisphere flips). - RNG: core PSA is deterministic; if you add stochastic workflows, derive subkeys with
jax.random.fold_in(key, index).
The old_psa package is adapted from prxteinmpnn’s PSA helpers (global (\mathbf{F}) and dense linear weights). The per-site method implemented in psax.core follows Sartori-Lab/PSA and Sartori & Leibler, Phys. Rev. X 14, 011042 (2024).
If you use this software in published work, please cite the PSA paper (and this codebase if appropriate):
@article{PhysRevX.14.011042,
title = {Evolutionary Conservation of Mechanical Strain Distributions in Functional Transitions of Protein Structures},
author = {Sartori, Pablo and Leibler, Stanislas},
journal = {Phys. Rev. X},
volume = {14},
issue = {1},
pages = {011042},
year = {2024},
publisher = {American Physical Society},
doi = {10.1103/PhysRevX.14.011042},
url = {https://link.aps.org/doi/10.1103/PhysRevX.14.011042}
}Reference implementation: https://github.com/Sartori-Lab/PSA.
More detail: references/sartori_leibler_2024_prx.md.