In [1]:
import time
from pathlib import Path

import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from gears import PertData

from bmfm_targets.datasets.perturbation import GearsDataModule
from bmfm_targets.tokenization import get_all_genes_tokenizer

  from .autonotebook import tqdm as notebook_tqdm


# Pertubation splits exploration

In [6]:
def get_gears_data(data_name, split):
    save_dir = Path(f"./save/dev_perturb_{data_name}-{time.strftime('%b%d-%H-%M')}/")
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"saving to {save_dir}")

    pert_data = PertData("./data")
    pert_data.load(data_name=data_name)
    pert_data.prepare_split(split=split, seed=1)
    return pert_data

def get_scperturb_data(data_name):
    path = '/dccstor/bmfm-targets/data/omics/transcriptome/scRNA/finetune/Perturbation/scPerturb/'
    if data_name == "norman":
        source_h5ad_file = [path + 'NormanWeissman2019_filtered.h5ad']
    elif data_name == "adamson":
        source_h5ad_file = [path + 'AdamsonWeissman2016_GSM2406675_10X001.h5ad', path + 'AdamsonWeissman2016_GSM2406677_10X005.h5ad', path + 'AdamsonWeissman2016_GSM2406681_10X010.h5ad']
    elif data_name == "replogle":
        source_h5ad_file = [path + 'ReplogleWeissman2022_K562_essential.h5ad']
    else:
        print('unsupported data name')
    
    if len(source_h5ad_file) == 1:
        adata = sc.read_h5ad(source_h5ad_file[0])
    else:
        adatas = []
        for file in source_h5ad_file:
            adatas.append(sc.read_h5ad(file))
        adata = AnnData.concatenate(*adatas)
    adata
    return adata

def compare_sizes(gears_adata, scperturb_adata):
    print(f'shape of gears data: {gears_adata.shape}')
    print(f'shape of scperturb data: {scperturb_adata.shape}')

def print_num_cell_types(gears_adata, scperturb_adata):
    print('GEARS data - number of cell types:')
    print(gears_adata.obs['cell_type'].value_counts())
    print('Scperturb data - number of cell types:')
    print(scperturb_adata.obs['celltype'].value_counts())
    print('Scperturb data - number of cell lines:')
    print(scperturb_adata.obs['cell_line'].value_counts())

def print_unique_perturbations(gears_adata, scperturb_adata):
    print('GEARS data - perturbations:')
    print(gears_adata.obs['condition'].unique())
    print('Scperturb_adata data - perturbations:')
    print(scperturb_adata.obs['perturbation'].unique())
    
    
def get_perturbations_split_gears(gears_obj):
    test_perts = gears_obj.set2conditions["test"]
    train_perts = gears_obj.set2conditions["train"]
    val_perts = gears_obj.set2conditions["val"]
    print('num train oerturbations:',len(train_perts))
    print('num val oerturbations:', len(val_perts))
    print('num test oerturbations:', len(test_perts))
    print('check that test,train,val groups are distinct in their belonging perturbations in GEARS data, sizes of intersections are:')
    print(len(set(test_perts).intersection(set(train_perts))))
    print(len(set(val_perts).intersection(set(train_perts))))
    print(len(set(test_perts).intersection(set(val_perts))))
    
    train_perts_df = pd.DataFrame(data = train_perts, columns = ['cond'])
    train_perts_df['cond'] = (
            train_perts_df['cond']
            .str.replace("+ctrl", "")
            .str.replace("ctrl+", "")
            .str.replace("+", "_")
            .str.replace("ctrl", "Control")
        )
    
    val_perts_df = pd.DataFrame(data = val_perts, columns = ['cond'])
    val_perts_df['cond'] = (
            val_perts_df['cond']
            .str.replace("+ctrl", "")
            .str.replace("ctrl+", "")
            .str.replace("+", "_")
            .str.replace("ctrl", "Control")
        )

    test_perts_df = pd.DataFrame(data = test_perts, columns = ['cond'])
    test_perts_df['cond'] = (
            test_perts_df['cond']
            .str.replace("+ctrl", "")
            .str.replace("ctrl+", "")
            .str.replace("+", "_")
            .str.replace("ctrl", "Control")
        )
    print('After renaming the perturbations, the intersection is:')
    print(len(set(test_perts_df['cond']).intersection(set(train_perts_df['cond']))))
    print(len(set(val_perts_df['cond']).intersection(set(train_perts_df['cond']))))
    print(len(set(test_perts_df['cond']).intersection(set(val_perts_df['cond']))))
    return test_perts, train_perts, val_perts

