# GEARS demo

In [1]:
import itertools
import os

import numpy as np
import pandas as pd
import torch
from anndata import AnnData
from scipy.sparse import csr_matrix
from scipy.stats import pearsonr
import random

# GEARS imports
from gears import GEARS, PertData

# Own imports
from data_utils.metrics import MMDLoss, compute_kld

In [7]:
DATA_DIR_PATH = "data"
MODELS_DIR_PATH = "models"
RESULTS_DIR_PATH = "results"

PREDICT_DOUBLE = True

# Set to True if training only has single perturbations (train + val) and double perturbations are on test.
SINGLE_TRAIN_ONLY = True

### Code and functions

In [8]:
def set_seeds(seed: int) -> None:
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [9]:

def train(
    pert_data: PertData, dataset_name: str, model_savedir: str,
    split: str, seed: int, hidden_size: int, device: str, epochs: int = 20
) -> str:
    """Set up, train, and save GEARS model."""
    print("Training GEARS model.")
    gears_model = GEARS(pert_data=pert_data, device=device)
    gears_model.model_initialize(hidden_size=hidden_size)
    gears_model.train(epochs=epochs)
    model_name = (
        f"gears_{dataset_name}_split_{split}_seed_{str(seed)}_hidden_size_{str(hidden_size)}"
    )
    gears_model.save_model(path=os.path.join(model_savedir, model_name))
    return model_name

In [10]:
def predict(pert_data: PertData,
            device: str, model_name: str,
            model_savedir: str, results_savedir: str) -> None:
    """Predict with GEARS model."""
    # Load the model.
    print("Loading GEARS model.")
    gears_model = GEARS(pert_data=pert_data, device=device)
    gears_model.load_pretrained(path=os.path.join(model_savedir, model_name))

    # Get all single perturbations.
    single_perturbations = set(
        [
            c.strip("+ctrl")
            for c in pert_data.adata.obs["condition"]
            if ("ctrl+" in c) or ("+ctrl" in c)
        ]
    )
    print(f"Number of single perturbations: {len(single_perturbations)}")

    # Get all double perturbations.
    double_perturbations = set(
        [c for c in pert_data.adata.obs["condition"] if "ctrl" not in c]
    )
    print(f"Number of double perturbations: {len(double_perturbations)}")

    # Generate all possible double perturbations (combos).
    combo_perturbations = []
    for g1 in single_perturbations:
        for g2 in single_perturbations:
            if g1 == g2:
                continue
            combo_perturbations.append(sorted([g1, g2]))
    combo_perturbations.sort()
    combo_perturbations = list(k for k, _ in itertools.groupby(combo_perturbations))
    print(f"Number of combo perturbations: {len(combo_perturbations)}")

    # Get the names of all measured genes as comma-separated list.
    var_names_str = ",".join(map(str, list(pert_data.adata.var_names)))

    if PREDICT_DOUBLE:
        # Predict all double perturbations.
        double_results_file_path = os.path.join(
            results_savedir, f"{model_name}_double.csv"
        )
        with open(file=double_results_file_path, mode="w") as f:
            print(f"double,{var_names_str}", file=f)
            for i, d in enumerate(double_perturbations):
                print(f"Predicting double {i + 1}/{len(double_perturbations)}: {d}")
                prediction = gears_model.predict(pert_list=[d.split("+")])
                double = next(iter(prediction.keys()))
                expressions = prediction[double]
                expressions_str = ",".join(map(str, expressions))
                print(f"{double},{expressions_str}", file=f)

Key points in evaluation:
- The original PertData used for training is reloaded, but only double perturbations are maintained as ground truth to compare against predictions.
- Randomly selected control samples based on current seed.
- Another randomly selected set of control samples.
- Instead of using the whole gene expression profiles (GEP), Differential Gene Expresion is used. </br>

The first set of control samples is substracted from True GEPs and Predicted GEPs for each double perturbation evaluation. </br>
For control samples, the second set of samples is substracted to ensure that baseline values are not constant, since that generates problems during computation. </br>

