# Annotated Data changes for all tools

This Notebook walks through the general process of .obs and .var differences between tools (SENA, GEARS, scGPT and lgem model). Each has some key differences in how some variables in their Annotated Data is named. Some tools (scGPT and lgem) only works with valid samples.

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import os

from gears import PertData

### Filepaths

In [57]:
# For loading 'raw data'
univ_path = '/workspace/tfm/SENA/data'
data_name = 'Norman2019_reduced.h5ad'
filepath = os.path.join(univ_path, data_name)
print(filepath)

# For saving data processed by GEARS PertData data handler
# Save data in common directory
data_savedir = '/workspace/tfm/cris_test/data'
name_new_dataset = 'norman_reduced'

/workspace/tfm/SENA/data/Norman2019_reduced.h5ad


### Loading and data changes

AnnData requirements:
- (SENA) adata.obs['guide_ids]: Condition of each sample. **'ctrl', 'GeneA', 'GeneA,GeneB'**
- (SENA) adata.var['gene_symbols']
- (GEARS) adata.var['gene_name'] = adata.var['gene_symbols']
- (GEARS) adata.obs['condition']: Condition of each sample. **'ctrl' 'ctrl+geneA' 'geneA+ctrl' 'geneA+geneB'**
- (lgem) adata.obs['condition_fixed]: Condition of each sample. **'ctrl' 'geneA' 'geneA+geneB'** 
- (scGPT) adata.var['gene_name'] = adata.var['gene_symbols']
</br>

ATTENTION: Depending on the version of pickler, there might be issues with pickle loading. scGPT uses an older version. Since I couldn't find an easier way to solve this, you should recreate the dataset inside the docker space for scGPT. SENA is the only one that doesn't use GEARS module's data handler.

In [8]:
adata = sc.read(filepath)

In [9]:
adata

AnnData object with n_obs × n_vars = 11850 × 5000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_ids', 'guide_merged', 'split', 'batch', 'condition', 'cell_type', 'dose_val', 'control', 'drug_dose_name', 'cov_drug_dose_name'
    var: 'gene_symbols', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'rank_genes_groups', 'rank_genes_groups_cov'
    layers: 'counts'

In [10]:
adata.obs['guide_ids']

index
CCCTCCTAGAGGTAGA-3-0-0    CEBPE,RUNX1T1
ATCTACTGTTATGCGT-1-0-0             DLX2
CGGACTGGTTGACGTT-6-0-0            ZBTB1
AACTTTCGTACGAAAT-5-0-0          AHR,FEV
GGACAGAGTGGTCCGT-2-0-0       CNN1,MAPK1
                              ...      
ACATCAGCATACTCTT-2-1-1                 
ATTGGTGTCAGCGATT-2-1-1                 
CGCCAAGGTAGCTCCG-6-1-1                 
CTAAGACTCCTGTACC-1-1-1                 
GTATCTTGTCTACCTC-2-1-1                 
Name: guide_ids, Length: 11850, dtype: category
Categories (237, object): ['', 'AHR', 'AHR,FEV', 'AHR,KLF1', ..., 'ZBTB10', 'ZBTB25', 'ZC3HAV1', 'ZNF318']

In [None]:
# gene_name
adata.var['gene_name'] = adata.var['gene_symbols']

In [None]:
# condition_fixed. Works with norman datastet
splitting = adata.obs['condition'].str.split('+')
for i in range(len(splitting)):
    if len(splitting[i]) == 2:
        if 'ctrl' in splitting[i]:
            splitting[i].remove('ctrl')

join_names = splitting.apply(lambda x: '+'.join(sorted(x))) # Makes sure that order is the same
adata.obs['condition_fixed'] = join_names


In [None]:
# Check that condition_fixed and guide_ids have same length
print(len(adata.obs['condition'].unique()))
print(len(adata.obs['condition_fixed'].unique()))
print(len(adata.obs['guide_ids'].unique()))


284
237
237


## Creating Gears Dataset

In [None]:
# GEARS pertdata pickle creation
# Can take around 10-15 minutes for the usual, non-reduced datafiles.
pert_data = PertData(data_savedir) # specific saved folder
pert_data.new_data_process(dataset_name = name_new_dataset, adata = adata) # specific dataset name and adata object


Found local copy...
Found local copy...
Creating pyg object for each cell in the data...
Creating dataset file...
 21%|████████████████████████▎                                                                                        | 61/284 [00:19<01:05,  3.42it/s]

LYL1+IER5L


 28%|███████████████████████████████▍                                                                                 | 79/284 [00:27<02:43,  1.25it/s]

IER5L+ctrl


 61%|████████████████████████████████████████████████████████████████████▏                                           | 173/284 [01:16<00:39,  2.79it/s]

KIAA1804+ctrl


 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████▎          | 257/284 [01:41<00:06,  3.87it/s]

ctrl+IER5L


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 284/284 [01:47<00:00,  2.64it/s]
Done!
Saving new dataset pyg object at /workspace/tfm/cris_test/data/norman_reduced/data_pyg/cell_graphs.pkl
Done!


In [None]:
# GEARS pertdata laoding
# Mostly to check that dataloader works without any problems
pert_data.load(data_path = os.path.join(data_savedir, name_new_dataset)) # load the processed data, the path is saved folder + dataset_name
pert_data.prepare_split(split = 'simulation', seed = 42) # get data split with seed
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader

Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl' 'ctrl+IER5L']
Local copy of pyg dataset is detected. Loading...
Done!
Creating new splits....
Saving new splits at /workspace/tfm/cris_test/data/norman_reduced/splits/norman_reduced_simulation_42_0.75.pkl
Simulation split test composition:
combo_seen0:5
combo_seen1:57
combo_seen2:17
unseen_single:37
Done!
Creating dataloaders....
Done!


#### Warning: PertData removes some samples

In [45]:
pert_data.adata # PertData removes some samples that aren't in the GO graph

AnnData object with n_obs × n_vars = 11700 × 5000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_ids', 'guide_merged', 'split', 'batch', 'condition', 'cell_type', 'dose_val', 'control', 'drug_dose_name', 'cov_drug_dose_name', 'condition_name', 'condition_fixed'
    var: 'gene_symbols', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_name'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups', 'rank_genes_groups_cov', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20'
    layers: 'counts'

In [44]:
adata

AnnData object with n_obs × n_vars = 11850 × 5000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_ids', 'guide_merged', 'split', 'batch', 'condition', 'cell_type', 'dose_val', 'control', 'drug_dose_name', 'cov_drug_dose_name', 'condition_name', 'condition_fixed'
    var: 'gene_symbols', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_name'
    uns: 'rank_genes_groups', 'rank_genes_groups_cov', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'top_non_zero_de_20'
    layers: 'counts'

#### Data handling functions

In [None]:
def separate_data(adata = None, dataset_name: str = "norman"):
    """Get the single perturbation dataset, double perturbation dataset and control dataset from the given AnnData object as well as list of single perturbations."""
    if "norman" in dataset_name.lower():
       # Preprocessing dataset
        splitting = adata.obs['condition'].str.split('+')
        for i in range(len(splitting)):
            if len(splitting[i]) == 2:
                if 'ctrl' in splitting[i]:
                    splitting[i].remove('ctrl')

        join_names = splitting.apply(lambda x: '+'.join(sorted(x))) # Makes sure that order is the same
        adata.obs['condition_fixed'] = join_names

        # Keeping only single perturbations
        filter_mask = ~adata.obs["condition_fixed"].str.contains(r"\+") # mask for those NOT containing +
        indexes_to_keep = filter_mask[filter_mask].index # mask that finds indeces in norman adata that aren't double perturbations

        # Dataset with single perts
        adata_single = adata[indexes_to_keep].copy()
        adata_single = adata_single[adata_single.obs['condition_fixed']!='ctrl']

        # Dataset with double perts
        adata_double = adata[~adata.obs['condition_fixed'].isin(adata_single.obs['condition_fixed'])].copy()
        adata_double = adata_double[adata_double.obs['condition_fixed']!='ctrl']

        # Ctrl expression
        adata_ctrl = adata[adata.obs['condition_fixed']=='ctrl'].copy()


    else:
        print("Dataset not implemented yet.")

    return adata_single, adata_double, adata_ctrl

In [31]:
def get_common_genes(adata = None, dataset_name: str = "norman"):
    """Get adata with perts found in genes (features) and return list of perts and genes."""
    if "norman" in dataset_name.lower():
        all_perts = adata.obs['condition_fixed'].values
        genes = set(adata.var['gene_symbols'].values)

        # Makes function work even with double perts
        def valid_pert(pert):
            pair_genes = pert.split("+")
            return all(gene in genes for gene in pair_genes)

        valid_perts = [p for p in all_perts if valid_pert(p)]
        adata_common = adata[adata.obs["condition_fixed"].isin(valid_perts)].copy() # Only keep perts that are in features
        perts = adata_common.obs["condition_fixed"].unique().tolist()
    else:
        print("Dataset not implemented yet.")
    return all_perts, perts, list(genes), adata_common

#### Samples with gene perturbations that are found in features

lgem and scGPT requires that its perturbed samples are found in the list of features (adata.var['gene_symbols']). adata_common can also be saved in its own with GEARS data_handler

In [55]:
all_perts, perts, genes, adata_common = get_common_genes(pert_data.adata, 'norman')

In [56]:
adata_common 

AnnData object with n_obs × n_vars = 6250 × 5000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_ids', 'guide_merged', 'split', 'batch', 'condition', 'cell_type', 'dose_val', 'control', 'drug_dose_name', 'cov_drug_dose_name', 'condition_name', 'condition_fixed'
    var: 'gene_symbols', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_name'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups', 'rank_genes_groups_cov', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20'
    layers: 'counts'