# In-Silico Perturbation

In [1]:
from stFormer.perturbation.stFormer_perturb import InSilicoPerturber
from stFormer.perturbation.perturb_stats import InSilicoPerturberStats
import pandas as pd 
import pickle
import os
import re

### 0.1 Parameter Descriptions

### 1.1 Option A: Delete Single Gene and View both Gene and Cell Embedding Shifts

Here we look at the following option:
1. Delete Gene: PDCD1
2. Load in pretrained model on masked learning objective
3. want to look at the cosine similarity to null distribution when perturbing PDCD1 in both per gene and per cell embedding
4. batch our perturbations at each model forward pass


In [None]:
isp = InSilicoPerturber(
            mode = 'spot', #perturbing either spot or neighborhood model (based on tokenization)
            perturb_type="delete",
            genes_to_perturb=['ENSG00000188389'],
            #perturb_rank_shift = None,
            model_type="Pretrained",
            num_classes=0,
            emb_mode="cell_and_gene",
            cell_emb_style='mean_pool',
            max_ncells=1000,
            emb_layer=0,
            forward_batch_size=100,
            nproc=12,
            token_dictionary_file='output/token_dictionary.pickle',
         )
         
isp.perturb_data(
    model_directory='output/spot/models/250422_102707_stFormer_L6_E3/final',
    input_data_file='output/spot/visium_spot.dataset',
    output_directory='output/perturb/cell_shift',
    output_prefix='perturb_spot')

Load cell and gene embedding shifts

In [3]:
# Load files 
perturb_dir = 'output/perturb'

# get full paths to all .pickle files
emb_files = [ os.path.join(perturb_dir, fname) for fname in os.listdir(perturb_dir) if fname.endswith('.pickle')]

for path in emb_files:
    with open(path, 'rb') as f:
        if re.search(r'cell_embs_dict', os.path.basename(path)):
            cell_embs = pickle.load(f)
        else:
            gene_embs = pickle.load(f)

We want to create stats to evaluate the effect of PDCD1 gene expression on:
1. Cell embeddings similarity to unperturbed cells
2. Gene embedding similarity to unperturbed genes, where we aggregate the gene shifts for our cells for each other token in our dataset
    - ex: how does deletion of PDCD1 effect the expression of another gene like CCR5AS
    - These results are ranked by mean cosine similarity (lowest -> highest) where low similarity signifies a greater impact on gene expression and a high dependency on the gene expression of our perturbed gene
    - provides standeard deviation of cosine similarity between gene embeddings and the number of detections

In [None]:
from stFormer.perturbation.perturb_stats import InSilicoPerturberStats
ispstats = InSilicoPerturberStats(
    mode='aggregate_gene_shifts',
    genes_perturbed = ['ENSG00000188389'],
    pickle_suffix = '_raw.pickle',
    token_dictionary_file='output/token_dictionary.pickle',
    gene_name_id_dictionary_file='output/ensembl_mapping_dict.pickle'
)
ispstats.get_stats(
    input_data_directory = 'output/perturb',
    null_dist_data_directory=None,
    output_directory = 'output/perturb/perturb_stats',
    output_prefix= 'perturb_spot'
)

In [None]:
import pandas as pd
results = pd.read_csv('output/perturb/perturb_stats/perturb_spot.csv',index_col = 0)
results #view head of results

Unnamed: 0,Perturbed,Gene_name,Ensembl_ID,Affected,Affected_gene_name,Affected_Ensembl_ID,Cosine_sim_mean,Cosine_sim_stdev,N_Detections
16096,15038,PDCD1,ENSG00000188389,cell_emb,,,0.974385,0.012724,490
14053,15038,PDCD1,ENSG00000188389,6342,MORF4L2-AS1,ENSG00000231154,0.723488,0.000000,1
15117,15038,PDCD1,ENSG00000188389,15084,CNGA1,ENSG00000198515,0.796645,0.038457,2
14433,15038,PDCD1,ENSG00000188389,10688,SERINC4,ENSG00000184716,0.802099,0.000000,1
9147,15038,PDCD1,ENSG00000188389,15941,CCR5AS,ENSG00000223552,0.814862,0.000000,1
...,...,...,...,...,...,...,...,...,...
15714,15038,PDCD1,ENSG00000188389,13172,MTRNR2L3,ENSG00000256222,0.999999,0.000000,1
15972,15038,PDCD1,ENSG00000188389,16722,MPO,ENSG00000005381,0.999999,0.000000,1
15752,15038,PDCD1,ENSG00000188389,11045,HBA1,ENSG00000206172,0.999999,0.000000,1
13381,15038,PDCD1,ENSG00000188389,15799,LINC02608,ENSG00000226251,0.999999,0.000000,1


