# LGEM demo

In [1]:
# imports
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from typing import List, Tuple
from scipy.stats import pearsonr
from sklearn.decomposition import PCA

import os

from gears import PertData

# own imports
from lgem.models import (
    LinearGeneExpressionModelLearned,
    LinearGeneExpressionModelOptimized,
)
from lgem.test import test as test_lgem
from lgem.train import train as train_lgem
from lgem.data import pseudobulk_data_per_perturbation
from data_utils.single_norman_utils import separate_data, get_common_genes

from lgem.utils import predict_evaluate_lgem_double

#### Some relevant code to see

In [2]:
def compute_embeddings_double(
    Y: torch.Tensor,  # noqa: N803
    perts: List[str], # all perturbations, single and double
    genes: List[str],
    d_embed: int = 10,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute gene and perturbation embeddings.

    Args:
        Y: Data matrix with shape (n_genes, n_perturbations).
        perts: List of perturbations.
        genes: List of genes.
        d_embed: Embedding dimension.

    Returns:
        G: Gene embedding matrix with shape (n_genes, d_embed).
        P: Perturbation embedding matrix with shape (n_perturbations, d_embed).
        b: Bias vector with shape (n_genes).
    """
    # Perform a PCA on Y to obtain the top d_embed principal components, which will
    # serve as the gene embeddings G.
    pca = PCA(n_components=d_embed)
    G = pca.fit_transform(Y)  # noqa: N806

    gene_to_idx = {gene: i for i, gene in enumerate(genes)}
    gene_to_emb = {gene: G[i] for gene, i in gene_to_idx.items()}

    P = []
    missing = []
    for pert in perts:
        genes_in_pert = pert.split("+")
        try:
            emb_list = [gene_to_emb[g] for g in genes_in_pert]
            pert_emb = np.mean(emb_list, axis=0)  # average embeddings
            P.append(pert_emb)
        except KeyError as e:
            missing.append(pert)
            P.append(np.zeros(d_embed))  # fallback if gene not found

    P = np.array(P)
    if missing:
        print(f"{len(missing)}/{len(perts)} missing embeddings.")
        print(f"Missing embeddings for perturbations: {missing}")

    # Compute b as the average expression of each gene across all perturbations.
    b = Y.mean(axis=1)

    return torch.from_numpy(G).float(), torch.from_numpy(P).float(), b

## LGEM

In [9]:
# Parameters
criterion = nn.MSELoss()
name = 'demo_test' # Name of project
savedir='/workspace/tfm/cris_test/models/'
eval_dir = '/workspace/tfm/cris_test/results/'
data_dir = '/workspace/tfm/cris_test/data'

eval_dir = os.path.join(eval_dir, name)
os.makedirs(eval_dir, exist_ok=True)
savedir = os.path.join(savedir, name)
os.makedirs(savedir, exist_ok=True)

dataset_name='norman_reduced'
prediction_type='double'
seed=1
num_runs=1
n_epochs=200
batch_size=8
top_deg=20 # Top Differentially expressed genes to evalaute from

run_name = f"lgem_{dataset_name}_epochs_{n_epochs}_batch_{batch_size}"
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

torch.manual_seed(seed)

# Make results file path.
results_file_path = os.path.join(
    eval_dir, f"{run_name}_train_test_metrics.csv"
    )

In [12]:
# Load data
print(f"Loading dataset from {data_dir}.")
datahandler = PertData(data_dir)
datahandler.load(data_path = os.path.join(data_dir, dataset_name)) # load the processed data, the path is saved folder + dataset_name

pertdata = datahandler.adata
# pertdata.prepare_split(split = 'simulation', seed = 42) # get data split with seed
# pertdata.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader

Found local copy...


Loading dataset from /workspace/tfm/cris_test/data.


Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl' 'ctrl+IER5L']
Local copy of pyg dataset is detected. Loading...
Done!


In [5]:
pertdata

View of AnnData object with n_obs × n_vars = 11700 × 5000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_ids', 'guide_merged', 'split', 'batch', 'condition', 'cell_type', 'dose_val', 'control', 'drug_dose_name', 'cov_drug_dose_name', 'condition_name', 'condition_fixed'
    var: 'gene_symbols', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_name'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups', 'rank_genes_groups_cov', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20'
    layers: 'counts'

In [13]:
# Get separate dataset for single, double perts and controls
# Get resulting RNA-seq dataset with valid samples (sample pert genes are genes found in features)
pertdata_single, pertdata_double, pertdata_ctrl = separate_data(adata = pertdata, dataset_name = dataset_name)

if prediction_type == "double":
    # Join both AnnData datasets
    pertdata_both = pertdata_single.concatenate(pertdata_double, join = 'outer', index_unique = '-')
    all_perts, perts, genes, pertdata_common = get_common_genes(adata = pertdata_both, dataset_name = dataset_name)
else:
    # Y, perts and genes
    all_perts, perts, genes, pertdata_common = get_common_genes(adata_single = pertdata_single, dataset_name = dataset_name)    

print(f"Number of unique perturbations: {len(perts)}/{len(all_perts)}")


Number of unique perturbations: 125/11650


- lgem works with individual samples for each condition
- Ordered perturbations equivalent to Y is in perts

This model:
- Separates single gene perturbation samples from double perturbation samples
- single perts: split into train and validation
- double perts: all of them used as test

### Training lgem

In [14]:
Y = pseudobulk_data_per_perturbation(perts, genes, pertdata_common)

# Compute the embeddings on the entire dataset (with singles and doubles)
G, P, b = compute_embeddings_double(Y.T, perts, genes)  # noqa: N806

# Keeping only embedding related to single perturbations for training
singles_idx = [i for i, pert in enumerate(perts) if "+" not in pert]
sY = Y[singles_idx, :]
sP = P[singles_idx, :]

with open(file=results_file_path, mode="w") as f:
    print(
        "seed,optimized_train_loss,optimized_model_loss,learned_train_loss,learned_model_loss",
        file=f,
    )

    # Setting up several runs (for this example, only 1)
    for current_run in range(num_runs):
        current_seed = seed + current_run
        torch.manual_seed(current_seed)
        model_name = f"lgem_{dataset_name}_seed_{current_seed}_epoch_{n_epochs}_batch_{batch_size}"

        # Directory for custom name for model where more pickle files will be saved
        # savedir = args.savedir
        current_savedir = os.path.join(savedir, model_name)
        os.makedirs(current_savedir, exist_ok=True)


# If prediction type is not double, splitter only splits into train and test
# Y, P used
        if prediction_type != "double":
            # Split the data into training and test sets and create the dataloaders.
            Y_train, Y_test, P_train, P_test = train_test_split(  # noqa: N806
                Y, P, test_size=0.2, random_state=current_seed
            )
            # Get indices of perturbations in the train/test sets
            train_indices, test_indices = train_test_split(range(len(perts)), test_size=0.2, random_state=42)
            # Get the equivalent perturbations for the training and test sets
            perts_train = [perts[i] for i in train_indices]
            perts_test = [perts[i] for i in test_indices]

            # Dataloaders
            train_dataset = TensorDataset(P_train, Y_train)
            test_dataset = TensorDataset(P_test, Y_test)
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

            # # Saving dataloaders as pickles and perturbation order from perts
            # torch.save(train_dataloader, os.path.join(current_savedir, "train_dataloader.pt"))
            # torch.save(test_dataloader, os.path.join(current_savedir, "train_dataloader.pt"))
            # torch.save({"perts_train": perts_train,
            #             "perts_test": perts_test},
            #             os.path.join(current_savedir, "perts.pt"))
            # val_dataloader = None

# If prediction is DOUBLE, train and valdation are taken from single pert sampels
# sY, sP (previously separated) are used
        else:
            # Keep train and val dataloaders for training
            Y_train, Y_val, P_train, P_val = train_test_split(  # noqa: N806
                sY, sP, test_size=0.2, random_state=current_seed
            )

            # Dataloaders
            train_dataset = TensorDataset(P_train, Y_train)
            val_dataset = TensorDataset(P_val, Y_val)
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


            # Test
            # Get indices of perturbations in the train/test set
            doubles_idx = [i for i, pert in enumerate(perts) if "+" in pert]
            doubles_dataset = TensorDataset(P[doubles_idx, :], Y[doubles_idx, :])
            test_dataloader = DataLoader(doubles_dataset, batch_size=batch_size, shuffle=False)

# Save dataloaders as pickles and perturbation order from perts
            # torch.save(train_dataloader, os.path.join(current_savedir, "train_dataloader.pt"))
            # torch.save(val_dataloader, os.path.join(current_savedir, "val_dataloader.pt"))
            # torch.save(test_dataloader, os.path.join(current_savedir, "test_dataloader.pt")) # equivalent to doubles only
            # torch.save({"perts": perts}, os.path.join(current_savedir, "perts.pt"))

        # Fit the OPTIMIZED model to the training data.
        model_optimized = LinearGeneExpressionModelOptimized(Y_train.T, G, P_train, b)
        train_loss_op = test_lgem(model_optimized, criterion, train_dataloader, device)
        print(f"Train loss (optimized model) | Run {current_run+1}: {train_loss_op:.4f}")
        test_loss_op = 0

        if prediction_type != "double":
            # Test the optimized model.
            test_loss_op = test_lgem(model_optimized, criterion, test_dataloader, device)
            print(f"Val loss (optimized model): {test_loss_op:.4f}")
        else:
            # Test the optimized model on validation set
            test_loss_op = test_lgem(model_optimized, criterion, val_dataloader, device)
            validation=True
            print(f"Val loss (optimized model): {test_loss_op:.4f}")

        # Fit the LEARNED model to the training data.
        model_learned = LinearGeneExpressionModelLearned(G, b)
        optimizer = torch.optim.Adam(params=model_learned.parameters(), lr=1e-3)
        model_learned = train_lgem(
            model_learned, criterion, optimizer, train_dataloader, val_dataloader, n_epochs, device, validation=validation
        )
        train_loss_learn = test_lgem(model_optimized, criterion, val_dataloader, device)
        print(f"Train loss (optimized model) | Run {current_run+1}: {train_loss_op:.4f}")
        test_loss_learn = 0

        if prediction_type != "double":
            # Test the learned model.
            test_loss_learn = test_lgem(model_learned, criterion, test_dataloader, device)
            print("Val loss (learned model): {test_loss_learn:.4f}")  
        else:
            test_loss_learn = test_lgem(model_learned, criterion, val_dataloader, device)
            print(f"Val loss (learned model): {test_loss_learn:.4f}")

        # Save results to file
        print(f"{current_seed},{train_loss_op}, {test_loss_op},{train_loss_learn}, {test_loss_learn}",
            file=f,
        )

        # Save models and embeddings
        # torch.save(model_optimized.state_dict(), os.path.join(current_savedir, "optimized_best_model.pt"))
        # torch.save(model_learned.state_dict(), os.path.join(current_savedir, "learned_best_model.pt"))
        # torch.save(G, os.path.join(current_savedir, "G.pt"))
        # torch.save(b, os.path.join(current_savedir, "b.pt"))
        # torch.save(P_train, os.path.join(current_savedir, "P.pt"))
        # torch.save(Y_train, os.path.join(current_savedir, "Y.pt"))
        # print(f"Saved models and embedding at {current_savedir}")


Pseudobulking:   0%|                                                                                               | 0/125 [00:00<?, ?perturbation/s]

Pseudobulking: 100%|████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:01<00:00, 111.73perturbation/s]


Train loss (optimized model) | Run 1: 0.0030
Val loss (optimized model): 0.0037


Training: 100%|████████████████████████████████████████████████| 200/200 [04:07<00:00,  1.24s/epoch, Training Loss: 1.2132 | Validation Loss: 6.0498]


Train loss (optimized model) | Run 1: 0.0030
Val loss (learned model): 6.0498


### Testing on doubles

In [15]:
# Variables and parameters
# In script it's usually loaded from previous configurations and saved files

# Trained model path
current_seed = seed
model_name = f"lgem_{dataset_name}_seed_{current_seed}_epoch_{n_epochs}_batch_{batch_size}"
current_savedir = os.path.join(savedir, model_name)

# Perturbation list
perturbation_list = perts

# Embeddings
G = G
b = b
P = P_train
Y = Y_train

# # Load dataset (dataloader)
# test_dataloader = torch.load(os.path.join(savedir, "test_dataloader.pt"))
# perturbation_list = torch.load(os.path.join(savedir, "perts.pt"))
# perturbation_list = perturbation_list["perts"]

# Load Embeddings
# G = torch.load(os.path.join(savedir, "G.pt"))
# P = torch.load(os.path.join(savedir, "P.pt"))
# Y = torch.load(os.path.join(savedir, "Y.pt"))
# b = torch.load(os.path.join(savedir, "b.pt"))

In [17]:
# Optimized model
model_optimized = LinearGeneExpressionModelOptimized(Y.T, G, P, b)
# model_optimized.load_state_dict(torch.load(os.path.join(savedir, "optimized_best_model.pt")))
double_perts_list_op, double_predictions_op, ground_truth, mse_pred_op = predict_evaluate_lgem_double(model_optimized, device, test_dataloader, perturbation_list)


Predicting and calculating loss for double perturbations.


In [18]:
# Learned model
model_learned = LinearGeneExpressionModelLearned(G, b)
# model_learned.load_state_dict(torch.load(os.path.join(savedir, "learned_best_model.pt")))
_, double_predictions_learn, _, mse_pred_learn= predict_evaluate_lgem_double(model_learned, device, test_dataloader, perturbation_list)


Predicting and calculating loss for double perturbations.


In [47]:
# Randomly chosen control cells for baseline
rand_idx = np.random.randint(low=0, high=pertdata_ctrl.X.shape[0], size=len(double_perts_list_op))
baseline_control = pertdata_ctrl.X[rand_idx, :].toarray()

# Turning profiles into arrays
double_predictions_op = np.asarray(double_predictions_op)
double_predictions_learn = np.asarray(double_predictions_learn)
gt = np.asarray(ground_truth)
baseline_control = np.asarray(baseline_control) # Redundant I tink but it broke


# Calculate true - pred MSE
mse_control_op = np.mean((gt - baseline_control) ** 2, axis = 1)
mse_control_learn = np.mean((gt - baseline_control) ** 2, axis = 1)
print("MSE calculation done.")


MSE calculation done.


In [52]:
# Calculate Pearson correlation
gt_deg = gt - baseline_control
deg_idx = np.argsort(abs(gt_deg), axis=1)[:, -top_deg:]

# Select values along the top DEG indices for each sample
pred_op_selected = np.take_along_axis(double_predictions_op - baseline_control, deg_idx, axis=1)
pred_learn_selected = np.take_along_axis(double_predictions_learn - baseline_control, deg_idx, axis=1)
gt_selected = np.take_along_axis(gt_deg, deg_idx, axis=1)


pearson_op = np.array([pearsonr(pred_op_selected[i], gt_selected[i])
                       for i in range(pred_op_selected.shape[0])
                       ])
pearson_learn = np.array([pearsonr(pred_learn_selected[i], gt_selected[i])
                       for i in range(pred_learn_selected.shape[0])
                       ])

print("Pearson calculation done")

Pearson calculation done


In [55]:
# Save metrics to result dir
result_df = pd.DataFrame({"double": double_perts_list_op,
                        "mse_true_vs_control_op": mse_control_op,
                        "mse_true_vs_control_learn": mse_control_learn,
                        "mse_true_vs_pred_op": mse_pred_op,
                        "mse_true_vs_pred_learn": mse_pred_learn,
                        "pearson_op": pearson_op[:, 0],
                        "pearson_op_pvalue": pearson_op[:, 1],
                        "pearson_learn": pearson_learn[:, 0],
                        "pearson_learn_pvalue": pearson_learn[:, 1]})

double_pred_op = pd.DataFrame(double_predictions_op, columns=pertdata_ctrl.var_names)
double_pred_op.insert(0, 'double', double_perts_list_op)
double_pred_learn = pd.DataFrame(double_predictions_learn, columns=pertdata_ctrl.var_names)
double_pred_learn.insert(0, 'double', double_perts_list_op)


# double_pred_op.to_csv(os.path.join(args.eval_dir, f"{model_name}_double_predictions_op.csv"), index=False)
# double_pred_learn.to_csv(os.path.join(args.eval_dir, f"{model_name}_double_predictions_learn.csv"), index=False)
# result_df.to_csv(os.path.join(args.eval_dir, f"{model_name}_double_metrics.csv"), index=False)
# print(f"Results saved to {os.path.join(args.eval_dir, f'{model_name}_double_metrics.csv')}")

In [56]:
result_df

Unnamed: 0,double,mse_true_vs_control_op,mse_true_vs_control_learn,mse_true_vs_pred_op,mse_true_vs_pred_learn,pearson_op,pearson_op_pvalue,pearson_learn,pearson_learn_pvalue
0,CEBPE+RUNX1T1,0.027876,0.027876,0.002543,11.850765,0.963878,8.721589e-12,0.214407,0.364017
1,AHR+FEV,0.095165,0.095165,0.013285,11.573503,0.968748,2.411408e-12,0.717365,0.000370
2,FOXA1+HOXB9,0.037915,0.037915,0.005112,11.186444,0.980608,3.433014e-14,-0.043954,0.854015
3,ETS2+MAP7D1,0.020817,0.020817,0.004105,11.542901,0.969774,1.792068e-12,0.154671,0.514978
4,FOXA3+FOXF1,0.042748,0.042748,0.002878,11.095216,0.986575,1.281330e-15,0.174350,0.462234
...,...,...,...,...,...,...,...,...,...
61,OSR2+PTPN12,0.054968,0.054968,0.001859,9.575863,0.996485,7.685355e-21,0.519939,0.018780
62,FOSB+UBASH3B,0.034607,0.034607,0.001474,23.589426,0.995367,9.187693e-20,-0.148715,0.531483
63,ETS2+IGDCC3,0.025769,0.025769,0.002357,3.367628,0.988671,2.802764e-16,0.716451,0.000380
64,TBX2+TBX3,0.033200,0.033200,0.005162,11.857831,0.829380,6.153504e-06,0.324285,0.163045


For lgem, the following results are saved:
- Mean predictions for learned model
- Mean predictions for optimised model
- Metrics for learned model
- Metrics for optimised model