# Predictions using different embedding modes

In this Notebook, we want to show the different embedding modes that are available for the different single cell RNA models, available in the package.

In [1]:
from helical import scGPT, scGPTConfig
import torch
import anndata
from pathlib import Path
from helical.utils.downloader import Downloader
import os
from helical.constants.paths import CACHE_DIR_HELICAL

  from .autonotebook import tqdm as notebook_tqdm

INFO:datasets:PyTorch version 2.5.1 available.
INFO:datasets:Polars version 0.20.31 available.


We show the working principle using the scGPT model. Get the data if you don't have it already:

In [2]:
scgpt = scGPT()
path = Path.joinpath(CACHE_DIR_HELICAL, "17_04_24_YolkSacRaw_F158_WE_annots.h5ad")
if not os.path.exists(path):
    downloader = Downloader()
    downloader.download_via_name("17_04_24_YolkSacRaw_F158_WE_annots.h5ad")
    
data = anndata.read_h5ad(path)

INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cpu' with embedding mode 'cls'.


To explain the working principle of the different embedding modes, it is easier to simulate returned embeddings from the model.
We can do this in the following cell:
- we define a torch tensor, simulating the embeddings
- overwrite the `scgpt.model._encode` function to return those embeddings 
- skip the `scgpt._normalize_embeddings` function by returning the input without modifying it 


In [3]:
# Mock the method directly on the instance
mocked_embeddings = torch.tensor([
                                    [[1.0, 1.0, 1.0, 1.0, 1.0], 
                                    [5.0, 5.0, 5.0, 5.0, 5.0], 
                                    [1.0, 2.0, 3.0, 2.0, 1.0], 
                                    [6.0, 6.0, 6.0, 6.0, 6.0]],
                                ])
scgpt.model._encode = lambda *args, **kwargs: mocked_embeddings
scgpt._normalize_embeddings = lambda x: x

With this, we can run scGPT in the 3 different modes: `gene`, `cell` and `cls`.

- The `gene` mode returns embeddings for every gene.
- The `cell` mode returns the average of the gene embeddings.
- The `cls` mode returns the `cls` specific row, returned by the model. It can be thought of as a summary of the observation.

We run scGPT on a single observation / cell to explain the process.

In [4]:
dataset = scgpt.process_data(data[0])

scgpt.config["emb_mode"] = "gene"
gene_embeddings = scgpt.get_embeddings(dataset)

scgpt.config["emb_mode"] = "cell"
cell_embeddings = scgpt.get_embeddings(dataset)

scgpt.config["emb_mode"] = "cls"
cls_embeddings = scgpt.get_embeddings(dataset)

INFO:helical.models.scgpt.model:Processing data for scGPT.


INFO:helical.models.scgpt.model:Filtering out 10801 genes to a total of 26517 genes with an ID in the scGPT vocabulary.
INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.
INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.
INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 1123.88it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.
INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 1723.92it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


The gene embeddings return embeddings for every gene:

In [5]:
gene_embeddings[0]

SLC39A14    [5.0, 5.0, 5.0, 5.0, 5.0]
MPDU1       [1.0, 2.0, 3.0, 2.0, 1.0]
GPHN        [6.0, 6.0, 6.0, 6.0, 6.0]
dtype: object

The cell embeddings hold the averages of the gene embeddings:

In [6]:
cell_embeddings[0]

array([4.       , 4.3333335, 4.6666665, 4.3333335, 4.       ],
      dtype=float32)

The cls embeddings correspond to the first row returned by the model.

This means that scGPT in `cls` mode ignores the remaining 3 rows.

In [7]:
cls_embeddings[0]

array([1., 1., 1., 1., 1.], dtype=float32)

We can run this on real data too but the interpreation of this is harder to visualise:

First, we remove our modified scGPT model and instantiate a new one.

In [8]:
del scgpt
device = "cuda" if torch.cuda.is_available() else "cpu"
scgpt = scGPT(configurer=scGPTConfig(device=device))

INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cls'.


In [9]:
scgpt.config["emb_mode"] = "gene"
gene_embeddings = scgpt.get_embeddings(dataset)
gene_embeddings[0]

INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 1/1 [00:00<00:00,  5.18it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


SLC39A14    [-0.0011799249, 0.0031951678, -0.0037296554, 0...
MPDU1       [-0.0013756858, 0.017062135, -0.007849643, 0.0...
GPHN        [0.007110456, 0.025636358, 0.0028697518, 0.005...
AGFG2       [-0.008909806, 0.01073429, 0.006347002, 0.0071...
POLR3B      [-0.012140153, 0.04901718, 0.02245722, 0.00043...
                                  ...                        
TMEM258     [-0.0077039357, 0.017461302, 0.002785733, 0.01...
BNIP3L      [-0.0103421565, 0.035706572, 0.011275602, 0.00...
KPNB1       [0.0004736521, 0.032073762, 0.0024564175, 0.00...
ZSWIM5      [-0.012645806, 0.048165236, 0.02488112, -0.006...
REPIN1      [0.00678998, 0.019529147, -0.0017630243, 0.001...
Length: 1199, dtype: object

With real data, it is easier to analyse the output sizes:

In [10]:
print(f"Number of genes with embeddings: {gene_embeddings[0].shape}")
print(f"Embedding size per gene: {gene_embeddings[0][0].shape}")

Number of genes with embeddings: (1199,)
Embedding size per gene: (512,)


In [11]:
scgpt.config["emb_mode"] = "cell"
cell_embeddings = scgpt.get_embeddings(dataset)
cell_embeddings[0]

INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 116.03it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


array([-8.15041643e-03,  2.63333023e-02,  7.81406183e-03,  9.81337484e-03,
        1.40524013e-02, -2.95165903e-03, -1.69922467e-02, -4.81104245e-03,
       -7.60848820e-03,  4.59150635e-02,  5.87423565e-03, -5.45159075e-03,
        2.13732291e-02,  9.06534586e-03,  1.08101517e-02, -4.91651380e-03,
       -1.33220525e-02, -2.16271300e-02,  1.46969380e-02, -2.12924127e-02,
       -1.29918456e-02, -8.90347362e-03, -3.24050188e-02,  1.69960652e-02,
        6.72291825e-03,  3.31430845e-02, -3.28126512e-02, -2.09503025e-02,
       -2.97727287e-02,  7.01213209e-03,  2.35989746e-02, -2.13149730e-02,
       -2.99774809e-03, -1.64881814e-02, -1.13256490e-02, -3.82535718e-03,
       -1.26695279e-02,  5.00416942e-03, -6.52590021e-02, -4.58240882e-03,
       -2.73203477e-02, -5.88836940e-03, -2.24745702e-02, -5.89525886e-03,
       -3.59893106e-02, -2.78945770e-02,  8.49726028e-04,  1.33508444e-02,
        1.13855442e-02,  2.42668390e-02, -1.94337759e-02,  1.32291475e-02,
       -4.44257283e-04, -

In [12]:
print(f"Embedding size per cell: {cell_embeddings[0].shape}")

Embedding size per cell: (512,)


In [13]:
scgpt.config["emb_mode"] = "cls"
cls_embeddings = scgpt.get_embeddings(dataset)
cls_embeddings[0]

INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 240.78it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


array([-2.45153867e-02,  6.76829070e-02,  9.26359184e-03, -2.09791050e-03,
        2.24513449e-02,  2.08408223e-03, -2.53852215e-02,  3.11240065e-03,
       -6.62306789e-03,  3.31380554e-02,  2.76659112e-02,  4.46426682e-03,
        3.08492407e-02,  1.56556666e-02,  1.73311401e-02, -1.20286504e-02,
       -3.41191003e-03, -2.74797548e-02, -8.76120233e-04, -1.59723293e-02,
       -1.63279437e-02, -1.06245605e-02, -1.66277196e-02, -2.04682187e-03,
        1.51277408e-02,  5.14600612e-02, -5.13950512e-02, -3.17132138e-02,
       -2.95655672e-02, -2.13937908e-02,  1.59325860e-02, -2.51241624e-02,
        4.61029354e-03, -2.76504010e-02, -1.92681160e-02, -3.63738127e-02,
       -6.18889090e-03, -4.07493208e-03, -7.18622655e-02,  4.02772194e-03,
       -3.14314142e-02, -7.31843291e-03, -5.05978167e-02, -9.09360871e-03,
       -2.41975486e-02, -1.72051415e-02,  7.02964189e-03,  3.86377797e-02,
        9.64887347e-03,  5.07224277e-02, -2.60307938e-02,  2.90858541e-02,
        7.36009097e-03,  

In [14]:
print(f"Embedding size per cls: {cls_embeddings[0].shape}")

Embedding size per cls: (512,)
