# Preprocessing the _Norman et al._[[1]](https://doi.org/10.1126/science.aax4438) dataset

We use `GEARS`[[2]](https://doi.org/10.1101/2022.07.12.499735) to: 
1. Get the `No-perturb` normalization baseline for the Perturb-Seq dataset by _Norman et al._ [[1]](https://www.science.org/doi/10.1126/science.aax6234) dataset.
2. Extract the `adata` object for `biolord`
3. Run `GEARS` as baseline comparison to `biolord`. 


[[1] Norman, T. M., Horlbeck, M. A., Replogle, J. M., Ge, A. Y., Xu, A., Jost, M., ... & Weissman, J. S. (2019). Exploring genetic interaction manifolds constructed from rich single-cell phenotypes. Science, 365(6455), 786-793.](https://doi.org/10.1126/science.aax4438)

[[2] Roohani, Y., Huang, K., & Leskovec, J. (2022). GEARS": Predicting transcriptional outcomes of novel multi-gene perturbations. BioRxiv, 2022-07.](https://doi.org/10.1101/2022.07.12.499735)


## Load packages

In [1]:
import sys
import pandas as pd
import pickle
import numpy as np
import anndata

from gears import PertData, GEARS
import torch

In [2]:
sys.path.append("../../../")
sys.path.append("../../utils/")
from paths import DATA_DIR, FIG_DIR

## Set parameters

In [3]:
DATA_DIR_LCL = str(DATA_DIR) + "/perturbations/norman/"
FIG_DIR_LCL = str(FIG_DIR) + "/perturbations/norman"

DATA_DIR_LCL = "/cs/labs/mornitzan/zoe.piran/research/projects/biolord_data/data/norman2019/"

In [4]:
device = torch.cuda.current_device()
batch_size = 32

## Create `no_perturb` baseline

In [5]:
epoch = 0
no_perturb = True

In [6]:
subgroup_analysis = {}
for seed in range(1,6):
    pert_data = PertData(DATA_DIR_LCL[:-1],  gene_path=DATA_DIR_LCL + "essential_norman.pkl") # specific saved folder
    pert_data.load(data_path = DATA_DIR_LCL + "norman2019")
    pert_data.prepare_split(split = "simulation", seed = seed)
    pert_data.get_dataloader(batch_size = batch_size, test_batch_size = batch_size)
    
    gears_model = GEARS(
        pert_data, 
        device = "cuda:" + str(device), 
        weight_bias_track = False, 
        proj_name = "norman2019",
        exp_name = "no_perturb_seed" + str(seed)
    )
    gears_model.model_initialize(hidden_size = 64, no_perturb = no_perturb,  go_path = DATA_DIR_LCL + "go_essential_norman.csv")
    subgroup_analysis[seed], subgroup_analysis_deeper, test_metrics = gears_model.train(epochs = epoch)
    

Found local copy...
These perturbations are not in the GO graph and is thus not able to make prediction for...
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:43
combo_seen2:19
unseen_single:36
Done!
Creating dataloaders....
Done!
Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.4383
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: 0.008115437
test_combo_seen0_pearson: 0.9752434637166935
test_combo_seen0_mse_de: 0.36536023
test_combo_seen0_pearson_de: 0.6961077512802023
test_combo_seen1_mse: 0.009594609
test_combo_seen1_pearson: 0.9712218966592624
test_combo_seen1_mse_de: 0.4883878
test_combo_seen1_pearson_de: 0.7296004579033083
test_combo_seen2_mse: 0.0067827976
test_combo_seen2_pearson: 0.9796019736738077

In [7]:
res = {}
for key in subgroup_analysis:
    res[f"mse_de_seed{key}"] = pd.DataFrame(subgroup_analysis[key]).T["mse_de"]

In [8]:
with open(DATA_DIR_LCL + "no_perturb_subgroup_analysis.pkl", "wb") as f:
    pickle.dump(subgroup_analysis, f)

In [9]:
pd.DataFrame(res).to_csv(DATA_DIR_LCL + "no_perturb_mse_de_seeds.csv")

In [10]:
no_perturb_mse_de_seeds = pd.read_csv(DATA_DIR_LCL + "no_perturb_mse_de_seeds.csv", index_col=0)
no_perturb_mse_de_seeds

Unnamed: 0,mse_de_seed1,mse_de_seed2,mse_de_seed3,mse_de_seed4,mse_de_seed5
combo_seen0,0.36536,0.681767,0.745222,0.430637,0.752053
combo_seen1,0.488388,0.552666,0.535759,0.451559,0.630864
combo_seen2,0.609232,0.51447,0.522522,0.504532,0.542322
unseen_single,0.306378,0.410966,0.26151,0.238054,0.398873


## Set `adata`

In [11]:
adata = pert_data.adata.copy()

In [12]:
for seed in range(1,6):
    with open(DATA_DIR_LCL + f"norman2019/splits/norman2019_simulation_{seed}_0.75.pkl", "rb") as f:
        split_data = pickle.load(f)
        pert2set = {}
        for i,j in split_data.items():
            for x in j:
                pert2set[x] = i
        
        subgroup = pickle.load(open(DATA_DIR_LCL + f"norman2019/splits/norman2019_simulation_{seed}_0.75_subgroup.pkl", "rb"))
        adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
        pert2subgroup = {}
        for i,j in subgroup["test_subgroup"].items():
            for x in j:
                pert2subgroup[x] = i
        
        adata.obs[f"subgroup{seed}"] = adata.obs["condition"].apply(lambda x: pert2subgroup[x] if x in pert2subgroup else 'Train/Val')
        rename = {
            'train': 'train',
             'test': 'ood',
             'val': 'test'
        }
        adata.obs[f'split{seed}'] = adata.obs[f'split{seed}'].apply(lambda x: rename[x])


In [13]:
adata.obs["perturbation"] = [cond.split("+")[0] for cond in adata.obs["condition"]]
adata.obs["perturbation_rep"] = [cond.split("+")[1] if len(cond.split("+")) > 1 else "ctrl" for cond in adata.obs["condition"]]
adata.obs["perturbation"] = adata.obs["perturbation"].astype("category")
adata.obs["perturbation_rep"] = adata.obs["perturbation_rep"].astype("category")
new_cats = adata.obs["perturbation_rep"].cat.categories[~adata.obs["perturbation_rep"].cat.categories.isin(adata.obs["perturbation"].cat.categories)]
adata.obs["perturbation"] = adata.obs["perturbation"].cat.add_categories(new_cats)
new_cats_rep =  adata.obs["perturbation"].cat.categories[~adata.obs["perturbation"].cat.categories.isin(adata.obs["perturbation_rep"].cat.categories)]
adata.obs["perturbation_rep"] = adata.obs["perturbation_rep"].cat.add_categories(new_cats_rep)
adata.obs["perturbation"] = adata.obs["perturbation"].cat.reorder_categories(adata.obs["perturbation_rep"].cat.categories)

In [14]:
go_path = DATA_DIR_LCL + "go_essential_norman.csv"
gene_path = DATA_DIR_LCL + "essential_norman.pkl"
df = pd.read_csv(go_path)
df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
with open(gene_path, 'rb') as f:
    gene_list = pickle.load(f)
    
df = df[df["source"].isin(gene_list)]

In [15]:
def get_map(pert):
    tmp = pd.DataFrame(np.zeros(len(gene_list)), index=gene_list)
    tmp.loc[df[df.target == pert].source.values, :] = df[df.target == pert].importance.values[:, np.newaxis]
    return tmp.values.flatten()    
    

In [16]:
pert2neighbor =  {i: get_map(i) for i in list(adata.obs["perturbation"].cat.categories)}    
adata.uns["pert2neighbor"] = pert2neighbor

In [17]:
pert2neighbor = np.asarray([val for val in adata.uns["pert2neighbor"].values()])
keep_idx = pert2neighbor.sum(0) > 0
keep_idx1 = pert2neighbor.sum(0) > 1
keep_idx2 = pert2neighbor.sum(0) > 2
keep_idx3 = pert2neighbor.sum(0) > 3

name_map = dict(adata.obs[["condition", "condition_name"]].drop_duplicates().values)
ctrl = np.asarray(adata[adata.obs["condition"].isin(["ctrl"])].X.mean(0)).flatten() 

In [18]:
df_perts_expression = pd.DataFrame(adata.X.A, index=adata.obs_names, columns=adata.var_names)
df_perts_expression["condition"] = adata.obs["condition"]
df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
df_perts_expression = df_perts_expression.reset_index()

single_perts_condition = []
single_pert_val = []
double_perts = []
for pert in adata.obs["condition"].cat.categories:
    if len(pert.split("+")) == 1:
        continue
    elif "ctrl" in pert:
        single_perts_condition.append(pert)
        p1, p2 = pert.split("+")
        if p2 == "ctrl":
            single_pert_val.append(p1)
        else:
            single_pert_val.append(p2)
    else:
        double_perts.append(pert)
single_perts_condition.append("ctrl")
single_pert_val.append("ctrl")

df_singleperts_expression = pd.DataFrame(df_perts_expression.set_index("condition").loc[single_perts_condition].values, index=single_pert_val)
df_singleperts_emb = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx] for p1 in df_singleperts_expression.index])
df_singleperts_emb1 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx1] for p1 in df_singleperts_expression.index])
df_singleperts_emb2 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx2] for p1 in df_singleperts_expression.index])
df_singleperts_emb3 = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx3] for p1 in df_singleperts_expression.index])

df_singleperts_condition = pd.Index(single_perts_condition)
df_single_pert_val = pd.Index(single_pert_val)

df_doubleperts_expression = df_perts_expression.set_index("condition").loc[double_perts].values
df_doubleperts_condition = pd.Index(double_perts)

In [19]:
adata_single = anndata.AnnData(X=df_singleperts_expression.values, var=adata.var.copy(), dtype=df_singleperts_expression.values.dtype)
adata_single.obs_names = df_singleperts_condition
adata_single.obs["condition"] = df_singleperts_condition
adata_single.obs["perts_name"] = df_single_pert_val
adata_single.obsm["perturbation_neighbors"] = df_singleperts_emb
adata_single.obsm["perturbation_neighbors1"] = df_singleperts_emb1
adata_single.obsm["perturbation_neighbors2"] = df_singleperts_emb2
adata_single.obsm["perturbation_neighbors3"] = df_singleperts_emb3

In [20]:
for split_seed in range(1,6):
    adata_single.obs[f"split{split_seed}"] = None
    adata_single.obs[f"subgroup{split_seed}"] = "Train/Val"
    for cat in ["train","test","ood"]:
        cat_idx = adata_single.obs["condition"].isin(adata[adata.obs[f"split{split_seed}"] == cat].obs["condition"].cat.categories)
        adata_single.obs.loc[cat_idx ,f"split{split_seed}"] = cat
        if cat == "ood":
            for ood_set in ["combo_seen0", "combo_seen1", "combo_seen2", "unseen_single"]:
                idx_ood = adata_single.obs["condition"].isin(adata[adata.obs[f"subgroup{split_seed}"] == ood_set].obs["condition"].cat.categories)
                adata_single.obs.loc[idx_ood ,f"subgroup{split_seed}"] = ood_set