In [3]:
def explore_dataset(data_name, split):
    gears_obj = get_gears_data(data_name, split)
    gears_adata = gears_obj.adata
    scperturb_adata = get_scperturb_data(data_name)
    compare_sizes(gears_adata, scperturb_adata)
    print_num_cell_types(gears_adata, scperturb_adata)
    print_unique_perturbations(gears_adata, scperturb_adata)
    test_perts, train_perts, val_perts = get_perturbations_split_gears(gears_obj)
    return gears_obj, scperturb_adata

In [4]:
def create_scgpt_split_columm_and_save_h5ad(data_name, gears_obj):
    test_perts, train_perts, val_perts = get_perturbations_split_gears(gears_obj)
    # save gears data with split column and change ctrl name to Control
    conditions = [
        gears_obj.adata.obs['condition'].isin(train_perts),
        gears_obj.adata.obs['condition'].isin(val_perts),
        gears_obj.adata.obs['condition'].isin(test_perts)
    ]

    values = ['train', 'dev', 'test']

    gears_obj.adata.obs['scgpt_split'] = np.select(conditions, values, default=np.nan)

    gears_obj.adata.obs[['condition','scgpt_split']]
    gears_obj.adata.write_h5ad(f'/dccstor/bmfm-targets/data/omics/transcriptome/scRNA/finetune/Perturbation/GEARS/{data_name}_scgpt_split.h5ad')

## Norman dataset

In [7]:
data_name = "norman"
split = "simulation"
gears_obj, scperturb_adata = explore_dataset(data_name, split)
create_scgpt_split_columm_and_save_h5ad(data_name, gears_obj)

Found local copy...


saving to save/dev_perturb_norman-Nov13-02-01


Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:43
combo_seen2:19
unseen_single:36
Done!


here1
shape of gears data: (89357, 5045)
shape of scperturb data: (111445, 33694)
GEARS data - number of cell types:
cell_type
A549    89357
Name: count, dtype: int64
Scperturb data - number of cell types:
celltype
lymphoblasts    111445
Name: count, dtype: int64
Scperturb data - number of cell lines:
cell_line
K562    111445
Name: count, dtype: int64
GEARS data - perturbations:
['TSC22D1+ctrl', 'KLF1+MAP2K6', 'ctrl', 'CEBPE+RUNX1T1', 'MAML2+ctrl', ..., 'STIL+ctrl', 'CDKN1C+ctrl', 'ctrl+CDKN1B', 'CDKN1B+CDKN1A', 'C3orf72+FOXL2']
Length: 277
Categories (277, object): ['AHR+FEV', 'AHR+KLF1', 'AHR+ctrl', 'ARID1A+ctrl', ..., 'ZC3HAV1+HOXC13', 'ZC3HAV1+ctrl', 'ZNF318+FOXL2', 'ZNF318+ctrl']
Scperturb_adata data - perturbations:
['ARID1A', 'BCORL1', 'FOSB', 'SET_KLF1', 'OSR2', ..., 'CEBPB_OSR2', 'PRDM1_CBFA2T3', 'FOSB_CEBPB', 'ZBTB10_DLX2', 'FEV_CBFA2T3']
Length: 237
Categories (237, object): ['AHR', 'AHR_FEV', 'AHR_KLF1', 'ARID1A', ..., 'ZC3HAV1_HOXC13', 'ZNF318', 'ZNF318_FOXL2', 'control']


... storing 'scgpt_split' as categorical


# Adamson dataset

In [8]:
data_name = "adamson"
split = "simulation"
gears_obj, scperturb_adata = explore_dataset(data_name, split)
create_scgpt_split_columm_and_save_h5ad(data_name, gears_obj)

Found local copy...