In [None]:
def evaluate_double(adata: AnnData, model_name: str, results_savedir: str,
                    pool_size: int = 250, seed: int = 42, top_deg: int = 20) -> None:
    """Evaluate the predicted GEPs of double perturbations."""
    # Load predicted GEPs.
    df = pd.read_csv(
        filepath_or_buffer=os.path.join(results_savedir, f"{model_name}_double.csv")
    )

    # Make results file path.
    results_file_path = os.path.join(
        results_savedir, f"{model_name}_double_metrics.csv"
    )

    with open(file=results_file_path, mode="w") as f:
        print(
            f"double,mmd_true_vs_ctrl,mmd_true_vs_pred,mse_true_vs_ctrl,mse_true_vs_pred,kld_true_vs_ctrl,kld_true_vs_pred,pearsonTop{top_deg}_true_vs_ctrl,pearson_pval_true_vs_ctrl,pearsonTop{top_deg}_true_vs_pred,pearson_pval_true_vs_pred",
            file=f,
        )

        for i, double in enumerate(df["double"]):
            # Get the predicted GEP for the current double perturbation.
            pred_geps = df.loc[df["double"] == double]
            pred_geps = pred_geps.iloc[0, 1:].tolist()
            pred_geps = np.array([pred_geps])

            # Get all the true GEPs with the current double perturbation.
            double = double.replace("_", "+")
            print(f"Evaluating double {i + 1}/{len(df['double'])}: {double}")
            true_geps = adata[adata.obs["condition"] == double]

            # Limiting n
            if true_geps.n_obs>pool_size:
                n = pool_size
                random_indices = np.random.choice(true_geps.n_obs, size=n, replace=False)
                true_geps = true_geps[random_indices, :]  
            else:
                # If less than pool size, randomly sample from all true_geps to avoid error in MMD computation
                n = true_geps.n_obs

            set_seeds(seed)

            # Obtaining random sample of ctrl GEP
            all_ctrl_geps = adata[adata.obs["condition"] == "ctrl"]
            random_indices = np.random.choice(
                all_ctrl_geps.n_obs, size=n, replace=False
            )
            ctrl_geps = all_ctrl_geps[random_indices, :]
            pred_geps = csr_matrix(np.tile(pred_geps, reps=(n, 1)))

            # Another random ctrl_gep
            random_indices_2 = np.random.choice(
                all_ctrl_geps.n_obs, size=n, replace=False
            )
            ctrl_geps_2 = all_ctrl_geps[random_indices_2, :]

            # Tensor conversion and differential expression
            ctrl_geps_tensor = torch.tensor(ctrl_geps.X.toarray())
            ctrl_ctrl_geps_tensor = torch.tensor(ctrl_geps_2.X.toarray()) - ctrl_geps_tensor
            true_ctrl_geps_tensor = torch.tensor(true_geps.X.toarray()) - ctrl_geps_tensor
            pred_ctrl_geps_tensor = torch.tensor(pred_geps.toarray()) - ctrl_geps_tensor

            # MMD setup.
            mmd_sigma = 200.0
            kernel_num = 10
            mmd_loss = MMDLoss(fix_sigma=mmd_sigma, kernel_num=kernel_num)

            # Compute MMD 
            mmd_true_vs_ctrl = mmd_loss.forward(
                            source=ctrl_ctrl_geps_tensor, target=true_ctrl_geps_tensor
                        )

            mmd_true_vs_pred = mmd_loss.forward(
                source=pred_ctrl_geps_tensor, target=true_ctrl_geps_tensor
            )

            # Compute MSE
            mse_true_vs_ctrl = torch.mean(
                (true_ctrl_geps_tensor - ctrl_ctrl_geps_tensor) ** 2
            ).item()
            mse_true_vs_pred = torch.mean(
                (true_ctrl_geps_tensor - pred_ctrl_geps_tensor) ** 2
            ).item()

            # Compute KLD
            kld_true_vs_ctrl = compute_kld(true_ctrl_geps_tensor, ctrl_ctrl_geps_tensor)
            kld_true_vs_pred = compute_kld(true_ctrl_geps_tensor, pred_ctrl_geps_tensor)

            # Compute Pearson for top DEG
            true_deg = true_ctrl_geps_tensor.mean(dim=0).cpu().detach().numpy()
            ctrl_ctrl_deg = ctrl_ctrl_geps_tensor.mean(dim=0).cpu().detach().numpy()
            pred_deg = pred_ctrl_geps_tensor.mean(dim=0).cpu().detach().numpy()
            topdeg_idx = np.argsort(abs(true_deg))[-top_deg:]

            pearson_true_vs_ctrl = pearsonr(true_deg[topdeg_idx], ctrl_ctrl_deg[topdeg_idx])
            pearson_true_vs_pred = pearsonr(true_deg[topdeg_idx], pred_deg[topdeg_idx])


            print(f"MMD (true vs. control):   {mmd_true_vs_ctrl:10.6f}")
            print(f"MMD (true vs. predicted): {mmd_true_vs_pred:10.6f}")
            print(f"MSE (true vs. control):   {mse_true_vs_ctrl:10.6f}")
            print(f"MSE (true vs. predicted): {mse_true_vs_pred:10.6f}")
            print(f"KLD (true vs. control):   {kld_true_vs_ctrl:10.6f}")
            print(f"KLD (true vs. predicted): {kld_true_vs_pred:10.6f}")
            print(f"Pearson Top {top_deg} DEG (true vs. control): {pearson_true_vs_ctrl.statistic:.6f} | p-value: {pearson_true_vs_ctrl.pvalue:.6f}")
            print(f"Pearson Top {top_deg} DEG (true vs. predicted): {pearson_true_vs_pred.statistic:.6f} | p-value: {pearson_true_vs_pred.pvalue:.6f}")

            print(
                f"{double},{mmd_true_vs_ctrl},{mmd_true_vs_pred},{mse_true_vs_ctrl},{mse_true_vs_pred},{kld_true_vs_ctrl},{kld_true_vs_pred},{pearson_true_vs_ctrl.statistic},{pearson_true_vs_ctrl.pvalue},{pearson_true_vs_pred.statistic},{pearson_true_vs_pred.pvalue}",
                file=f,
            )

