This part of the tutorial will guide you how to obtain the ESM-2 embedding of the gene and how to use stImpute for spatial transcriptome data prediction.

### Query and Download

We need to download the protein expression of the gene under [UniProt](https://www.uniprot.org/) (here we use `P61922`, `Q3UJF9` and `Q91ZH7` as an example). We can do this efficiently with the following script:

In [1]:
from urllib.request import urlopen

In [2]:
gene_id_list = ['Q3UJF9', 'P61922', 'Q91ZH7']

In [3]:
f = open('gene_protein.txt', 'w')

for gene_id in gene_id_list:
    URL = urlopen('https://rest.uniprot.org/uniprotkb/' + gene_id + '.fasta')
    albumen = str(URL.read()).split('\\n')[:-1]
    f.write('>' + gene_id + '\n')
    f.write(''.join(albumen[1:]) + '\n')
f.close()

### ESM-2 Embedding

Installed the ESM-2 model (you can refer to [ESM-2](https://github.com/facebookresearch/esm)), then **enter the following command:**

CUDA_VISIBLE_DEVICES=0 python scripts/extract.py esm2_t36_3B_UR50D gene_protein.txt examples/data/some_proteins_emb_esm2 --repr_layers 36 --include mean per_tok

After running successfully, the gene embedding file (in this case, `P61922.pt`, `Q3UJF9.pt` and `Q91ZH7.pt`) is generated in the `esm-main/examples/data/some_proteins_emb_esm2/` directory. We put them together:

In [4]:
import os
import torch
import pickle
import pandas as pd

df = pd.DataFrame()

for path, dir_lst, file_lst in os.walk(r'examples/data/some_proteins_emb_esm2'):
    for file_name in file_lst:
        data = torch.load(open(os.path.join(path, file_name), 'rb'))
        df.insert(df.shape[1], data['label'], data['representations'][36][-1].numpy())

df

Unnamed: 0,Q3UJF9,P61922,Q91ZH7
0,0.134797,0.025256,-0.149445
1,-0.053382,-0.031979,-0.187018
2,-0.012659,0.004654,-0.018185
3,0.043372,-0.081000,-0.261306
4,-0.072901,0.091502,-0.136699
...,...,...,...
2555,-0.071132,-0.050865,0.162222
2556,-0.076116,-0.019745,-0.090902
2557,0.065875,0.103102,0.006431
2558,-0.208240,-0.061796,-0.100855


In [5]:
pickle.dump(df, open('emb.pkl', 'wb'))

### Spatial transcriptomic data prediction

Next, we take osmFISH_Zeisel as an example to show how to use stImpute for spatial transcriptome data prediction.

In [1]:
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import anndata as ad
from sklearn import metrics
from sklearn.model_selection import KFold
from scipy.spatial.distance import cosine
from model import *
warnings.filterwarnings('ignore')

seed=42
set_seed(seed)

Load data

In [2]:
st_adata = sc.read_h5ad('/public/home/syj/stImpute/dataset/st-seq/osmFISH.h5ad')
sc_adata = sc.read_h5ad('/public/home/syj/stImpute/dataset/scRNA-seq/Zeisel.h5ad')
emb_file = 'embed/osmFISH_emb.pkl' # emb_file = None

In [3]:
raw_spatial_df = pd.DataFrame(st_adata.X, columns=st_adata.var_names)
raw_scrna_df = pd.DataFrame(sc_adata.X, columns=sc_adata.var_names)
raw_shared_gene = np.intersect1d(raw_spatial_df.columns, raw_scrna_df.columns)

all_pred_res = pd.DataFrame(np.zeros((raw_spatial_df.shape[0], raw_shared_gene.shape[0])), columns=raw_shared_gene)
all_reliable_res = pd.DataFrame(np.zeros((1, raw_shared_gene.shape[0])), columns=raw_shared_gene)

5-fold cross validation

In [4]:
idx = 1
kf = KFold(n_splits=5, shuffle=True, random_state=0)
kf.get_n_splits(raw_shared_gene)
for train_ind, test_ind in kf.split(raw_shared_gene):    
    print("\n===== Fold %d =====\nNumber of train genes: %d, Number of test genes: %d" % (idx, len(train_ind), len(test_ind)))
    train_gene = raw_shared_gene[train_ind]
    test_gene  = raw_shared_gene[test_ind]
    spatial_df = raw_spatial_df[train_gene]
    scrna_df   = raw_scrna_df
    all_pred_res[test_gene], all_reliable_res[test_gene] = stImpute(spatial_df, scrna_df, train_gene, test_gene, seed=seed, emb_file=emb_file)
    idx += 1


===== Fold 1 =====
Number of train genes: 26, Number of test genes: 7
ST data:            3405 cells * 26 genes
scRNA-seq data:     1691 cells * 15075 genes
7 genes to be predicted



Embedding        : 100%|██████████| 30/30 [00:10<00:00,  2.94it/s]
EM_training      : 100%|██████████| 2/2 [00:05<00:00,  2.59s/it]
Reliable training: 100%|██████████| 100/100 [00:00<00:00, 425.00it/s]



===== Fold 2 =====
Number of train genes: 26, Number of test genes: 7
ST data:            3405 cells * 26 genes
scRNA-seq data:     1691 cells * 15075 genes
7 genes to be predicted



Embedding        : 100%|██████████| 30/30 [00:09<00:00,  3.21it/s]
EM_training      : 100%|██████████| 2/2 [00:05<00:00,  2.57s/it]
Reliable training: 100%|██████████| 100/100 [00:00<00:00, 403.92it/s]



===== Fold 3 =====
Number of train genes: 26, Number of test genes: 7
ST data:            3405 cells * 26 genes
scRNA-seq data:     1691 cells * 15075 genes
7 genes to be predicted



Embedding        : 100%|██████████| 30/30 [00:09<00:00,  3.29it/s]
EM_training      : 100%|██████████| 2/2 [00:04<00:00,  2.50s/it]
Reliable training: 100%|██████████| 100/100 [00:00<00:00, 487.20it/s]



===== Fold 4 =====
Number of train genes: 27, Number of test genes: 6
ST data:            3405 cells * 27 genes
scRNA-seq data:     1691 cells * 15075 genes
6 genes to be predicted



Embedding        : 100%|██████████| 30/30 [00:09<00:00,  3.09it/s]
EM_training      : 100%|██████████| 2/2 [00:05<00:00,  2.62s/it]
Reliable training: 100%|██████████| 100/100 [00:00<00:00, 386.50it/s]



===== Fold 5 =====
Number of train genes: 27, Number of test genes: 6
ST data:            3405 cells * 27 genes
scRNA-seq data:     1691 cells * 15075 genes
6 genes to be predicted



Embedding        : 100%|██████████| 30/30 [00:09<00:00,  3.09it/s]
EM_training      : 100%|██████████| 2/2 [00:05<00:00,  2.56s/it]
Reliable training: 100%|██████████| 100/100 [00:00<00:00, 447.65it/s]


### Calculating metrics

Cosine similarity and MSE:

In [5]:
print('result: ')
print('gene-wise cosine: %.2f' % (np.median(calc_all(raw_spatial_df[raw_shared_gene], all_pred_res, cal='cosine'))))
print('gene-wise mse: %.2f' % (np.median(calc_all(raw_spatial_df[raw_shared_gene], all_pred_res, cal='mse'))))
print('cell-wise cosine: %.2f' % (np.median(calc_all(raw_spatial_df[raw_shared_gene].T, all_pred_res.T, cal='cosine'))))
print('cell-wise mse: %.2f' % (np.median(calc_all(raw_spatial_df[raw_shared_gene].T, all_pred_res.T, cal='mse'))))

result: 
gene-wise cosine: 0.79
gene-wise mse: 0.82
cell-wise cosine: 0.76
cell-wise mse: 1.03


Top 50% gene-wise Cosine similarity:

In [6]:
top50_ind = np.argsort(-all_reliable_res.values.squeeze())[:raw_shared_gene.shape[0]//2]
top50_gene_cos = [1-cosine(all_pred_res[gene], raw_spatial_df[gene]) for gene in raw_shared_gene[top50_ind]]
print('top50%% gene-wise cosine: %.2f' % (np.median(top50_gene_cos)))

top50% gene-wise cosine: 0.82


Cluster result:

In [7]:
imp_adata = ad.AnnData(all_pred_res)
try:
    imp_adata.obs['Cluster'] = st_adata.obs['Cluster'].values.astype('category')
    sc.pp.neighbors(imp_adata, n_neighbors=10, n_pcs=40)
    sc.tl.umap(imp_adata)
    sc.tl.leiden(imp_adata)

    labels_true = []
    dic = dict()
    for str in st_adata.obs['Cluster']:
        try:
            labels_true.append(dic[str])
        except:
            dic[str] = len(dic)
            labels_true.append(dic[str])
    labels_true = np.array(labels_true)
    labels_pred = imp_adata.obs['leiden'].astype('int').to_numpy()

    print('ARI: %.2f' % (metrics.adjusted_rand_score(labels_true, labels_pred)))
    print('FMI: %.2f' % (metrics.fowlkes_mallows_score(labels_true, labels_pred)))
    print('Comp: %.2f' % (metrics.completeness_score(labels_true, labels_pred)))
except:
    exit()

ARI: 0.29
FMI: 0.35
Comp: 0.43