saving to save/dev_perturb_adamson-Nov13-02-23


Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['SRPR+ctrl' 'SLMO2+ctrl' 'TIMM23+ctrl' 'AMIGO3+ctrl' 'KCTD16+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:21
Done!


here1
shape of gears data: (65899, 5060)
shape of scperturb data: (86111, 32738)
GEARS data - number of cell types:
cell_type
K562(?)    65899
Name: count, dtype: int64
Scperturb data - number of cell types:
celltype
lymphoblasts    86111
Name: count, dtype: int64
Scperturb data - number of cell lines:
cell_line
K562    86111
Name: count, dtype: int64
GEARS data - perturbations:
['CREB1+ctrl', 'ctrl', 'ZNF326+ctrl', 'BHLHE40+ctrl', 'DDIT3+ctrl', ..., 'CARS+ctrl', 'TMED2+ctrl', 'P4HB+ctrl', 'SPCS3+ctrl', 'SPCS2+ctrl']
Length: 82
Categories (82, object): ['AARS+ctrl', 'ARHGAP22+ctrl', 'ASCC3+ctrl', 'ATP5B+ctrl', ..., 'XRN1+ctrl', 'YIPF5+ctrl', 'ZNF326+ctrl', 'ctrl']
Scperturb_adata data - perturbations:
['CREB1_pDS269', 'SNAI1_pDS266', '62(mod)_pBA581', 'EP300_pDS268', 'ZNF326_pDS262', ..., 'ERN1_pBA575', 'CCND3_pDS005', 'ATF6_pBA586', 'EIF2AK3_pBA573', 'STT3A_pDS010']
Length: 131
Categories (130, object): ['3x_neg_ctrl_pMJ144-1', '3x_neg_ctrl_pMJ144-2', '62(mod)_pBA581', '63(mod)_pBA580

... storing 'scgpt_split' as categorical


In [2]:
dataset_kwargs = {'split_column_name': 'scgpt_split'}  
transform_kwargs = {'source_h5ad_file_names': ['/dccstor/bmfm-targets/data/omics/transcriptome/scRNA/finetune/Perturbation/GEARS/norman_scgpt_split.h5ad'], 
                    "transforms": [],
                    'stratifying_label': 'condition'}

dm = GearsDataModule(tokenizer=get_all_genes_tokenizer(), fields=[],data_dir = '/dccstor/bmfm-targets/data/omics/transcriptome/scRNA/finetune/Perturbation/GEARS/', processed_name= 'norman_scgpt_split_processed', collation_strategy= "sequence_labeling", transform_datasets= True,dataset_kwargs=dataset_kwargs, transform_kwargs=transform_kwargs, perturbation_column_name='condition')

In [3]:
dm.prepare_data()

... storing 'condition' as categorical
... storing 'split_stratified_condition' as categorical


In [4]:
dm.setup('fit')
dataset = dm.train_dataset

In [6]:
from bmfm_targets.datasets import GearsPerturbationDatasetTransformer

data_name='norman'
pert_data = sc.read_h5ad(f'/dccstor/bmfm-targets/data/omics/transcriptome/scRNA/finetune/Perturbation/GEARS/{data_name}_scgpt_split.h5ad')
data_transformer_gears = GearsPerturbationDatasetTransformer(tokenizer=get_all_genes_tokenizer(), source_h5ad_file_names=[f'/dccstor/bmfm-targets/data/omics/transcriptome/scRNA/finetune/Perturbation/GEARS/{data_name}_scgpt_split.h5ad'], stratification_type='simulation', stratifying_label='condition')

cleaned_gears = data_transformer_gears._clean_dataset(pert_data)
sum(cleaned_gears.obs['condition'].str.contains(r'\+'))
cleaned_gears.obs['condition']

cell_barcode
AAACCTGAGGCATGTG-1          TSC22D1
AAACCTGAGGCCCTTG-1      KLF1_MAP2K6
AAACCTGCACGAAGCA-1          Control
AAACCTGCAGACGTAG-1    CEBPE_RUNX1T1
AAACCTGCAGCCTTGG-1            MAML2
                          ...      
TTTGTCAGTCAGAATA-8          Control
TTTGTCATCAGTACGT-8            FOXA3
TTTGTCATCCACTCCA-8            CELF2
TTTGTCATCCCAACGG-8           BCORL1
TTTGTCATCTGGCGAC-8           MAP4K3
Name: condition, Length: 89091, dtype: object

In [7]:
cleaned_gears.obs['condition'].nunique()

229

In [None]:
#gears_perts = pert_data.adata.obs['condition'].unique()

#data_transformer_gears = PerturbationDatasetTransformer(tokenizer=get_all_genes_tokenizer(), source_h5ad_file_names=[f'/dccstor/bmfm-targets/data/omics/transcriptome/scRNA/finetune/Perturbation/GEARS/{data_name}_scgpt_split.h5ad'], stratification_type='simulation', stratifying_label='condition')

#cleaned_gears = data_transformer_gears._clean_dataset(pert_data.adata)
#sum(cleaned_gears.obs['condition'].str.contains(r'\+'))
#cleaned_gears.obs['condition']

#data_transformer = PerturbationDatasetTransformer(tokenizer=get_all_genes_tokenizer(), source_h5ad_file_names=source_h5ad_file, stratification_type='simulation')
#if len(source_h5ad_file) == 1:
#    clean_adata = data_transformer._clean_dataset(adata)
#else:
#    clean_adatas = []
#    for file in source_h5ad_file:
#        clean_adatas.append(data_transformer._clean_dataset(sc.read_h5ad(file)))
#    clean_adata = AnnData.concatenate(*clean_adatas)
#clean_adata

#clean_adata.obs['perturbation'].unique()