### 1.2 Option B: Delete all Genes and compare Cell State Shifts (Group A -> Group B)


1. **Data Filtering & Batching**  
   - `perturb_data()` loads and filters your input dataset.  
   - In `isp_perturb_set()` or `isp_perturb_all()`, cells are grouped into batches of size `forward_batch_size`.

2. **Define Transition**  
   -  `get_state_embs` 

3. **Embedding Extraction**  
   - **Original** (`full_orig`) and **perturbed** (`full_pert`) token sequences are passed through the model (via `get_embs`) to collect hidden‐state tensors.  
   - Because overexpressed tokens were prepended, the **first N** positions of `full_pert` are your newly-inserted genes; the rest align with the original token order.

4. **Cosine‐Similarity Quantification**  
   - **Gene-level**: `pu.quant_cos_sims(pert_emb, original_emb, …, emb_mode="gene")` computes how each remaining gene embedding shifts when your target gene is overexpressed.  
   - **Cell-level** (if `cell_states_to_model` is set): average the non-padding embeddings before & after overexpression and quantify that shift against your state-embedding targets.

5. **Aggregation & Output**  
   - For each cell or state, cosine-similarities are averaged (or bucketed) and stored in a dictionary keyed by the perturbed gene(s).  
   - Final results are written out as “_cell_embs_dict_…” and—if `emb_mode="cell_and_gene"`—also “_gene_embs_dict_…” files you can analyze with `in_silico_perturber_stats`.

In [None]:
cell_states_to_model = {'state_key':'classification',
                        'start_state': 'Invasive cancer',
                        'goal_state': 'Invasive cancer + lymphocytes',
                        'alt_states': ['Invasive cancer + stroma','Invasive cancer + stroma + lymphocytes']}

In [None]:
from stFormer.tokenization.embedding_extractor import EmbExtractor
embex = EmbExtractor(
    model_type = 'Pretrained',
    num_classes=2,
    filter_data = None,
    max_ncells = 1000,
    emb_layer=0,
    summary_stat='exact_mean',
    forward_batch_size=100,
    token_dictionary_file='output/token_dictionary.pickle',
    nproc = 16
)

state_embs_dict = embex.get_state_embs(
    cell_states_to_model,
    model_directory='output/spot/models/250422_102707_stFormer_L6_E3/final',
    input_data_file='output/spot/visium_spot.dataset',
    output_directory='output/perturb/cell_shift',
    output_prefix='spot_state_shift'
)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

visualize state embeddings

In [11]:
for key,value in state_embs_dict.items():
    print(key)

Invasive cancer
Invasive cancer + lymphocytes
Invasive cancer + stroma
Invasive cancer + stroma + lymphocytes


Perform Perturbtion of all genes and compare cosine similarity from start state to goal state and rank according to shift to goal state

In [1]:
from stFormer.perturbation.stFormer_perturb import InSilicoPerturber
from stFormer.perturbation.perturb_stats import InSilicoPerturberStats

In [None]:
isp = InSilicoPerturber(
            perturb_type="delete",
            perturb_rank_shift=None,
            combos=0,
            anchor_gene=None,
            genes_to_perturb='all',
            cell_states_to_model=cell_states_to_model,
            state_embs_dict = state_embs_dict,
            model_type="Pretrained",
            num_classes=0,
            emb_mode="cell",
            cell_emb_style='mean_pool',
            max_ncells=None,
            emb_layer=0,
            forward_batch_size=100,
            nproc=12,
            token_dictionary_file='output/token_dictionary.pickle',
         )
         
isp.perturb_data(
    model_directory='output/spot/models/250422_102707_stFormer_L6_E3/final',
    input_data_file='output/spot/visium_spot.dataset',
    output_directory='output/perturb/cell_shift',
    output_prefix='perturb_spot')

In [None]:
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model=cell_states_to_model,
                                  token_dictionary_file='output/token_dictionary.pickle',
                                  gene_name_id_dictionary_file='output/ensembl_mapping_dict.pickle')
