## Instructions for Performing Inference on HEST-Bench with STPath

This tutorial will guide you to:
* Prepare the data (AnnData object and visual embeddings) from a selected dataset in HEST-Bench for inference.
* Run inference using STPath and compute the corresponding evaluation metrics.

It is important to note that there are differences between the data in [HEST-1K](https://huggingface.co/datasets/MahmoodLab/hest/tree/main) and [HEST-Bench](https://huggingface.co/datasets/MahmoodLab/hest-bench/tree/main). Specifically, HEST-Bench is a subset of HEST-1K. The tissue segmentation in HEST-1K has been refined in its latest release (see this [issue](https://github.com/mahmoodlab/HEST/issues/68) for details). In our experiments in the paper, we use the same dataset IDs and highly variable genes (HVGs) defined in HEST-Bench, but the actual samples are taken from HEST-1K. For simplicity, however, this tutorial will demonstrate the workflow using the data provided directly by HEST-Bench.

### Download dataset from HuggingFace

We here specifically download the LUNG dataset from HEST-Bench.

In [7]:
import os

In [2]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="MahmoodLab/hest-bench", 
    repo_type='dataset', 
    local_dir='/home/ti.huang/scratch/hest-bench',
    allow_patterns=['LUNG/']
)

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

TENX118.png:   0%|          | 0.00/1.43M [00:00<?, ?B/s]

TENX141.png:   0%|          | 0.00/2.40M [00:00<?, ?B/s]

TENX118.h5:   0%|          | 0.00/301M [00:00<?, ?B/s]

TENX141.h5ad:   0%|          | 0.00/6.84M [00:00<?, ?B/s]

TENX141.h5:   0%|          | 0.00/505M [00:00<?, ?B/s]

test_0.csv:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

TENX118.h5ad:   0%|          | 0.00/4.80M [00:00<?, ?B/s]

test_1.csv:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

train_0.csv:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

train_1.csv:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

var_50genes.json:   0%|          | 0.00/452 [00:00<?, ?B/s]

'/scratch/ti.huang/hest-bench'

In [10]:
# adata includes the spatial transcriptomics profiles
# patches includes the image patches
# splits includes the multi-fold split, which is not used in our setup
# var_50genes.json is the top 50 HVG list
os.listdir('/home/ti.huang/scratch/hest-bench/LUNG')

['adata', 'patches', 'splits', 'var_50genes.json']

In [11]:
# LUNG dataset includes two samples
os.listdir('/home/ti.huang/scratch/hest-bench/LUNG/adata')

['TENX118.h5ad', 'TENX141.h5ad']

### Initialze the Gigapath encoder

Please first download the pretrained weight of Gigapath (`https://huggingface.co/prov-gigapath/prov-gigapath/blob/main/pytorch_model.bin`) in `your_weight_dir/gigapath`.

In [3]:
from stpath.app.preprocess.hest import LazyEncoder

# weight_root should be `your_weight_dir`
lazy_enc = LazyEncoder("gigapath", weights_root="/home/ti.huang/project/single_cell_dataset/weights_root")
gigapath_encoder, img_transforms = lazy_enc.get_model(device=0)  # set the gpu device as 0

Encode each image patch with Gigapath and save them in `embeddings` folder.

In [17]:
import torch
from stpath.app.preprocess.hest import embed_tiles
from stpath.hest_utils.st_dataset import H5TileDataset

batch_size = 128

for sample_id in ['TENX118', 'TENX141']:
    tile_h5_path = os.path.join('/home/ti.huang/scratch/hest-bench/LUNG', "patches", f"{sample_id}.h5")
    assert os.path.isfile(tile_h5_path), f"Tile file {tile_h5_path} does not exist"
    embedding_dir = os.path.join('/home/ti.huang/scratch/hest-bench/LUNG/embeddings', sample_id, lazy_enc.name, 'fp32')
    os.makedirs(embedding_dir, exist_ok=True)
    embed_path = os.path.join(embedding_dir, f'{sample_id}.h5')
    
    tile_dataset = H5TileDataset(tile_h5_path, chunk_size=batch_size, img_transform=img_transforms)
    tile_dataloader = torch.utils.data.DataLoader(
                       tile_dataset, 
                       batch_size=1, 
                       shuffle=False,
                       num_workers=1
                    )
    _ = embed_tiles(sample_id, tile_dataloader, encoder, embed_path, device=0)

