# Evaluate a `PertrubNet` model with `Sci-Plex 3` data 

We use the following pre-processed files:
1. anndata file obtained in [1_perturbations_sciplex3_perturbnet-preprocess.ipynb](https://github.com/nitzanlab/biolord_reproducibility/blob/main/scripts/PerturbNet/1_perturbations_sciplex3_perturbnet-preprocess.ipynb)
2. trained scVI model using [perturbnet_scvi.py](https://github.com/nitzanlab/biolord_reproducibility/blob/main/scripts/PerturbNet/perturbnet_scvi.py),
3. trained cINN model obtained by [perturbnet_cinn.py](https://github.com/nitzanlab/biolord_reproducibility/blob/main/scripts/PerturbNet/perturbnet_cinn.py)

All PerturbNet training and evaluation follow guidelines provided in the package impplementation, [PerturbNet](https://github.com/welch-lab/PerturbNet). 

[[1] Srivatsan, S. R., McFaline-Figueroa, J. L., Ramani, V., Saunders, L., Cao, J., Packer, J., ... & Trapnell, C. (2020). Massively multiplex chemical transcriptomics at single-cell resolution. Science, 367(6473), 45-51.](https://www.science.org/doi/10.1126/science.aax6234)

[[2] Yu, H. and Welch, J.D., 2022. PerturbNet predicts single-cell responses to unseen chemical and genetic perturbations. bioRxiv, pp.2022-07.](https://doi.org/10.1101/2022.07.20.500854)


In [1]:
import sys
import os

import scanpy as sc
import scvi
import numpy as np
from scipy import sparse

from tqdm import tqdm

  self.seed = seed
  self.dl_pin_memory_gpu_training = (


In [2]:
sys.path.insert(0, "path_to_PerturbNet")
from pytorch_scvi.distributions import *
from pytorch_scvi.scvi_generate_z import *

from perturbnet.perturb.util import * 
from perturbnet.perturb.cinn.module.flow import * 
from perturbnet.perturb.chemicalvae.chemicalVAE import *
from perturbnet.perturb.cinn.module.flow_generate import SCVIZ_CheckNet2Net

In [3]:
sys.path.append("../../../")
sys.path.append("../../../utils/")
from paths import DATA_DIR, FIG_DIR

## Utility functions

In [4]:
def bool2idx(x):
    """
    Returns the indices of the True-valued entries in a boolean array `x`
    """
    return np.where(x)[0]

In [5]:
def compute_prediction_normmodel_ctrl(
    adata_ood,
    adata,
    dose_cell_onehot_ood,
    scvi_model_cinn,
    model_g,
    std_model,
    perturbnet_model,
    device="cpu",
    verbose=False
):
    pert_categories_index_ood = pd.Index(adata_ood.obs["cov_drug_dose_name"].values, dtype="category")
    
    Zsample = scvi_model_cinn.get_latent_representation(adata=adata, give_mean=False)
    Lsample = scvi_model_cinn.get_latent_library_size(adata=adata, give_mean=False)
    
    drug_r2 = {}
    drug_r2_full = {}

    normModel = NormalizedRevisionRSquare(largeCountData=adata.layers["counts"].A)

    for cell_drug_dose_comb, category_count in tqdm(
        zip(*np.unique(pert_categories_index_ood.values, return_counts=True))
    ):
        # estimate metrics only for reasonably-sized drug/cell-type combos
        if category_count <= 5:
            continue
        # doesn"t make sense to evaluate DMSO (=control) as a perturbation
        if (
            "dmso" in cell_drug_dose_comb.lower()
            or "control" in cell_drug_dose_comb.lower()
        ):
            continue


        bool_category = pert_categories_index_ood.get_loc(cell_drug_dose_comb)
        idx_all = bool2idx(bool_category)
        idx = idx_all[0]

        real_data = adata_ood[idx_all, :].layers["counts"].A

        # make predictions
        cell_line = adata_ood[idx, :].obs["cell_type"].values[0]
        trt = adata_ood[idx, :].obs["treatment"].values[0]
        dose = adata_ood[idx, :].obs["dose"].values[0]
        idx_trt = data_trt[data_trt["treatment"] == trt]["Indices"].values[0]
        celldose_onehot = dose_cell_onehot_ood[idx]

        # take ctrl cells from same cell line
        idx_base = np.argwhere((adata.obs["cell_type"] == cell_line) & (adata.obs["vehicle"] == "True")).flatten()
        adata_base = adata[idx_base].copy()
        drug_r2_full[cell_drug_dose_comb] = {}
        r2_m_arr = []
        input_ctrl_latent_base =  Zsample[idx_base]
        input_ctrl_library_base = np.log(Lsample[idx_base])
        onehot_indice_trt_other = np.tile(
                data_sciplex_onehot[idx_trt], 
                (input_ctrl_latent_base.shape[0], 1, 1)
        )

        _, _, _, embdata_torch_other = model_g(torch.tensor(onehot_indice_trt_other).float().to(device))
        trt_onehot_otherTo = std_model.standardize_z(embdata_torch_other.cpu().detach().numpy())

        celldose_onehot_other = np.tile(
            celldose_onehot,
            (len(idx_base), 1)
        )
        
        trt_onehot_otherTo = np.concatenate((trt_onehot_otherTo, celldose_onehot_other), axis=1)
            
        recon_latent, recon_data = perturbnet_model.recon_data(
                input_ctrl_latent_base, 
                trt_onehot_otherTo,
                input_ctrl_library_base
            )

        r2_m, _, _ = normModel.calculate_r_square(real_data, recon_data)
        drug_r2[cell_drug_dose_comb] = r2_m
        
    return drug_r2

In [6]:
def compute_prediction_normmodel(
    adata_ood,
    adata_ref,
    adata,
    dose_cell_onehot_ood,
    scvi_model_cinn,
    model_g,
    std_model,
    perturbnet_model,
    device="cpu",
    verbose=False
):
    pert_categories_index_ood = pd.Index(adata_ood.obs["cov_drug_dose_name"].values, dtype="category")
    
    Zsample = scvi_model_cinn.get_latent_representation(adata=adata_ref, give_mean=False)
    Lsample = scvi_model_cinn.get_latent_library_size(adata=adata_ref, give_mean=False)
    
    drug_r2 = {}
    drug_r2_full = {}

    normModel = NormalizedRevisionRSquare(largeCountData=adata.layers["counts"].A)

    for cell_drug_dose_comb, category_count in tqdm(
        zip(*np.unique(pert_categories_index_ood.values, return_counts=True))
    ):
        # estimate metrics only for reasonably-sized drug/cell-type combos
        if category_count <= 5:
            continue
        # doesn"t make sense to evaluate DMSO (=control) as a perturbation
        if (
            "dmso" in cell_drug_dose_comb.lower()
            or "control" in cell_drug_dose_comb.lower()
        ):
            continue


        bool_category = pert_categories_index_ood.get_loc(cell_drug_dose_comb)
        idx_all = bool2idx(bool_category)
        idx = idx_all[0]

        real_data = adata_ood[idx_all, :].layers["counts"].A

        # make predictions
        cell_line = adata_ood[idx, :].obs["cell_type"].values[0]
        trt = adata_ood[idx, :].obs["treatment"].values[0]
        dose = adata_ood[idx, :].obs["dose"].values[0]
        idx_trt = data_trt[data_trt["treatment"] == trt]["Indices"].values[0]
        celldose_onehot = dose_cell_onehot_ood[idx]

        adata_cmp = adata_ref[(adata_ref.obs["cell_type"] == cell_line) & (adata_ref.obs["dose"] == dose)]
        drug_r2_full[cell_drug_dose_comb] = {}
        r2_m_arr = []
        for trt_base in adata_cmp.obs["treatment"].cat.categories:
            idx_trt_base = data_trt[data_trt["treatment"] == trt_base]["Indices"].values[0]
            idx_trt_type_base = np.where(adata_cmp.obs["treatment"] == trt_base)[0]
            
            onehot_indice_trt_base = np.tile(
                data_sciplex_onehot[idx_trt_base],
                (len(idx_trt_type_base), 1, 1)
            )
            
            _, _, _, embdata_torch_base = model_g(torch.tensor(onehot_indice_trt_base).float().to(device))
            
            input_trt_latent_base, trt_onehot_base = Zsample[idx_trt_type_base], std_model.standardize_z(
                embdata_torch_base.cpu().detach().numpy()
            )
            celldose_onehot_base = np.tile(
                celldose_onehot,
                (len(idx_trt_type_base), 1)
            )
			
            trt_onehot_base = np.concatenate((trt_onehot_base, celldose_onehot_base), axis=1)
            input_trt_library_base = np.log(Lsample[idx_trt_type_base])
            
            onehot_indice_trt_other = np.tile(
                data_sciplex_onehot[idx_trt], 
                (input_trt_latent_base.shape[0], 1, 1)
            )
            
            _, _, _, embdata_torch_other = model_g(torch.tensor(onehot_indice_trt_other).float().to(device))
            trt_onehot_otherTo = std_model.standardize_z(embdata_torch_other.cpu().detach().numpy())
			
            trt_onehot_otherTo = np.concatenate((trt_onehot_otherTo, celldose_onehot_base), axis=1)
            
            recon_latent, recon_data = perturbnet_model.trans_data(
                input_trt_latent_base, 
                trt_onehot_base,
                trt_onehot_otherTo,
                input_trt_library_base
            )
        
            r2_m, _, _ = normModel.calculate_r_square(real_data, recon_data)
        
            drug_r2_full[cell_drug_dose_comb][trt_base] = r2_m
            r2_m_arr.append(r2_m)
            
        drug_r2[cell_drug_dose_comb] = np.mean(r2_m_arr)
    return drug_r2, drug_r2_full

In [7]:
def create_df_max(res):
    
    df = pd.DataFrame.from_dict(res, orient="index", columns=["r2"])

    df["r"] = df["r2"].apply(lambda x: max(x,0))
    df["cell_line"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[0]).values
    df["drug"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[1]).values
    df["dose"] = pd.Series(df.index.values).apply(lambda x: x.split("_")[2]).values
    df["dose"] = df["dose"].astype(float)

    df["combination"] = df.index.values
    df = df.reset_index()
    return df

## Set parameters

In [8]:
DATA_DIR_LCL = str(DATA_DIR) + "/perturbations/sciplex3/"

## files
path_chemvae_model = DATA_DIR_LCL + 'models/chemvae/model_params.pt'
path_scvi_model = DATA_DIR_LCL + "models/scvi"

path_sciplex_onehot = DATA_DIR_LCL + 'OnehotData_188.npy'
path_chem_onehot = DATA_DIR_LCL + 'OnehotData_ZINC.npy'

path_cinn_model_save = DATA_DIR_LCL + "models/cinn_cov"

In [9]:
adata_ref =  sc.read(
    DATA_DIR_LCL + "sciplex3_biolord.h5ad",
    backup_url="https://figshare.com/ndownloader/files/39324305",
)

In [10]:
ref_obs = pd.Index([obs.split("-")[0] for obs in adata_ref.obs_names])
adata_ref.obs_names = ref_obs

## Load data

In [11]:
adata_orig = sc.read(os.path.join(DATA_DIR_LCL, 'sciPlex3_whole_filtered_NormBYHighGenes_processed.h5ad')) 
adata = adata_orig[adata_orig.obs_names.isin(ref_obs)].copy()
adata.obs["split_ood"] = adata_ref.obs.loc[adata.obs_names ,"split_ood"]
adata.obs["cov_drug_dose_name"] = adata_ref.obs.loc[adata.obs_names ,"cov_drug_dose_name"]


In [12]:
## remove 9 unseene drugs (using reference adata)
input_ltpm_label = adata.obs.copy()
kept_indices = list(np.where((input_ltpm_label["split_ood"] != "ood") & (input_ltpm_label["treatment"] != "S0000"))[0])


In [13]:
adata_train = adata[kept_indices, :].copy()

In [14]:
## onehot
data_sciplex_onehot = np.load(path_sciplex_onehot)
data_chem_onehot = np.load(path_chem_onehot)

# remove
input_ltpm_label1 = input_ltpm_label.iloc[kept_indices, :]
input_ltpm_label1.index = list(range(input_ltpm_label1.shape[0]))

## meta information
perturb_with_onehot_overall = np.array(list(input_ltpm_label['treatment']))
perturb_with_onehot_kept = perturb_with_onehot_overall[kept_indices]

In [15]:
data_trt = pd.read_csv(DATA_DIR_LCL + 'emb_named_chemvae_canonize.csv')
data_trt['Indices'] = list(range(data_trt.shape[0]))

cell_embdata = input_ltpm_label1.loc[:, ['treatment']].merge(data_trt, how = 'left', on = 'treatment')
indices_onehot = list(cell_embdata['Indices'])

data_sciplexKept_onehot = data_sciplex_onehot[indices_onehot]

## Load models

In [16]:
scvi.model.SCVI.setup_anndata(adata_train, layer = "counts")
scvi_model_cinn = scvi.model.SCVI.load(path_scvi_model, adata_train, use_gpu=False)
scvi_model_de = scvi_predictive_z(scvi_model_cinn)

[34mINFO    [0m File                                                                                                      
         [35m/cs/labs/mornitzan/zoe.piran/research/projects/biolord_data/data/perturbation-celltype/models/scvi/[0m[95mmodel.p[0m
         [95mt[0m already downloaded                                                                                      


  _, _, device = parse_device_args(


In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## ChemicalVAE
model_chemvae = ChemicalVAE(n_char = data_chem_onehot.shape[2], max_len = data_chem_onehot.shape[1]).to(device)
model_chemvae.load_state_dict(torch.load(path_chemvae_model, map_location = device))
model_chemvae.eval()

ChemicalVAE(
  (conv_1): Conv1d(120, 9, kernel_size=(9,), stride=(1,))
  (conv_2): Conv1d(9, 9, kernel_size=(9,), stride=(1,))
  (conv_3): Conv1d(9, 10, kernel_size=(11,), stride=(1,))
  (bnConv1): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnConv2): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnConv3): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear_0): Linear(in_features=90, out_features=196, bias=True)
  (linear_1): Linear(in_features=196, out_features=196, bias=True)
  (linear_2): Linear(in_features=196, out_features=196, bias=True)
  (dropout1): Dropout(p=0.08283292970479479, inplace=False)
  (bn1): BatchNorm1d(196, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout2): Dropout(p=0.08283292970479479, inplace=False)
  (bn2): BatchNorm1d(196, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear_3): Linear(in_features=196, 

### i) estimate latent means and stds

In [18]:
std_model = Standardize(data_all = data_chem_onehot, model = model_chemvae, device = device)

# cell type and dose covariates
dose_pd = pd.get_dummies(list(adata.obs['dose'].astype(int).astype(str)))
dose_onehot_data = dose_pd.values.astype('float64')

cell_type_pd = pd.get_dummies(list(adata.obs['cell_type'].astype(str)))
cell_onehot_data = cell_type_pd.values.astype('float64')

dose_cell_onehot = np.concatenate((dose_onehot_data, cell_onehot_data), axis=1)

if sparse.issparse(adata.X):
    usedata = adata.X.A
else:
    usedata = adata.X

if sparse.issparse(adata.layers['counts']):
    usedata_count = adata.layers['counts'].A
else:
    usedata_count = adata.layers['counts']


flow_model = ConditionalFlatCouplingFlow(
    conditioning_dim=204, # extra 7 columns from cell type and dose
											 # condition dimensions
    embedding_dim=10,
    conditioning_depth=2,
    n_flows=20,
    in_channels=10,
    hidden_dim=1024,
    hidden_depth=2,
    activation="none",
    conditioner_use_bn=True
)

model_c = Net2NetFlow_scVIChemStdStatesFlow(
    configured_flow = flow_model,
    first_stage_data = usedata_count[kept_indices], 
    cond_stage_data = data_sciplexKept_onehot, 
    model_con = model_chemvae, 
    scvi_model = scvi_model_cinn, 
    std_model = std_model,
    cell_type_onehot = cell_onehot_data[kept_indices],
    dose_onehot = dose_onehot_data[kept_indices]
)


model_c.to(device = device)

Note: Conditioning network uses batch-normalization. Make sure to train with a sufficiently large batch size


Net2NetFlow_scVIChemStdStatesFlow(
  (flow): ConditionalFlatCouplingFlow(
    (embedder): BasicFullyConnectedNet(
      (main): Sequential(
        (0): Linear(in_features=204, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
        (3): Linear(in_features=256, out_features=256, bias=True)
        (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): LeakyReLU(negative_slope=0.01)
        (6): Linear(in_features=256, out_features=256, bias=True)
        (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (8): LeakyReLU(negative_slope=0.01)
        (9): Linear(in_features=256, out_features=10, bias=True)
      )
    )
    (sub_layers): ModuleList(
      (0-19): 20 x ConditionalFlatDoubleCouplingFlowBlock(
        (norm_layer): ActNorm()
        (coupling): ConditionalDoubleVectorCoupl

### ii) Load the trained model

In [19]:
# (2) evaluation
model_c.train(n_epochs = 1, batch_size = 128, lr = 4.5e-6)
model_c.load(path_cinn_model_save)

[Epoch 1/1] [Batch 515/515] [loss: 3.590674/3.194298]


## Evaluate

In [20]:
model_c.eval()

model_g = model_c.model_con
model_g.eval()

ChemicalVAE(
  (conv_1): Conv1d(120, 9, kernel_size=(9,), stride=(1,))
  (conv_2): Conv1d(9, 9, kernel_size=(9,), stride=(1,))
  (conv_3): Conv1d(9, 10, kernel_size=(11,), stride=(1,))
  (bnConv1): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnConv2): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnConv3): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear_0): Linear(in_features=90, out_features=196, bias=True)
  (linear_1): Linear(in_features=196, out_features=196, bias=True)
  (linear_2): Linear(in_features=196, out_features=196, bias=True)
  (dropout1): Dropout(p=0.08283292970479479, inplace=False)
  (bn1): BatchNorm1d(196, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout2): Dropout(p=0.08283292970479479, inplace=False)
  (bn2): BatchNorm1d(196, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear_3): Linear(in_features=196, 

In [21]:
perturbnet_model = SCVIZ_CheckNet2Net(model_c, device, scvi_model_de)

In [22]:
idx_ood = np.where((adata.obs["split_ood"] == "ood"))[0]
adata_ood = adata[idx_ood].copy()
dose_cell_onehot_ood = dose_cell_onehot[idx_ood]

In [23]:
idx_test_trts = np.where(
    (adata.obs["split_ood"] == "test") & (adata.obs["vehicle"] == "False")
)[0]

adata_test_trt = adata[idx_test_trts].copy()


In [24]:
idx_train_trts = np.where(
    (adata.obs["split_ood"] == "train") & (adata.obs["vehicle"] == "False")
)[0]

adata_train_trt = adata[idx_train_trts].copy()

In [25]:
drug_r2_ctrl = compute_prediction_normmodel_ctrl(
    adata_ood=adata_ood,
    adata=adata,
    dose_cell_onehot_ood=dose_cell_onehot_ood,
    scvi_model_cinn=scvi_model_cinn,
    model_g=model_g,
    std_model=std_model,
    perturbnet_model=perturbnet_model,
    device=device,
    verbose=False
)

[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


108it [05:04,  2.82s/it]


In [26]:
df_all_ctrl = create_df_max(drug_r2_ctrl)
mean_all_ctrl = df_all_ctrl.groupby(by=["dose"]).mean("r2_de").reset_index()
median_all_ctrl = df_all_ctrl.groupby(by=["dose"]).median("r2_de").reset_index()

In [27]:
mean_all_ctrl

Unnamed: 0,dose,r2,r
0,0.001,0.61322,0.61322
1,0.01,0.511364,0.511364
2,0.1,0.400939,0.400939
3,1.0,0.224406,0.224406


In [28]:
drug_r2, drug_r2_full = compute_prediction_normmodel(
    adata_ood=adata_ood,
    adata_ref=adata_train_trt,
    adata=adata,
    dose_cell_onehot_ood=dose_cell_onehot_ood,
    scvi_model_cinn=scvi_model_cinn,
    model_g=model_g,
    std_model=std_model,
    perturbnet_model=perturbnet_model,
    device=device,
    verbose=False
)

[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


108it [40:31, 22.51s/it]


In [29]:
drug_r2_max = {}
for comb in drug_r2_full:
    drug_r2_max[comb] = np.max(list(drug_r2_full[comb].values()))

In [30]:
df_all_max = create_df_max(drug_r2_max)
mean_all_max = df_all_max.groupby(by=["dose"]).mean("r2_de").reset_index()
median_all_max = df_all_max.groupby(by=["dose"]).median("r2_de").reset_index()

In [31]:
mean_all_max

Unnamed: 0,dose,r2,r
0,0.001,0.187643,0.187643
1,0.01,0.165635,0.165635
2,0.1,0.154935,0.154935
3,1.0,0.101398,0.101398
