# Preprocessing the _Adamson et al._[[1]](https://doi.org/10.1016/j.cell.2016.11.048) dataset

We use `GEARS`[[2]](https://doi.org/10.1101/2022.07.12.499735) to: 
1. Get the `No-perturb` normalization baseline for the Perturb-Seq dataset by _Adamson et al._ [[1]]([[1]](https://doi.org/10.1016/j.cell.2016.11.048)) dataset.
2. Extract the `adata` objects for `biolord`
3. Run `GEARS` as baseline comparison to `biolord`. 


[[1] Adamson, B., Norman, T. M., Jost, M., Cho, M. Y., Nuñez, J. K., Chen, Y., ... & Weissman, J. S. (2016). A multiplexed single-cell CRISPR screening platform enables systematic dissection of the unfolded protein response. Cell, 167(7), 1867-1882.](https://doi.org/10.1016/j.cell.2016.11.048)

[[2] Roohani, Y., Huang, K., & Leskovec, J. (2022). GEARS": Predicting transcriptional outcomes of novel multi-gene perturbations. BioRxiv, 2022-07.](https://doi.org/10.1101/2022.07.12.499735)


## Load packages

In [1]:
import sys
import pandas as pd
import pickle
import numpy as np
import anndata

sys.path.insert(0, "/cs/labs/mornitzan/zoe.piran/research/projects/GEARS")
from gears import PertData, GEARS
import torch

In [2]:
sys.path.append("../../../")
sys.path.append("../../utils/")
from paths import DATA_DIR, FIG_DIR

## Set parameters

In [3]:
DATA_DIR_LCL = str(DATA_DIR) + "/perturbations/adamson/"
FIG_DIR_LCL = str(FIG_DIR) + "/perturbations/adamson"

In [4]:
device = torch.cuda.current_device()
batch_size = 32

## Create `no_perturb` baseline

In [5]:
epoch = 0
no_perturb = True

In [6]:
subgroup_analysis = {}
for seed in range(1,6):
    pert_data = PertData(DATA_DIR_LCL[:-1]) # specific saved folder
    pert_data.load(data_name = "adamson")
    pert_data.prepare_split(split = "simulation", seed = seed)
    pert_data.get_dataloader(batch_size = batch_size, test_batch_size = batch_size)
    
    gears_model = GEARS(
        pert_data, 
        device = "cuda:" + str(device), 
        weight_bias_track = False, 
        proj_name = "adamson",
        exp_name = "no_perturb_seed" + str(seed)
    )
    gears_model.model_initialize(hidden_size = 64, no_perturb = no_perturb)
    subgroup_analysis[seed], subgroup_analysis_deeper, test_metrics = gears_model.train(epochs = epoch)

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and is thus not able to make prediction for...
['SRPR+ctrl' 'SLMO2+ctrl' 'TIMM23+ctrl' 'AMIGO3+ctrl' 'KCTD16+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:21
Done!
Creating dataloaders....
Done!
Found local copy...
Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.3849
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: nan
test_combo_seen0_pearson: nan
test_combo_seen0_mse_de: nan
test_combo_seen0_pearson_de: nan
test_combo_seen1_mse: nan
test_combo_seen1_pearson: nan
test_combo_seen1_mse_de: nan
test_combo_seen1_pearson_de: nan
test_combo_seen2_mse: nan
test_combo_seen2_pearson: nan
test_combo_seen2_mse_de: nan
test_combo_seen2_pearson_de: nan
test_unseen_single_mse: 0.0100

In [7]:
res = {}
for key in subgroup_analysis:
    res[f"mse_de_seed{key}"] = pd.DataFrame(subgroup_analysis[key]).T["mse_de"]

In [8]:
with open(DATA_DIR_LCL + "no_perturb_subgroup_analysis.pkl", "wb") as f:
    pickle.dump(subgroup_analysis, f)

In [9]:
pd.DataFrame(res).to_csv(DATA_DIR_LCL + "no_perturb_mse_de_seeds.csv")

In [11]:
pd.DataFrame(res)

Unnamed: 0,mse_de_seed1,mse_de_seed2,mse_de_seed3,mse_de_seed4,mse_de_seed5
combo_seen0,,,,,
combo_seen1,,,,,
combo_seen2,,,,,
unseen_single,0.384905,0.4207,0.339157,0.28872,0.317135


## Set `adata`

In [12]:
adata = pert_data.adata.copy()

In [13]:
for seed in range(1,6):
    with open(DATA_DIR_LCL + f"adamson/splits/adamson_simulation_{seed}_0.75.pkl", "rb") as f:
        split_data = pickle.load(f)
        pert2set = {}
        for i,j in split_data.items():
            for x in j:
                pert2set[x] = i
        
        subgroup = pickle.load(open(DATA_DIR_LCL + f"adamson/splits/adamson_simulation_{seed}_0.75_subgroup.pkl", "rb"))
        adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
        pert2subgroup = {}
        for i,j in subgroup["test_subgroup"].items():
            for x in j:
                pert2subgroup[x] = i
        
        adata.obs[f"subgroup{seed}"] = adata.obs["condition"].apply(lambda x: pert2subgroup[x] if x in pert2subgroup else 'Train/Val')
        rename = {
            'train': 'train',
             'test': 'ood',
             'val': 'test'
        }
        adata.obs[f'split{seed}'] = adata.obs[f'split{seed}'].apply(lambda x: rename[x])


In [14]:
adata.obs["perturbation"] = [cond.split("+")[0] for cond in adata.obs["condition"]]
adata.obs["perturbation"] = adata.obs["perturbation"].astype("category")

In [15]:
go_path = DATA_DIR_LCL + 'adamson/go.csv'
gene_path = DATA_DIR_LCL + 'essential_all_data_pert_genes.pkl'
df = pd.read_csv(go_path)
df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1, ['importance'])).reset_index(drop = True)
with open(gene_path, 'rb') as f:
    gene_list = pickle.load(f)
    
df = df[df["source"].isin(gene_list)]

In [16]:
def get_map(pert):
    tmp = pd.DataFrame(np.zeros(len(gene_list)), index=gene_list)
    tmp.loc[df[df.target == pert].source.values, :] = df[df.target == pert].importance.values[:, np.newaxis]
    return tmp.values.flatten()    
    

In [17]:
pert2neighbor =  {i: get_map(i) for i in list(adata.obs["perturbation"].cat.categories)}    
adata.uns["pert2neighbor"] = pert2neighbor

In [18]:
pert2neighbor = np.asarray([val for val in adata.uns["pert2neighbor"].values()])
keep_idx = pert2neighbor.sum(0) > 0

name_map = dict(adata.obs[["condition", "condition_name"]].drop_duplicates().values)
ctrl = np.asarray(adata[adata.obs["condition"].isin(["ctrl"])].X.mean(0)).flatten() 

In [19]:
df_perts_expression = pd.DataFrame(adata.X.A, index=adata.obs_names, columns=adata.var_names)
df_perts_expression["condition"] = adata.obs["condition"]
df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
df_perts_expression = df_perts_expression.reset_index()

single_perts_condition = []
single_pert_val = []
double_perts = []
for pert in adata.obs["condition"].cat.categories:
    if len(pert.split("+")) == 1:
        continue
    elif "ctrl" in pert:
        single_perts_condition.append(pert)
        p1, p2 = pert.split("+")
        if p2 == "ctrl":
            single_pert_val.append(p1)
        else:
            single_pert_val.append(p2)
single_perts_condition.append("ctrl")
single_pert_val.append("ctrl")

df_singleperts_expression = pd.DataFrame(df_perts_expression.set_index("condition").loc[single_perts_condition].values, index=single_pert_val)
df_singleperts_emb = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx] for p1 in df_singleperts_expression.index])

df_singleperts_condition = pd.Index(single_perts_condition)
df_single_pert_val = pd.Index(single_pert_val)


In [20]:
adata_single = anndata.AnnData(X=df_singleperts_expression.values, var=adata.var.copy(), dtype=df_singleperts_expression.values.dtype)
adata_single.obs_names = df_singleperts_condition
adata_single.obs["condition"] = df_singleperts_condition
adata_single.obs["perts_name"] = df_single_pert_val
adata_single.obsm["perturbation_neighbors"] = df_singleperts_emb

In [21]:
for split_seed in range(1,6):
    adata_single.obs[f"split{split_seed}"] = None
    adata_single.obs[f"subgroup{split_seed}"] = "Train/Val"
    for cat in ["train","test","ood"]:
        cat_idx = adata_single.obs["condition"].isin(adata[adata.obs[f"split{split_seed}"] == cat].obs["condition"].cat.categories)
        adata_single.obs.loc[cat_idx ,f"split{split_seed}"] = cat
        if cat == "ood":
            adata_single.obs.loc[cat_idx ,f"subgroup{split_seed}"] = "unseen_single"

In [22]:
adata_single.write(DATA_DIR_LCL + "adamson/adamson_single_biolord.h5ad")

In [23]:
adata.write(DATA_DIR_LCL + "adamson/adamson_biolord.h5ad")

## Train `GEARS` models

In [24]:
epoch = 15
no_perturb = False

In [25]:
subgroup_analysis_gears = {}
for seed in range(1,6):
    pert_data = PertData(DATA_DIR_LCL[:-1]) # specific saved folder
    pert_data.load(data_name = "adamson")
    pert_data.prepare_split(split = "simulation", seed = seed)
    pert_data.get_dataloader(batch_size = batch_size, test_batch_size = batch_size)
    
    gears_model = GEARS(
        pert_data, 
        device = "cuda:" + str(device), 
        weight_bias_track = False, 
        proj_name = "adamson",
        exp_name = "gears_seed" + str(seed)
    )
    gears_model.model_initialize(hidden_size = 64, no_perturb = no_perturb)
    subgroup_analysis_gears[seed], subgroup_analysis_deeper, test_metrics = gears_model.train(epochs = epoch)

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and is thus not able to make prediction for...
['SRPR+ctrl' 'SLMO2+ctrl' 'TIMM23+ctrl' 'AMIGO3+ctrl' 'KCTD16+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:21
Done!
Creating dataloaders....
Done!
Found local copy...
Start Training...
Epoch 1 Step 1 Train Loss: 0.4359
Epoch 1 Step 51 Train Loss: 0.5284
Epoch 1 Step 101 Train Loss: 0.6695
Epoch 1 Step 151 Train Loss: 0.5774
Epoch 1 Step 201 Train Loss: 0.6081
Epoch 1 Step 251 Train Loss: 0.4904
Epoch 1 Step 301 Train Loss: 0.5563
Epoch 1 Step 351 Train Loss: 0.4902
Epoch 1 Step 401 Train Loss: 0.4975
Epoch 1 Step 451 Train Loss: 0.5356
Epoch 1 Step 501 Train Loss: 0.5083
Epoch 1 Step 551 Train Loss: 0.5493
Epoch 1 Step 601 Train Loss: 0.5228
Epoch 1 Step 651 Train Loss: 0.5923
Epoch 1 St

In [27]:
no_perturb_mse_de_seeds = pd.DataFrame(res)
res_gears = {}
res_gears_normalized = {}
for key in subgroup_analysis_gears:
    res_gears[f"mse_de_seed{key}"] = pd.DataFrame(subgroup_analysis_gears[key]).T["mse_de"]
    res_gears_normalized[f"mse_de_seed{key}"] = pd.DataFrame(subgroup_analysis_gears[key]).T["mse_de"] / no_perturb_mse_de_seeds[f"mse_de_seed{key}"]
    

In [30]:
with open(DATA_DIR_LCL + "gears_subgroup_analysis.pkl", "wb") as f:
    pickle.dump(subgroup_analysis_gears, f)

In [31]:
pd.DataFrame(res_gears).to_csv(DATA_DIR_LCL + "gears_mse_de_seeds.csv")

In [32]:
pd.DataFrame(res_gears_normalized).to_csv(DATA_DIR_LCL + "gears_normalized_mse_de_seeds.csv")