Embedding Tiles TENX118: 100%|██████████████████████████████████████| 16/16 [00:37<00:00,  2.32s/it]
Embedding Tiles TENX141: 100%|██████████████████████████████████████| 26/26 [01:02<00:00,  2.41s/it]


In [18]:
os.listdir('/home/ti.huang/scratch/hest-bench/LUNG/embeddings')

['TENX118', 'TENX141']

### Perform inference with STPath

Initialize the STPath inference agent using the pretrained weights downloaded from `https://huggingface.co/tlhuang/STPath/tree/main`.

In [29]:
from stpath.app.pipeline.inference import STPathInference

agent = STPathInference(
    gene_voc_path='/home/ti.huang/STPath/utils_data/symbol2ensembl.json',
    model_weight_path='/home/ti.huang/project/stfm/stpath/backup/stfm.pth', 
    device=0
)

n_genes: 38984, n_tech: 5, n_species: 6, n_organs: 25, n_cancer_annos: 5, n_domain_annos: 10
Model loaded from /home/ti.huang/project/stfm/stpath/backup/stfm.pth


Perform inference and save the predicted expressions of the HVGs

In [56]:
import json
import anndata as ad
import numpy as np

from scipy.stats import pearsonr
from stpath.hest_utils.st_dataset import load_adata
from stpath.hest_utils.file_utils import read_assets_from_h5

all_pred, all_gt = [], []
for sample_id in ['TENX118', 'TENX141']:
    source_dataroot = '/home/ti.huang/scratch/hest-bench/LUNG/'  # the root directory of the STPath repository
    with open(os.path.join(source_dataroot, "var_50genes.json")) as f:
        hvg_list = json.load(f)['genes']

    data_dict, _ = read_assets_from_h5(os.path.join(source_dataroot, f"embeddings/{sample_id}/gigapath/fp32/{sample_id}.h5"))  # load the data from the h5 file
    coords = data_dict["coords"]
    embeddings = data_dict["embeddings"]
    barcodes = data_dict["barcodes"].flatten().astype(str).tolist()
    adata = ad.read_h5ad(os.path.join(source_dataroot, f"adata/{sample_id}.h5ad"))[barcodes, :]
    
    # The return pred_adata includes the expressions of the genes in hvg_list, which is a list of highly variable genes.
    pred_adata = agent.inference(
        coords=coords, 
        img_features=embeddings, 
        organ_type="Lung", 
        tech_type=None,
        save_gene_names=hvg_list  # we only need the highly variable genes for evaluation
    )

    gt = np.log1p(adata[:, hvg_list].X.toarray())  # sparse -> dense

    all_gt.append(gt)
    all_pred.append(pred_adata.X)

Starting inference...
Return results...
Starting inference...
Return results...


Concatenate the expressions of HVGs across the samples (TENX118, TENX141) and compute the average Pearson correlation. Note that the resulting metric (0.607) is slightly higher than the value reported in the paper (0.559 in Figure 2(c)), due to differences between the samples in HEST-1K and HEST-Bench.

In [57]:
all_gt = np.concatenate(all_gt, axis=0)
all_pred = np.concatenate(all_pred, axis=0)
# calculate the Pearson correlation coefficient between the predicted and ground truth gene expression
all_pearson_list = []
# go through each gene in the highly variable genes list
for i in range(len(hvg_list)):
    pearson_corr, _ = pearsonr(all_gt[:, i], all_pred[:, i])
    all_pearson_list.append(pearson_corr.item())
print(f"Pearson correlation for LUNG dataset: {np.mean(all_pearson_list)}")  # 0.1562

Pearson correlation for LUNG dataset: 0.6075147917866707
