Skip to content

maraxen/psax

Repository files navigation

psax

JAX implementation of Protein Strain Analysis (PSA): per-site weighted finite strain and deformation gradients, consistent with the reference implementation Sartori-Lab/PSA and the original paper:

Pablo Sartori and Stanislas Leibler, Evolutionary conservation of mechanical strain distributions in functional transitions of protein structures, Phys. Rev. X 14, 011042 (2024). DOI: 10.1103/PhysRevX.14.011042 (APS: link.aps.org).

BibTeX and preprint links: references/sartori_leibler_2024_prx.md.

Install

pip install psax
# editable dev
pip install -e ".[dev]"
# optional: parity tests against upstream PSA (install VCS pin separately — not a PyPI extra)
pip install "psa @ git+https://github.com/Sartori-Lab/PSA.git@a55f44eea3c8165d618cfd607f1c3cebe7535cbb"

Requires Python 3.11+.

PyPI artifacts vs install failures

This project publishes one pure-Python wheel (py3-none-any) and a source distribution, not a matrix of platform wheels. A successful pip install psax still has to install JAX, which pulls jaxlib with platform- and Python-version-specific binaries from PyPI.

If installation “fails on many platforms” or pip spends a long time backtracking, the usual cause is no matching jaxlib wheel for that OS/CPU/Python combo (less common architectures, very new Python, or Windows without the supported stack). See the official JAX installation guide and install a JAX version that supports your environment before or after pip install psax. On Linux x86_64 and recent macOS/arm64 with Python 3.11–3.13, pip install psax typically works out of the box.

Structure I/O (proxide)

proxide is not on PyPI. To run psax run or run_pairwise_psa_from_structures, install proxide from a local checkout in the same environment:

pip install /path/to/proxide

If this repo lives in a uv workspace next to a local proxide member, add a workspace source mapping (see uv workspace sources):

[tool.uv.sources]
proxide = { workspace = true }

PyPI and releases

Workflow Purpose
.github/workflows/ci.yml Tests, docs, upstream parity, wheel smoke on main
.github/workflows/publish.yml Build with python -m build and upload to PyPI on GitHub Release (OIDC trusted publishing)

Configure trusted publishing on PyPI for this repository and workflow, and add a GitHub Environment named pypi if you use protection rules. Tag releases (e.g. v0.1.0) and publish a GitHub Release to trigger the workflow (or run it manually via workflow_dispatch).

StableHLO / jax.export (AOT)

For deployment and compiler toolchains, psax exposes a thin export layer in psax.stablehlo: fixed static shapes (N, 3) and (E,) directed edges, with method / ridge_eps / rcond fixed at export time. Serialized artifacts use JAX’s export format (VHLO / calling-convention version) and are tied to the JAX version you built with—see the JAX export documentation and the OpenXLA StableHLO + JAX tutorial.

CLI:

psax emit-stablehlo --n 128 --edges 4096 --out psa_strain.bin

(run_pairwise_psa is not the export entrypoint: it builds weights outside of jit.)

Batching ref / design pairs (vmap)

  • Aligned batchB independent pairs (x_ref[b], x_def[b]) with the same graph (shared edges_i, edges_j, edge_weights): use psax.batch.deformation_gradient_per_site_vmap_pairs or run_pairwise_psa_batched_shared_graph when the dense weight mask (N, N) is shared.
  • Ref × design grid — all combinations (b_ref, b_design) with fixed shared edges: psax.batch.deformation_gradient_per_site_grid_ref_design.
  • Per-sample edges — same batch size and edge count E: deformation_gradient_per_site_batched_edges / edge_batch_mode="per_pair".
  • Different E per sample needs padding/masking or a Python loop with psax.utils.safe_map (see docs / bucket placeholder).

Quickstart (two coordinate sets → per-site F → strain)

import jax.numpy as jnp
from psax.run.pipeline import run_pairwise_psa

x_ref = jnp.array(...)  # (N, 3)
x_def = jnp.array(...)  # (N, 3)
out = run_pairwise_psa(x_ref, x_def, r_inner=6.0, r_outer=8.0)
F = out.deformation_gradient       # (N, 3, 3)
E = out.green_lagrange_strain      # (N, 3, 3)
lam = out.principal_strain_eigenvalues  # (N, 3)

With structures on disk (requires proxide):

from psax.run.pipeline import run_pairwise_psa_from_structures