### GEARS

The full script to run GEARS from command line also allows to change some hyperparameters. </br>

Data is loaded from the data path where the PertData object is stored. </br>
'Raw' data as .h5ad file has to be previously converted through GEARS' own data handler for the training to work. </br>

In this project, all models have been trained with only single perturbation and control data. GEARS' data handler is already built to be able to withold the fraction of double perturbation required.
 </br>

 Calculations are made for:
- MSE (overall prediction)
- MMD (distrubution)
- Pearson (prediction of direction of change) </br>

Pearson is calculated from the mean expression of all taken samples for each group (perturbation prediction, perturbation truth and control). </br>
For Pearson, using the whole vector of expression values will result in P close to 1 in most cases, as gene expression vector are often filled with genes that have 0 expression. Instead, the Top X Differentially Expressed Genes can be chosen, which are chosen based on the True perturbation samples. </br>
In this Notebook, only 1 iteration of the training is done, but the script is written to allow more iterations with different seeds.

In [None]:
split = 'simulation' # Divides into Train/Val and Test
seed = 42
hidden_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
project_name = 'demo_test'
dataset_name = 'norman_reduced'
pool_size = 50
num_runs = 1
epochs = 3
top_deg = 20

In [None]:
# Create directories.
os.makedirs(name=DATA_DIR_PATH, exist_ok=True)
os.makedirs(name=MODELS_DIR_PATH, exist_ok=True)
os.makedirs(name=RESULTS_DIR_PATH, exist_ok=True)

# Create directory for custom project name in MODELS_DIR_PATH and RESULTS_DIR_PATH.
model_savedir = os.path.join(MODELS_DIR_PATH, project_name)
results_savedir = os.path.join(RESULTS_DIR_PATH, project_name)
os.makedirs(name=model_savedir, exist_ok=True)
os.makedirs(name=results_savedir, exist_ok=True)

In [None]:
# Load "norman" data.
data_name = dataset_name
print(f"Loading '{data_name}' data.")
pertdata = PertData(DATA_DIR_PATH)
pertdata.load(data_path=os.path.join(DATA_DIR_PATH, data_name))

Found local copy...


