Shared ProstT5 inference library for phold and baktfold.
pholdlib provides the common pLM Foldseek 3Di inference layer — this includes loading the ProstT5 protein language model & running the CNN prediction head and writing 3Di outputs and probabilities — so that this logic is consistent and not duplicated.
If you use pholdlib in your research, please cite:
Bouras G., Grigson S.R., Mirdita M., Heinzinger M., Papudeshi B.,
Mallawaarachchi V., Green R., Kim S.R., Mihalia V., Psaltis A.J.,
Wormald P-J., Vreugde S., Steinegger M., Edwards R.A.Protein Structure Informed Bacteriophage Genome Annotation with Phold
Nucleic Acids Research, Volume 54, Issue 1, 13 January 2026
https://doi.org/10.1093/nar/gkaf1448
MIT
pip install pholdlibWith test dependencies:
pip install "pholdlib[test]"from pathlib import Path
import torch
from pholdlib.prostt5.model import get_T5_model, load_predictor
from pholdlib.prostt5.inference import run_prostt5_inference
from pholdlib.prostt5.output import SS_MAPPING, write_probs, write_fail_ids
# 1. Load ProstT5 + CNN predictor head
model, vocab, device = get_T5_model(
model_dir=Path("/path/to/model_cache"),
model_name="Rostlab/ProstT5_fp16",
cpu=True,
threads=4,
)
predictor = load_predictor("/path/to/cnn_checkpoint.pt", device)
# 2. Prepare sequences: list of (id, sequence, length), sorted descending by length
seqs = [
("protein_A", "MKTIIALSYIFCLVFA", 16),
("protein_B", "MASMTGGQQMGRDLY", 15),
]
# 3. Run inference
predictions, emb_residue, emb_protein, fail_ids = run_prostt5_inference(
seq_dict=seqs,
model=model,
vocab=vocab,
predictor=predictor,
device=device,
output_probs=True,
save_per_residue_embeddings=False,
save_per_protein_embeddings=False,
)
# 4. Decode predictions
for seq_id, (pred_array, mean_prob, all_probs) in predictions.items():
threedi = "".join(SS_MAPPING[int(c)] for c in pred_array)
print(f"{seq_id}: {threedi} (mean confidence {mean_prob:.1f}%)")| Function / Class | Description |
|---|---|
get_T5_model(model_dir, model_name, cpu, threads, check_fn, zenodo_fn) |
Load ProstT5 encoder + tokenizer; optionally download via HuggingFace with a Zenodo fallback. Returns (model, vocab, device). |
load_predictor(checkpoint_path, device) |
Load a CNN prediction head from a .pt checkpoint or state-dict file. Returns the CNN in eval mode. |
CNN |
Two-layer Conv2d head: (B, L, 1024) embeddings → (B, 20, L) 3Di logits. |
toCPU(tensor) |
Detach a tensor, move to CPU, and return a NumPy array. |
| Function | Description |
|---|---|
run_prostt5_inference(seq_dict, model, vocab, predictor, device, ...) |
Batch inference over a list of (id, sequence, length) tuples. Returns (predictions, embeddings_per_residue, embeddings_per_protein, fail_ids). |
predictions is a dict {seq_id: (pred_array, mean_prob, all_prob)}:
pred_array—np.bytearray of 3Di class indices (0–19; 20 = masked).mean_prob— mean per-residue confidence, 0–100.all_prob—float32array of shape(1, L), orNonewhenoutput_probs=False.
Key batching parameters:
| Parameter | Default | Meaning |
|---|---|---|
max_residues |
100,000 | Max total residues per batch before flush |
max_seq_len |
30,000 | Any single sequence longer than this triggers an immediate flush |
max_batch |
10,000 | Max sequences per batch |
| Function / Constant | Description |
|---|---|
SS_MAPPING |
{0…20: char} — maps 3Di class index to single-letter code (20 → 'X' masked). |
mask_low_confidence_aa(sequence, scores, threshold) |
Replace residues whose confidence score is below threshold with 'X'. scores should be a shape-(1, L) array or equivalent. |
write_probs(predictions, output_path_mean, output_path_all, original_keys) |
Write mean probabilities to a CSV and per-residue probabilities to a JSONL file. |
write_fail_ids(fail_ids, out_path) |
Write a list of failed sequence IDs to a TSV. No-ops on an empty list. |
| Function / Constant | Description |
|---|---|
PROSTT5_MD5_DICTIONARY |
MD5 hashes for Rostlab/ProstT5_fp16 model files. |
check_prostT5_download(model_dir, model_name, md5_dict, model_subdir) |
Returns True if the model is absent or corrupt and needs to be (re)downloaded. |
download_zenodo_prostT5(model_dir, logdir, threads, backup_url, backup_md5, backup_tarball) |
Download and extract a ProstT5 tarball from a Zenodo backup URL. |
check_fn/zenodo_fnhooks —get_T5_modelaccepts optional callables for model integrity checking and Zenodo fallback download. Tool-specific implementations (e.g. phold'scheck_prostT5_download) are passed in; pholdlib ships the baseRostlab/ProstT5_fp16MD5 dict and a genericdownload_zenodo_prostT5that callers configure with their own backup URL.- FP32 on CPU —
get_T5_modelautomatically casts the model tofloat()whencpu=Trueto avoid errors with half-precision operations on CPU. - Sequence pre-processing —
run_prostt5_inferencereplacesU,Z, andOwithXinternally; callers do not need to sanitise sequences beforehand.
# Install test dependencies
pip install "pholdlib[test]"
# Unit tests — no model download, no GPU required (completes in ~1 s)
pytest tests/
# Integration tests — ProstT5 is downloaded automatically on first run
# (~1.6 GB fp16) into tests/test_data/model_cache/
pytest tests/ --run_integration
# Point at a pre-existing model cache
pytest tests/ --run_integration --model_dir /path/to/model_cache
# Use the real trained phold CNN weights (ships with the phold repo)
pytest tests/ --run_integration \
--checkpoint /path/to/phold/src/phold/cnn/cnn_chkpnt/model.pt
# With GPU + multiple threads
pytest tests/ --run_integration --gpu_available --threads 8| File | Contents |
|---|---|
tests/test_output.py |
SS_MAPPING, mask_low_confidence_aa, write_probs, write_fail_ids |
tests/test_databases.py |
MD5 dict structure, _calc_md5, check_prostT5_download, download_zenodo_prostT5 |
tests/test_model.py |
CNN forward pass and architecture, toCPU, load_predictor |
tests/test_integration.py |
Model loading, tokenizer, predictor, batch inference, per-residue/protein embeddings, output round-trip |
Unit tests load module files directly via importlib to avoid triggering the transformers import chain in pholdlib/prostt5/__init__.py, so they run without a working HuggingFace / scipy install. conftest.py also mocks transformers automatically if it cannot be imported, keeping the PyTorch-only CNN and toCPU tests green in broken environments.
Integration tests are gated behind --run_integration. When no --checkpoint is supplied they use a randomly initialised CNN predictor — predictions are meaningless but shapes, types, and probability ranges are all verified. Pass --checkpoint to run against the real trained weights.