In [21]:
adata_single.write(DATA_DIR_LCL + "norman2019/norman2019_single_biolord.h5ad")

In [22]:
adata.write(DATA_DIR_LCL + "norman2019/norman2019_biolord.h5ad")

## Train `GEARS` models

In [23]:
epoch = 15
no_perturb = False

In [24]:
subgroup_analysis_gears = {}
for seed in range(1,6):
    pert_data = PertData(DATA_DIR_LCL[:-1], gene_path=DATA_DIR_LCL + "essential_norman.pkl") # specific saved folder
    pert_data.load(data_path = DATA_DIR_LCL + "norman2019")
    pert_data.prepare_split(split = "simulation", seed = seed)
    pert_data.get_dataloader(batch_size = batch_size, test_batch_size = batch_size)
    
    gears_model = GEARS(
        pert_data, 
        device = "cuda:" + str(device), 
        weight_bias_track = False, 
        proj_name = "norman2019",
        exp_name = "gears_seed" + str(seed)
    )
    gears_model.model_initialize(hidden_size = 64, no_perturb = no_perturb,  go_path = DATA_DIR_LCL + "go_essential_norman.csv")
    subgroup_analysis_gears[seed], subgroup_analysis_deeper, test_metrics = gears_model.train(epochs = epoch)

Found local copy...
These perturbations are not in the GO graph and is thus not able to make prediction for...
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:43
combo_seen2:19
unseen_single:36
Done!
Creating dataloaders....
Done!
Start Training...
Epoch 1 Step 1 Train Loss: 0.5212
Epoch 1 Step 51 Train Loss: 0.4818
Epoch 1 Step 101 Train Loss: 0.4277
Epoch 1 Step 151 Train Loss: 0.4037
Epoch 1 Step 201 Train Loss: 0.4787
Epoch 1 Step 251 Train Loss: 0.4899
Epoch 1 Step 301 Train Loss: 0.4857
Epoch 1 Step 351 Train Loss: 0.5484
Epoch 1 Step 401 Train Loss: 0.4155
Epoch 1 Step 451 Train Loss: 0.5776
Epoch 1 Step 501 Train Loss: 0.5211
Epoch 1 Step 551 Train Loss: 0.5474
Epoch 1 Step 601 Train Loss: 0.4326
Epoch 1 Step 651 Train Loss: 0.5109
Epoch 1 Step 701 Train Loss: 0

In [25]:
res_gears = {}
res_gears_normalized = {}
for key in subgroup_analysis_gears:
    res_gears[f"mse_de_seed{key}"] = pd.DataFrame(subgroup_analysis_gears[key]).T["mse_de"]
    res_gears_normalized[f"mse_de_seed{key}"] = pd.DataFrame(subgroup_analysis_gears[key]).T["mse_de"] / no_perturb_mse_de_seeds[f"mse_de_seed{key}"]
    

In [26]:
with open(DATA_DIR_LCL + "gears_subgroup_analysis.pkl", "wb") as f:
    pickle.dump(subgroup_analysis_gears, f)

In [27]:
pd.DataFrame(res_gears).to_csv(DATA_DIR_LCL + "gears_mse_de_seeds.csv")

In [28]:
pd.DataFrame(res_gears_normalized).to_csv(DATA_DIR_LCL + "gears_normalized_mse_de_seeds.csv")

In [29]:
df_gears = pd.DataFrame(res_gears_normalized).T
df_gears.loc["mean"] = df_gears.mean(0)
df_gears

Unnamed: 0,combo_seen0,combo_seen1,combo_seen2,unseen_single
mse_de_seed1,0.397848,0.399881,0.287964,0.658105
mse_de_seed2,0.690452,0.393775,0.178161,0.667789
mse_de_seed3,0.550081,0.399856,0.227736,0.772609
mse_de_seed4,0.382077,0.34689,0.264457,0.647327
mse_de_seed5,0.606267,0.499058,0.344923,0.786129
mean,0.525345,0.407892,0.260648,0.706392