Loading 'norman_reduced' data.


Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl' 'ctrl+IER5L']
Local copy of pyg dataset is detected. Loading...
Done!


In [None]:
# Running train, predict, evaluate for multiple runs.
current_run = 0
current_seed = seed + current_run
print(f"Current run: {current_run + 1}/{num_runs}, Seed: {current_seed}")

# Split data and get dataloaders.
# This split of train test sizes keeps singles in training and validation and doubles in testing
print("Preparing data split.")
if SINGLE_TRAIN_ONLY:
    # If training only has single perturbations, we need to set the train_gene_set_size
    # to 1.0 (keep all single pertubation genes) and combo_seen2_train_frac to 0.0 (don't keep any double perturbation genes)
    print("Keeping only single perturbation samples in training.")
    pertdata.prepare_split(
        split=split,
        seed=current_seed,
        train_gene_set_size=1.0,
        combo_seen2_train_frac=0.0
    )
else:
    pertdata.prepare_split(split=split, seed=current_seed)
pertdata.get_dataloader(batch_size=32)

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:130
unseen_single:0
Done!
Creating dataloaders....
Done!


Current run: 1/1, Seed: 42
Preparing data split.
Keeping only single perturbation samples in training.
here1


In [None]:
# Train.
model_name = train(
    pert_data=pertdata,
    dataset_name=dataset_name,
    model_savedir=model_savedir,
    split=split,
    seed=current_seed,
    hidden_size=hidden_size,
    device=device,
    epochs=epochs
)

Training GEARS model.


Found local copy...
Start Training...
Epoch 1 Step 1 Train Loss: 0.6946
Epoch 1 Step 51 Train Loss: 0.6973
Epoch 1 Step 101 Train Loss: 0.6872
Epoch 1: Train Overall MSE: 0.0068 Validation Overall MSE: 0.0049. 
Train Top 20 DE MSE: 0.1392 Validation Top 20 DE MSE: 0.2115. 
Epoch 2 Step 1 Train Loss: 0.5751
Epoch 2 Step 51 Train Loss: 0.7926
Epoch 2 Step 101 Train Loss: 0.7243
Epoch 2: Train Overall MSE: 0.0058 Validation Overall MSE: 0.0053. 
Train Top 20 DE MSE: 0.1113 Validation Top 20 DE MSE: 0.1954. 
Epoch 3 Step 1 Train Loss: 0.6627
Epoch 3 Step 51 Train Loss: 0.7356
Epoch 3 Step 101 Train Loss: 0.7539
Epoch 3: Train Overall MSE: 0.0055 Validation Overall MSE: 0.0053. 
Train Top 20 DE MSE: 0.0978 Validation Top 20 DE MSE: 0.2014. 
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.2586
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: nan
test_combo_seen0_pearson: nan
test_combo_seen0_mse_de: nan
test_combo_seen0_pearson_de: nan
test_comb

In [None]:
# Predict.
predict(pert_data=pertdata, device=device, model_name=model_name,
        model_savedir=model_savedir, results_savedir=results_savedir)



Loading GEARS model.
Number of single perturbations: 103
Number of double perturbations: 130
Number of combo perturbations: 5253
Predicting double 1/130: C3orf72+FOXL2
Predicting double 2/130: ZBTB10+PTPN12
Predicting double 3/130: SGK1+S1PR2
Predicting double 4/130: RHOXF2+ZBTB25
Predicting double 5/130: KLF1+TGFBR2
Predicting double 6/130: CEBPE+SPI1
Predicting double 7/130: DUSP9+PRTG
Predicting double 8/130: ETS2+MAP7D1
Predicting double 9/130: CDKN1C+CDKN1B
Predicting double 10/130: MAP2K3+MAP2K6
Predicting double 11/130: SAMD1+UBASH3B
Predicting double 12/130: CEBPE+CEBPB
Predicting double 13/130: FOSB+CEBPB
Predicting double 14/130: DUSP9+SNAI1
Predicting double 15/130: CBL+CNN1
Predicting double 16/130: FOXA3+FOXF1
Predicting double 17/130: CBL+UBASH3B
Predicting double 18/130: ZC3HAV1+CEBPE
Predicting double 19/130: AHR+KLF1
Predicting double 20/130: BCL2L11+BAK1
Predicting double 21/130: FEV+ISL2
Predicting double 22/130: JUN+CEBPB
Predicting double 23/130: LYL1+CEBPB
Predict