out = run_pairwise_psa_from_structures("ref.pdb", "def.pdb", align_kabsch=False)

CLI

psax --version   # or -V
psax version
psax run --help
psax emit-stablehlo --help
psax run ref.pdb def.pdb -o out.npz --bfactors-out colored.pdb --template-pdb ref.pdb
psax export template.pdb out.pdb --values "1.0,2.0,3.0"

The CLI dispatches to psax.run.pipeline, psax.stablehlo, and psax.io.export.

Parity & limitations

What is validated How
JAX vs in-repo NumPy PSA loop tests/test_core_parity.py (parity_numpy)
JAX vs installed psa.elastic.deformation_gradient (dense weights, fp64) tests/parity_upstream/ (parity_upstream), install psa from git (see Install)
Symmetrized D + solvers vs NumPy reference Covered indirectly when ridge_eps=0 matches upstream dense path

Not fully cross-checked here: upstream sparse/Numba dict fast paths, energy/rotation pipelines in PSA, or every PSA strain helper name-for-name—only the dense deformation-gradient path and Green–Lagrange strain built from F.

See tests/parity_upstream/README.md for dense vs sparse scope.

Documentation (Sphinx)

pip install -e ".[docs]"
sphinx-build -b html docs docs/_build/html

Testing & coverage

pytest
pytest --cov=psax --cov-report=term-missing

Markers: unit, parity_numpy, parity_upstream, structure_io, slow (see pyproject.toml).

API sketch

  • Build a directed edge list from a dense weight matrix outside jit (psax.graph.edges_from_dense_weight_matrix).
  • deformation_gradient_per_site: (\mathbf{F}_i) from weighted bond vectors using jax.ops.segment_sum, with method= solve / lstsq / svd and optional return_diagnostics (condition numbers).
  • Strain / kinematics (psax.core): Green–Lagrange and small-strain Lagrange, Euler strains, Cauchy–Green invariants, principal stretches, rotation axis/angle (polar decomposition).
  • Energy (psax.core.energy): Saint Venant–Kirchhoff density and unit helpers.
  • Alignment (psax.alignment): vendored soft Smith–Waterman / Needleman–Wunsch (sequence), Kabsch rigid superposition (structure).
  • Spatial (psax.spatial): cylindrical/spherical coordinates, tensor rotation into cylindrical frame, rough effective volumes.
  • I/O (psax.io): load_structure via proxide (optional), B-factor export without BioPython.
  • AOT (psax.stablehlo): jax.export of the PSA strain pipeline at fixed shapes.
  • Synthetic tests (psax.testing.forms): rods, twist/spin/radial deformations.

Legacy global least-squares (\mathbf{F}) (from the prxteinmpnn snapshot) lives under old_psa for regression tests only; it is not the published per-site PSA in Phys. Rev. X 14, 011042 (2024).

JAX notes

  • safe_map (psax.utils.safe_map): when batching many independent structures, pass a Python int batch_size into safe_map (not a traced value). Inside jit, prefer fixed-shape kernels and bucketing (see psax.bucket) instead of compiling over Python loops.
  • Float precision: enable jax_enable_x64 if you need to match double-precision NumPy/SciPy baselines (tests set this in conftest.py).
  • Eigenvectors: principal_strain_eigensystem fixes eigenvector signs deterministically (no random hemisphere flips).
  • RNG: core PSA is deterministic; if you add stochastic workflows, derive subkeys with jax.random.fold_in(key, index).

Prior art

The old_psa package is adapted from prxteinmpnn’s PSA helpers (global (\mathbf{F}) and dense linear weights). The per-site method implemented in psax.core follows Sartori-Lab/PSA and Sartori & Leibler, Phys. Rev. X 14, 011042 (2024).

Citation

If you use this software in published work, please cite the PSA paper (and this codebase if appropriate):

@article{PhysRevX.14.011042,
  title = {Evolutionary Conservation of Mechanical Strain Distributions in Functional Transitions of Protein Structures},
  author = {Sartori, Pablo and Leibler, Stanislas},
  journal = {Phys. Rev. X},
  volume = {14},
  issue = {1},
  pages = {011042},
  year = {2024},
  publisher = {American Physical Society},
  doi = {10.1103/PhysRevX.14.011042},
  url = {https://link.aps.org/doi/10.1103/PhysRevX.14.011042}
}

Reference implementation: https://github.com/Sartori-Lab/PSA.
More detail: references/sartori_leibler_2024_prx.md.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages