In [1]:
# load packages
from pathlib import Path
import shutil
import numpy as np
import json
import tempfile
import argparse
import session_info

from gears import PertData, GEARS

In [2]:
# define 
dataset_name = "southard_rpe1_crispri"   # change if needed
test_train_config_path = "/.mounts/labs/steinlab/scratch/mtong/datasets/seq/southard_RPE1_CRISPRi/southard_data/results/set2conditions.json"
epochs = 20
seed = 1
working_dir = "/.mounts/labs/steinlab/scratch/mtong/datasets/seq/southard_RPE1_CRISPRi/southard_data"

# make sure this directory exists
out_dir = f"{working_dir}/results/GEARS"
np.random.seed(seed)

In [3]:
# load data into Pertdata object
pert_data_folder = Path("/.mounts/labs/steinlab/scratch/mtong/datasets/seq/southard_RPE1_CRISPRi")
pert_data = PertData(pert_data_folder)
pert_data.load(data_path = f"{working_dir}/{dataset_name}")

with open(test_train_config_path, "r") as json_file:
    set2conditions = json.load(json_file)

Local copy of pyg dataset is detected. Loading...
Done!


In [4]:
# define data split
print(set2conditions)
pert_data.set2conditions = set2conditions
pert_data.split = "custom"
pert_data.subgroup = None
pert_data.seed = 1
pert_data.train_gene_set_size = 0.75 # number of genes used for training
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)

Creating dataloaders....
Done!