ispstats.get_stats(
    input_data_directory="output/perturb/cell_shift", # this should be the directory 
    null_dist_data_directory=None,
    output_directory="output/perturb/cell_shift/perturb_stats",
    output_prefix="visium_spot")

Structure of this dataset follows:
1. `Gene`: Token ID 
2. `Gene Name`: hgnc name for gene
3. `Ensembl_ID`: matching ensembl gene id 
4. `Shift to goal end`: cosine shift from start state towards goal end state in response to given perturbation
5. `Goal end vs random pvalue`: pvalue of cosine shift from start state towards goal end state by Wilcoxon to random distribution of max 10,000 cells
6. `N Detections`: Number of cells where the perturbed gene was detected/perturbed
7. `Shift to alternate end state`: Cosine shift from start state towards alternate end state in response to perturbation
8. `Goal end FDR`: Benjamini Hochberg correction of Goal State vs Null Pvalue
9. `Alt End FDR`: Multiple Hypothesis Test Correction of Alternate State End vs Random Pvalue 
10. `Sig`: Binarized False Discovery Rate below significant (FDR < 0.05)

In [5]:
import pandas as pd

In [8]:
cell_shift_stats = pd.read_csv('output/perturb/cell_shift/perturb_stats/visium_spot.csv',index_col=0)
cell_shift_stats

Unnamed: 0,Gene,Gene_name,Ensembl_ID,Shift_to_goal_end,Goal_end_vs_random_pval,N_Detections,Shift_to_alt_end_Invasive cancer + stroma,Alt_end_vs_random_pval_Invasive cancer + stroma,Shift_to_alt_end_Invasive cancer + stroma + lymphocytes,Alt_end_vs_random_pval_Invasive cancer + stroma + lymphocytes,Goal_end_FDR,Alt_end_FDR_Invasive cancer + stroma,Alt_end_FDR_Invasive cancer + stroma + lymphocytes,Sig
2979,3907,ANKRD37,ENSG00000186352,0.001014,0.027741,5,0.000712,0.174332,0.000750,0.090140,0.924671,0.913586,0.873315,0
11646,17551,SPINK4,ENSG00000122711,0.000910,0.175442,6,0.001164,0.132797,0.001139,0.133500,0.953533,0.886458,0.888611,0
7884,10285,ARID4A,ENSG00000032219,0.000824,0.023335,10,0.000348,0.601956,0.000403,0.424766,0.924671,0.982160,0.959837,0
9497,12278,HOXB-AS1,ENSG00000230148,0.000818,0.024909,6,0.000750,0.044214,0.000770,0.026993,0.924671,0.823552,0.790271,0
10644,13690,PBX4,ENSG00000105717,0.000778,0.003643,7,0.001221,0.005439,0.001154,0.005040,0.924671,0.734657,0.715931,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
53,64,TTC34,ENSG00000215912,-0.000763,0.013673,9,0.000024,0.509069,-0.000080,0.630417,0.924671,0.974761,0.974686,0
5066,6672,IDO1,ENSG00000131203,-0.000782,0.055477,7,-0.000479,0.125197,-0.000522,0.074521,0.924671,0.886458,0.853843,0
5856,7701,CAVIN3,ENSG00000170955,-0.000879,0.013125,6,-0.000491,0.008692,-0.000545,0.005454,0.924671,0.734657,0.715931,0
479,604,FYB2,ENSG00000187889,-0.000925,0.005857,11,-0.000171,0.891572,-0.000269,0.694267,0.924671,0.996187,0.985234,0


### 1.3 Option C: Overexpress Gene


1. **Data Filtering & Batching**  
   - `perturb_data()` loads and filters your input dataset.  
   - In `isp_perturb_set()` or `isp_perturb_all()`, cells are grouped into batches of size `forward_batch_size`.

2. **Applying Overexpression**  
   - For each example, `pu.overexpress_tokens(...)` takes the token indices of your gene(s) and **inserts** them at the **front** of the token list.  

3. **Embedding Extraction**  
   - **Original** (`full_orig`) and **perturbed** (`full_pert`) token sequences are passed through the model (via `get_embs`) to collect hidden‐state tensors.  
   - Because overexpressed tokens were prepended, the **first N** positions of `full_pert` are your newly-inserted genes; the rest align with the original token order.

