<a href="https://colab.research.google.com/github/cody-mar10/protein_set_transformer/blob/main/examples/pst_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PST Inference

## 1. GPU Runtime
Ensure that you are using a runtime with GPU access:
`Runtime > Change runtime type` and choose a GPU runtime.

## 2. Setup software
Google Colab servers already have the latest version of `PyTorch`. We need to check the version of `PyTorch` and `CUDA` to properly install other `PyTorch` extension libraries required by `PST`.

All installation should take less than 1 minute.

In [1]:
import torch
from pathlib import Path

In [2]:
torch.__version__

'2.8.0+cu126'

The wheel URL should be in the form: `https://data.pyg.org/whl/torch-{TORCHVERSION}+{CUDA}.html`

Just change the command below to use the correct version information.

In [3]:
!uv pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.8.0+cu126.html

[2mUsing Python 3.12.11 environment at: /usr[0m
[2K[2mResolved [1m25 packages[0m [2min 897ms[0m[0m
[2K[2mPrepared [1m3 packages[0m [2min 675ms[0m[0m
[2K[2mInstalled [1m3 packages[0m [2min 38ms[0m[0m
 [32m+[39m [1mtorch-geometric[0m[2m==2.6.1[0m
 [32m+[39m [1mtorch-scatter[0m[2m==2.1.2+pt28cu126[0m
 [32m+[39m [1mtorch-sparse[0m[2m==0.6.18+pt28cu126[0m


Then install the `PST` library. *Note: must be `>=2.6.0` since that minor release unlocks the max python and PyTorch versions*

In [4]:
!uv pip install "ptn-set-transformer>=2.6.0"

[2mUsing Python 3.12.11 environment at: /usr[0m
[2K[2mResolved [1m84 packages[0m [2min 564ms[0m[0m
[2K[2mPrepared [1m13 packages[0m [2min 327ms[0m[0m
[2K[2mInstalled [1m13 packages[0m [2min 82ms[0m[0m
 [32m+[39m [1mboltons[0m[2m==25.0.0[0m
 [32m+[39m [1mcattrs[0m[2m==25.2.0[0m
 [32m+[39m [1mcolorlog[0m[2m==6.9.0[0m
 [32m+[39m [1mfair-esm[0m[2m==2.0.0[0m
 [32m+[39m [1mjsonargparse[0m[2m==4.41.0[0m
 [32m+[39m [1mlightning[0m[2m==2.5.5[0m
 [32m+[39m [1mlightning-cv[0m[2m==1.1.0[0m
 [32m+[39m [1mlightning-utilities[0m[2m==0.15.2[0m
 [32m+[39m [1moptuna[0m[2m==4.5.0[0m
 [32m+[39m [1mptn-set-transformer[0m[2m==2.6.0[0m
 [32m+[39m [1mpytorch-lightning[0m[2m==2.5.5[0m
 [32m+[39m [1mtorchmetrics[0m[2m==1.8.2[0m
 [32m+[39m [1mtypeshed-client[0m[2m==2.8.2[0m


## 3. Mount Google Drive (optional)
PST inference requires specially formatted HDF5 files that can be created and stored on your Google Drive account.

Your Google Drive account can be mounted to this runtime server so that your files are accessible.

You can store your data files there (and models if you want, but those can also be downloaded locally).

-----

Mounting your Google Drive will prompt authentication each time this notebook is ran:

In [5]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


After mounting your Google Drive, your files should be accessible at the path `/content/drive/MyDrive`.

For example, I have a demo folder that has a test FASTA file (which is the first 250 scaffolds encoding 8,955 proteins from the PST training set).

In [6]:
!ls -lh /content/drive/MyDrive/pst_demo

total 30M
-rw------- 1 root root  27M Oct  1 16:01 PST_embeddings.h5
-rw------- 1 root root 3.4M Oct  1 16:12 test.faa


## 4. Compute ESM2 embeddings (optional)

If you already have ESM2 embeddings for your protein sequences, then you can skip this step.

-----

**Advanced**: If you have a *large* number of proteins (ie 1M+), then you could benefit from splitting your large FASTA file into smaller segments (~100k sequences) to embed independently. This obviously benefits more if you have access to multiple GPUs to split the work across, but this will also help if you need to restart this runtime due to timelimits.

You will need to concatenate the embeddings back **in the same order** since the FASTA file should be sorted such that the proteins are in order for each scaffold (based on the position in the genome).

In [7]:
from pst.embed.model import ESM2Models
def esm_embed(file: str, esm_model: ESM2Models = ESM2Models.esm2_t30_150M, batch_size: int = 2048, outdir: Path = Path(".")):
    """
    Compute ESM2 embeddings for a FASTA file.

    Args:
        file (str): Path to FASTA file.
        esm_model (ESM2Models, optional): ESM2 model to use. Defaults to ESM2Models.esm2_t30_150M.
        batch_size (int, optional): Batch size in number of amino acids. Defaults to 2048.
    """
    from pst.embed import ModelArgs, TrainerArgs, embed

    model_args = ModelArgs(esm=esm_model, batch_size=batch_size)
    trainer_args = TrainerArgs()

    embed(input=Path(file), outdir=outdir, model_cfg=model_args, trainer_cfg=trainer_args)

This will embed the sequences to `{outdir}/{esm2_model_name}_results.h5`

This will download the relevant ESM2 model, but that could also be uploaded to your Google Drive. This would require you to adjust your `$TORCH_HOME` environment variable appropriately.

In [8]:
fasta_file = "/content/drive/MyDrive/pst_demo/test.faa"

Depending on the length of the proteins in your FASTA file, you may need to lower the `batch_size` (which is in units of amino acids).

This will by far take the most amount of time in the PST inference process (~1 min/2.5k proteins). If you have access to more powerful GPUs or can upgrade a free Colab notebook, you would greatly benefit from ESM2 inference needed for PST **if you have a large number of proteins**.

In [9]:
esm_embed(fasta_file, batch_size=4096)

INFO: Seed set to 111
INFO:lightning.fabric.utilities.seed:Seed set to 111


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t30_150M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t30_150M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D-contact-regression.pt


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [10]:
!ls -lh *.h5

-rw-r--r-- 1 root root 21M Oct  1 16:19 esm2_t30_150M_results.h5


### 4.1 Graph format embeddings
The embedding file needs to be reformatted to a graph format used by PST.

For protein FASTA files generated by prodigal/pyrodigal, the following should be sufficient. However, `OptionalArgs` also takes an optional strand mapping file that maps each protein to a strand `[-1, 1]` (in a tab-delimited format: `protein\tstrand`.

The `scaffold_map_file`maps scaffolds to genomes for multi-scaffold genomes in a tab-delimited format: `scaffold name\tgenome name`. `scaffold name` is defined as the part of the protein name before the numerical identifier: `scaffold name_PTNID`

This will create a new file `{outdir}/{esm2_model_name}_results.graphfmt.h5` that should be used as input to `PST` models.

In [11]:
def graphify(embeddings_file: str, fasta_file: str, scaffold_map_file: str | None = None):
    from pst.utils.graphify import IOArgs, OptionalArgs, to_graph_format

    io_args = IOArgs(file=Path(embeddings_file), fasta_file=Path(fasta_file))
    optional_args = OptionalArgs(scaffold_map_file=Path(scaffold_map_file) if scaffold_map_file is not None else None)

    to_graph_format(io_args, optional_args)

In [12]:
graphify(embeddings_file="esm2_t30_150M_results.h5", fasta_file=fasta_file)

In [13]:
!ls -lh *.h5

-rw-r--r-- 1 root root 21M Oct  1 16:19 esm2_t30_150M_results.graphfmt.h5
-rw-r--r-- 1 root root 21M Oct  1 16:19 esm2_t30_150M_results.h5


There are now extra fields in the HDF5 file that enable efficient access to all protein embeddings from each genome.

In [14]:
import tables as tb

In [15]:
esm_embeddings_file = "esm2_t30_150M_results.graphfmt.h5"
with tb.open_file(esm_embeddings_file) as fp:
    for node in fp.walk_nodes(classname="Array"):
        print(node)

/data (CArray(np.int64(8955), np.int64(640))shuffle, blosc:lz4(4)) ''
/ptr (CArray(np.int64(251),)shuffle, blosc:lz4(4)) ''
/sizes (CArray(np.int64(250),)shuffle, blosc:lz4(4)) ''
/strand (CArray(np.int64(8955),)shuffle, blosc:lz4(4)) ''


## 5. PST inference
With graph-formatted ESM2 embeddings, `PST` can be used for inference.

Only 2 different ESM2 embedding sizes were used to train PSTs:
- `esm2_t6_8M` -> `PST...__small`
- `esm2_t30_150M` -> `PST...__large`

So the correct `PST` model needs to be chosen based on the ESM2 embeddings generated.

There are also several PST models trained with different objectives and cross validation strategies:
- `PST-TL-P__small`
- `PST-TL-P__large`
- `PST-TL-T__small`
- `PST-TL-T__large`
- `PST-MLM-P__small`
- `PST-MLM-P__large`
- `PST-MLM-T__small`
- `PST-MLM-T__large`

I recommend starting with `PST-TL-P` models.

### 5.1 Download PST model

The following code uses the `PST` API to download the models from the DRYAD repository [https://doi.org/10.5061/dryad.d7wm37q8w](https://doi.org/10.5061/dryad.d7wm37q8w).

For simplicity, the model checkpoint is downloaded to the current directory (`/content/drive/`)

In [16]:
# choices:
# [
#     "PST-TL-P__small",
#     "PST-TL-P__large",
#     "PST-TL-T__small",
#     "PST-TL-T__large",
#     "PST-MLM", # <- NOTE: this will create a subdirectory, so keep that in mind if using
# ]
def download_model(model: str | list[str]):
    from pst.utils.download import DryadDownloader
    from pst.utils.cli.download import ManuscriptDataArgs, ClusterArgs, EmbeddingsArgs, ModelArgs

    if isinstance(model, str):
        model = [model]

    downloader = DryadDownloader(
        manuscript=ManuscriptDataArgs(),
        cluster=ClusterArgs(),
        embeddings=EmbeddingsArgs(),
        model=ModelArgs(choices=model),
        all=False,
        outdir=Path(".")
    )

    downloader.download()

download_model("PST-TL-P__large")

Downloading the following 2 files to .
	README.md
	PST-TL-P__large.ckpt.gz


[0/2] README.md: 100%|██████████| 28.7k/28.7k [00:00<00:00, 915kB/s]
[1/2] PST-TL-P__large.ckpt.gz: 100%|██████████| 221M/221M [00:03<00:00, 70.2MB/s]


[2/2] Download finished.
Decompressing all tarballs, zip files, and gzipped files.


In [17]:
!ls -lh

total 286M
drwx------ 6 root root 4.0K Oct  1 16:13 drive
-rw-r--r-- 1 root root  21M Oct  1 16:19 esm2_t30_150M_results.graphfmt.h5
-rw-r--r-- 1 root root  21M Oct  1 16:19 esm2_t30_150M_results.h5
-rw-r--r-- 1 root root 245M Oct  1 16:19 PST-TL-P__large.ckpt
-rw-r--r-- 1 root root  29K Oct  1 16:19 README.md
drwxr-xr-x 1 root root 4.0K Sep 29 13:37 sample_data


Now use the `model_inference` function to generate PST embeddings.

In [18]:
import pst
from pst.predict import model_inference
from pst.predict.predict import PredictArgs, AcceleratorOpts

Note: The `output` file should be saved to your Google Drive account so that it is permanent. Otherwise, it will be deleted if stored locally on this server.

In [19]:
esm_embeddings_file = Path("esm2_t30_150M_results.graphfmt.h5")
model_checkpoint = Path("PST-TL-P__large.ckpt") # CHANGE TO MODEL DOWNLOADED

 # NOTE: you will want this to be saved to your Google Drive so it will persist after this notebook ends
output = Path("/content/drive/MyDrive/pst_demo/PST_embeddings.h5")

results = model_inference(
    model_type=pst.ProteinSetTransformer,
    file=esm_embeddings_file,
    predict=PredictArgs(checkpoint=model_checkpoint, output=output),
    lazy=True,
    accelerator=AcceleratorOpts.gpu,

    # OPTIONAL: if you want inspect results in this notebook
    # The results are always saved to the file above
    return_predictions=True,
)



100%|██████████| 8/8 [00:00<00:00,  9.26it/s]


For the purposes of this notebook, I returned the predictions to inspect.

Notice the shape of the first (batch) dimension of each tensor:
- `protein` and `attn` are the same as the number of proteins
- `genome` is the same as the number of scaffolds in this case
  - Note: all the genomes in the test file were single scaffold, so `genome` and `scaffold` are interchangeable.
  - For datasets that include multi-scaffold genomes, there will also be a `scaffold` field **if the graph-formatted HDF5 file includes a `scaffold_label` field that maps each scaffold to a unique integer corresponding to the genome**.

In [20]:
for k, v in results.items():
    print(k, v.shape)

protein torch.Size([8955, 800])
attn torch.Size([8955, 4])
genome torch.Size([250, 800])


In [21]:
!ls -lh /content/drive/MyDrive/pst_demo

total 30M
-rw------- 1 root root  27M Oct  1 16:19 PST_embeddings.h5
-rw------- 1 root root 3.4M Oct  1 16:12 test.faa


Finally, we can inspect the HDF5 file to see that the fields present in the file are identical to the returned results.

In [22]:
with tb.open_file(output) as fp:
    for node in fp.walk_nodes(classname="Array"):
        print(node)

/attn (EArray(np.int64(8955), np.int64(4))shuffle, blosc:lz4(4)) ''
/ctx_ptn (EArray(np.int64(8955), np.int64(800))shuffle, blosc:lz4(4)) ''
/genome (EArray(np.int64(250), np.int64(800))shuffle, blosc:lz4(4)) ''