In [None]:
# Evaluate.
pertdata = PertData(DATA_DIR_PATH)
pertdata.load(data_path=os.path.join(DATA_DIR_PATH, data_name))
evaluate_double(adata=pertdata.adata, model_name=model_name,
                results_savedir=results_savedir, pool_size=pool_size,
                seed=current_seed, top_deg=top_deg)

Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl' 'ctrl+IER5L']
Local copy of pyg dataset is detected. Loading...
Done!


Evaluating double 1/130: C3orf72+FOXL2
MMD (true vs. control):     0.339430
MMD (true vs. predicted):   0.623679
MSE (true vs. control):     0.086623
MSE (true vs. predicted):   0.053890
KLD (true vs. control):     0.002077
KLD (true vs. predicted):   0.001739
Pearson Top 20 DEG (true vs. control): -0.066646 | p-value: 0.780115
Pearson Top 20 DEG (true vs. predicted): 0.750948 | p-value: 0.000136
Evaluating double 2/130: ZBTB10+PTPN12
MMD (true vs. control):     0.376876
MMD (true vs. predicted):   0.703073
MSE (true vs. control):     0.091089
MSE (true vs. predicted):   0.058669
KLD (true vs. control):     0.003361
KLD (true vs. predicted):   0.003322
Pearson Top 20 DEG (true vs. control): -0.133943 | p-value: 0.573443
Pearson Top 20 DEG (true vs. predicted): 0.854320 | p-value: 0.000002
Evaluating double 3/130: SGK1+S1PR2
MMD (true vs. control):     0.385643
MMD (true vs. predicted):   0.638535
MSE (true vs. control):     0.083990
MSE (true vs. predicted):   0.050543
KLD (true vs. co

In [None]:
import pandas as pd
filepath = '../cris_test/results/demo_test/gears_norman_reduced_split_simulation_seed_42_hidden_size_64_double_metrics.csv'
results = pd.read_csv(filepath)

In [None]:
results

Unnamed: 0,double,mmd_true_vs_ctrl,mmd_true_vs_pred,mse_true_vs_ctrl,mse_true_vs_pred,kld_true_vs_ctrl,kld_true_vs_pred,pearsonTop20_true_vs_ctrl,pearson_pval_true_vs_ctrl,pearsonTop20_true_vs_pred,pearson_pval_true_vs_pred
0,C3orf72+FOXL2,0.339430,0.623679,0.086623,0.053890,0.002077,0.001739,-0.066646,0.780115,0.750948,1.358936e-04
1,ZBTB10+PTPN12,0.376876,0.703073,0.091089,0.058669,0.003361,0.003322,-0.133943,0.573443,0.854320,1.633744e-06
2,SGK1+S1PR2,0.385643,0.638535,0.083990,0.050543,0.004669,0.003529,0.367951,0.110454,0.710611,4.456751e-04
3,RHOXF2+ZBTB25,0.410998,0.645421,0.085957,0.050997,0.006405,0.003855,0.215142,0.362335,0.920853,8.662332e-09
4,KLF1+TGFBR2,0.332317,0.551943,0.067557,0.036747,0.003217,0.003444,-0.250833,0.286107,0.771662,6.756299e-05
...,...,...,...,...,...,...,...,...,...,...,...
125,CEBPB+PTPN12,0.314081,0.585183,0.077841,0.045880,0.001748,0.002482,0.297266,0.203080,0.678285,1.011904e-03
126,IGDCC3+ZBTB25,0.394384,0.607104,0.083960,0.048898,0.005252,0.003518,0.272205,0.245624,0.916520,1.376997e-08
127,MAP2K3+SLC38A2,0.296137,0.498100,0.069197,0.037744,0.001132,0.001821,0.141487,0.551835,0.837527,4.088613e-06
128,CDKN1B+CDKN1A,0.371046,0.604317,0.079099,0.045504,0.003326,0.002521,0.213503,0.366091,0.928270,3.672547e-09