4. **Cosine‐Similarity Quantification**  
   - **Gene-level**: `pu.quant_cos_sims(pert_emb, original_emb, …, emb_mode="gene")` computes how each remaining gene embedding shifts when your target gene is overexpressed.  
   - **Cell-level** (if `cell_states_to_model` is set): average the non-padding embeddings before & after overexpression and quantify that shift against your state-embedding targets.

5. **Aggregation & Output**  
   - For each cell or state, cosine-similarities are averaged (or bucketed) and stored in a dictionary keyed by the perturbed gene(s).  
   - Final results are written out as “_cell_embs_dict_…” and—if `emb_mode="cell_and_gene"`—also “_gene_embs_dict_…” files you can analyze with `in_silico_perturber_stats`.

In [1]:
from stFormer.perturbation.stFormer_perturb import InSilicoPerturber
from stFormer.perturbation.perturb_stats import InSilicoPerturberStats

In [None]:
isp = InSilicoPerturber(
            perturb_type="overexpress",
            genes_to_perturb=['ENSG00000091831'], #ESR1,
            #perturb_rank_shift = None,
            model_type="Pretrained",
            filter_data={'subtype':'TNBC'}, #include filtering to TNBC to see overexpression of ESR1 on these cells
            num_classes=0,
            emb_mode="cell_and_gene",
            cell_emb_style='mean_pool',
            max_ncells=1000,
            emb_layer=0,
            forward_batch_size=10,
            nproc=12,
            token_dictionary_file='output/token_dictionary.pickle',
         )
         
isp.perturb_data(
    model_directory='output/spot/models/250422_102707_stFormer_L6_E3/final',
    input_data_file='output/spot/visium_spot.dataset',
    output_directory='output/perturb/overexpress',
    output_prefix='perturb_spot')

In [None]:
ispstats = InSilicoPerturberStats(mode='aggregate_gene_shifts',
                                  genes_perturbed=['ENSG00000091831'],
                                  token_dictionary_file='output/token_dictionary.pickle',
                                  gene_name_id_dictionary_file='output/ensembl_mapping_dict.pickle'
                                  )
ispstats.get_stats(
    input_data_directory='output/perturb/overexpress',
    null_dist_data_directory=None,
    output_directory='output/perturb/overexpress/stats',
    output_prefix='visium_spot'
)

In [None]:
import pandas as pd
stats = pd.read_csv('output/perturb/overexpress/stats/visium_spot.csv')
stats[stats['N_Detections'] > 100]

### 1.5 Option E: Load in Cell Classifier for Perturbation

1. Data Filtering & Batching
- `perturb_data()` loads and filters the pretrained CellClassifier
- Each cell is selected individually and padded into batch form using padding_collate_fn_orig.

2. Generating Perturbations
- For each cell, `pu.make_perturbation_batch` creates a batch containing our perturbation (ESR1)
- Perturbed versions are padded into batch form (pert_batch).
- The perturbed embeddings (full_pert) are collected.

3. Cosine‐Similarity Quantification
- **Gene-level:**`pu.quant_cos_sims`quantifies how each gene’s embedding shifts under perturbation.
- **Cell-level:** Mean-pooled cell embeddings are compared using `quant_cos_sims`to measure the overall state shift.

4. Aggregation & Dictionary Update
- For every perturbed gene, cosine similarities to all other genes are saved in gene_embs.
- Once all cells are processed, the final cosine similarity and gene embedding dictionaries are saved.

In [1]:
from stFormer.perturbation.stFormer_perturb import InSilicoPerturber
from stFormer.perturbation.perturb_stats import InSilicoPerturberStats

In [2]:
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

In [None]:
isp = InSilicoPerturber(
            perturb_type="delete",
            genes_to_perturb='all',
            #perturb_rank_shift = None,
            model_type="CellClassifier",
            num_classes=2, #number of classes originally trained on
            emb_mode="cell",
            cell_emb_style='mean_pool',
            max_ncells=500,
            emb_layer=0,
            forward_batch_size=100,
            nproc=12,
            token_dictionary_file='output/token_dictionary.pickle',
         )
         
isp.perturb_data(
    model_directory='output/models/classification/final_model',
    input_data_file='output/spot/visium_spot.dataset',
    output_directory='output/perturb',
    output_prefix='perturb_spot')