{'test': ['RPL18A+ctrl', 'RPS8+ctrl', 'POGLUT3+ctrl', 'EIF3A+ctrl', 'RPL15+ctrl', 'POLR3E+ctrl', 'UXT+ctrl', 'PPP1R37+ctrl', 'RPL4+ctrl', 'KPNB1+ctrl', 'RPL38+ctrl', 'MED30+ctrl', 'RPS3A+ctrl', 'RPS15A+ctrl', 'TYK2+ctrl', 'RPS19+ctrl', 'EIF3CL+ctrl', 'ERCC3+ctrl', 'EXOSC8+ctrl', 'EIF1AX+ctrl', 'NUP214+ctrl', 'FUNDC2+ctrl', 'EIF3E+ctrl', 'RPTOR+ctrl', 'EXOSC5+ctrl', 'EIF3M+ctrl', 'RPS26+ctrl', 'GTF2E1+ctrl', 'TSR2+ctrl', 'MED1+ctrl', 'RPL35+ctrl', 'RPL23+ctrl', 'RPS4X+ctrl', 'RPL6+ctrl'], 'train': ['ctrl', 'ADAM10+ctrl', 'RPL14+ctrl', 'MED21+ctrl', 'MED11+ctrl', 'UTP18+ctrl', 'NUP205+ctrl', 'WDR43+ctrl', 'SIRT7+ctrl', 'MED18+ctrl', 'RPL37A+ctrl', 'NUP98+ctrl', 'EIF3I+ctrl', 'RPL32+ctrl', 'RPL39+ctrl', 'GLE1+ctrl', 'TMX2+ctrl', 'URI1+ctrl', 'RPS14+ctrl', 'XRCC5+ctrl', 'MED9+ctrl', 'RPS29+ctrl', 'RPS6+ctrl', 'TBCB+ctrl', 'POLR3D+ctrl', 'RPS13+ctrl', 'EXOSC3+ctrl', 'TAF5+ctrl', 'NUP85+ctrl', 'NUP88+ctrl', 'MED12+ctrl', 'RPS2+ctrl', 'RPLP0+ctrl', 'RPL30+ctrl', 'MED20+ctrl', 'EIF3D+ctrl', 'U

In [7]:
# generate perturbation relationship graph 
import pandas as pd
k = 20
df_jaccard = pd.read_csv('/.mounts/labs/steinlab/scratch/mtong/datasets/model/GEARS/go_essential_all.csv')
df_out = df_jaccard.groupby('target').apply(lambda x: x.nlargest(k + 1, ['importance'])).reset_index(drop = True)
pert_list = list(pert_data.adata.obs['pert_gene'].unique())
node_map_pert = {pert: i for i, pert in enumerate(pert_list)}
genes_in_dataset = list(node_map_pert)
df_filtered = df_out[
    df_out['source'].isin(genes_in_dataset) & df_out['target'].isin(genes_in_dataset)
]

In [8]:
from gears import utils
sim_network = utils.GeneSimNetwork(df_filtered, pert_list, node_map = node_map_pert)

In [9]:
# initiate model training
gears_model = GEARS(pert_data, device = 'cuda')
gears_model.model_initialize(hidden_size = 64, G_go=sim_network.edge_index, G_go_weight=sim_network.edge_weight)
gears_model.train(epochs = epochs)

Start Training...
Epoch 1 Step 1 Train Loss: 1.1750
Epoch 1 Step 51 Train Loss: 1.0682
Epoch 1 Step 101 Train Loss: 1.4802
Epoch 1 Step 151 Train Loss: 0.9636
Epoch 1 Step 201 Train Loss: 0.9985
Epoch 1 Step 251 Train Loss: 1.3827
Epoch 1 Step 301 Train Loss: 1.3094
Epoch 1 Step 351 Train Loss: 0.8933
Epoch 1 Step 401 Train Loss: 1.1013
Epoch 1 Step 451 Train Loss: 0.8587
Epoch 1 Step 501 Train Loss: 1.0574
Epoch 1 Step 551 Train Loss: 1.1351
Epoch 1 Step 601 Train Loss: 1.3079
Epoch 1 Step 651 Train Loss: 1.2927
Epoch 1 Step 701 Train Loss: 1.1259
Epoch 1 Step 751 Train Loss: 1.1658
Epoch 1 Step 801 Train Loss: 0.9216
Epoch 1 Step 851 Train Loss: 1.3044
Epoch 1 Step 901 Train Loss: 1.1850
Epoch 1 Step 951 Train Loss: 1.5485
Epoch 1 Step 1001 Train Loss: 1.5330
Epoch 1 Step 1051 Train Loss: 1.1749
Epoch 1 Step 1101 Train Loss: 1.3811
Epoch 1 Step 1151 Train Loss: 1.2651
Epoch 1 Step 1201 Train Loss: 1.5043
Epoch 1 Step 1251 Train Loss: 1.1507
Epoch 1 Step 1301 Train Loss: 1.4207
Epoch 

In [10]:
# save model and generate predictions
gears_model.save_model(out_dir)

conds = pert_data.adata.obs["condition"].cat.remove_unused_categories().cat.categories.tolist()
split_conds = [x.split("+") for x in conds]
split_conds = [list(filter(lambda y: y != "ctrl", x)) for x in split_conds]

all_pred_vals = gears_model.predict(split_conds)
all_pred_vals = {k: v.tolist() for k, v in all_pred_vals.items()}

# Save ground truth
ground_truth_vals = {}
for cond in conds:
    obs_idx = pert_data.adata.obs['condition'] == cond
    mean_expr = np.asarray(pert_data.adata[obs_idx, :].X.mean(axis=0)).ravel()
    # Remove "+ctrl" from the key to match prediction
    key = cond.replace("+ctrl", "")
    ground_truth_vals[key] = mean_expr.tolist()

# Convert empty string key to 'ctrl'
if '' in all_pred_vals:
    all_pred_vals['ctrl'] = all_pred_vals.pop('')

# sanity check
pred_conditions = set(all_pred_vals.keys())
gt_conditions = set(ground_truth_vals.keys())
assert pred_conditions == gt_conditions, f"Mismatch in conditions: {pred_conditions ^ gt_conditions}"

In [11]:
# save 
with open(f"{out_dir}/all_predictions.json", 'w', encoding="utf8") as handle:
    json.dump(all_pred_vals, handle, indent = 4)
with open(f"{out_dir}/gene_names.json", 'w', encoding="utf8") as handle:
    json.dump(pert_data.adata.var["gene_name"].values.tolist(), handle, indent = 4)
with open(f"{out_dir}/all_ground_truth.json", 'w', encoding="utf8") as handle:
    json.dump(ground_truth_vals, handle, indent=4)

session_info.show()
print("Python done")

Python done
