# A tour of the ComPert model

In [1]:
import sys
sys.path.append("../")

In [2]:
# some standard packages to assist this tutorial
from pprint import pprint
import compert as cpa
import pandas as pd


In [3]:
# from compert.train import train_compert
# from compert.data import load_dataset_splits
# from compert.plotting import CompertVisuals
# from compert.api import ComPertAPI


# Training your model

IMPORTANT. Currenlty because of the standartized evaluation procedure, we need to provide adata.obs['control'] (0 if not control, 1 for cells to use as control). And we also need to provide de_genes in .uns['rank_genes_groups']. This genes corresond to the groups of drugs in adata.obs['drug_dose_name']. adata.obs['drug_dose_name'] is a neccessary field to standartized evaluation even for the datasets that don't have doses.

### Init your model

In [4]:
dataset_path = '/home/anna/cpa_binaries/datasets/GSM_new.h5ad'
compert_api = cpa.api.ComPertAPI(
    dataset_path, 
    pretrained=None,
    perturbation_key='condition',
    split_key='split',
    covariate_keys=['cell_type'],
#     pretrained='/home/anna/cpa_binaries/pretrained_models/GSM/sweep_GSM_new_logsigm_model_seed=60_epoch=1120.pt',
    hparams={},
    device='cuda'
)

### Start training

In [5]:
compert_api.train(max_epochs=1)

Rec: -0.7740, AdvPert: 0.67, AdvCov: 0.00:   0%|          | 0/1 [00:04<?, ?it/s]

{"model_saved": "./model.pt"}
{'ellapsed_minutes': 0.07614408334096273,
 'epoch': 0,
 'evaluation_stats': {'covariate disentanglement': [0.0],
                      'ood': [0.7870016772472433,
                              0.29051685915402303,
                              -0.38455639914283113,
                              -14.998942293197016],
                      'optimal for covariates': [1.0],
                      'optimal for perturbations': 0.2,
                      'perturbation disentanglement': 0.7174778255356932,
                      'test': [0.8104991585013892,
                               0.412436435788468,
                               -0.6482918258551059,
                               -15.159100742885807],
                      'training': [0.8180323798801574,
                                   0.4184610634785806,
                                   -0.5210729513080057,
                                   -14.02289937656397]},
 'training_stats': defaultdict(<class 

Rec: -0.7740, AdvPert: 0.67, AdvCov: 0.00:   0%|          | 0/1 [00:47<?, ?it/s]


{"model_saved": "./model.pt"}


### Visualize training history

In [6]:
# from compert.plotting import ComPertHistory
# pretty_history = ComPertHistory(compert_api.model.history)
# pretty_history.print_time()
# pretty_history.plot_losses()
# pretty_history.plot_metrics(epoch_min=0)

In [7]:
# from compert.train import evaluate

# ComPert API for compatibility with scanpy

Print and plot drug embeddings.

In [8]:
perts_anndata = compert_api.get_drug_embeddings()
perts_anndata

AnnData object with n_obs × n_vars = 5 × 256
    obs: 'condition'

Print and plot covars embeddings.

In [9]:
covars_anndata = compert_api.get_covars_embeddings('cell_type')
covars_anndata

AnnData object with n_obs × n_vars = 1 × 256
    obs: 'cell_type'

In [10]:
compert_api.num_measured_points['training']

{'A549_BMS_0.001': 442,
 'A549_BMS_0.005': 391,
 'A549_BMS_0.01': 262,
 'A549_BMS_0.05': 134,
 'A549_BMS_0.1': 103,
 'A549_BMS_1.0': 13,
 'A549_Dex_0.001': 204,
 'A549_Dex_0.005': 264,
 'A549_Dex_0.01': 479,
 'A549_Dex_0.05': 484,
 'A549_Dex_0.1': 486,
 'A549_Dex_1.0': 568,
 'A549_Nutlin_0.001': 284,
 'A549_Nutlin_0.005': 252,
 'A549_Nutlin_0.01': 387,
 'A549_Nutlin_0.05': 350,
 'A549_Nutlin_0.1': 457,
 'A549_Nutlin_1.0': 6,
 'A549_SAHA_0.001': 392,
 'A549_SAHA_0.005': 376,
 'A549_SAHA_0.01': 383,
 'A549_SAHA_0.05': 299,
 'A549_SAHA_0.1': 297,
 'A549_SAHA_1.0': 282,
 'A549_Vehicle_1.0': 1535}

In [11]:
compert_api.compute_comb_emb(thrh=0)
compert_api.compute_uncertainty(
                    cov='A549', 
                    pert='Nutlin', 
                    dose='1.0'
                )

TypeError: compute_comb_emb() missing 1 required positional argument: 'cov'

In [None]:
compert_api.measured_points['training']

In [None]:
compert_api

Setting up a variable for automatic plotting. The plots also could be used separately.

In [None]:
compert_plots = CompertVisuals(compert_api, fileprefix=None)

In [None]:
compert_plots.plot_latent_embeddings(compert_api.emb_perts, kind='perturbations', show_text=True)

If your have a lot of cell types or a lot of perturbations, you can also chose to not display their names.

In [None]:
compert_plots.plot_latent_embeddings(compert_api.emb_perts, kind='perturbations', show_text=False)

Or I can change the color scheme for the emebddings.

In [None]:
perts_palette = {'BMS': '#999999',                 
                 'SAHA': '#4daf4a',
                 'Dex': '#377eb8',
                 'Nutlin': '#e41a1c',
                 'Vehicle': '#000000'
    
                }

compert_plots.perts_palette = perts_palette
compert_plots.plot_latent_embeddings(compert_api.emb_perts, kind='perturbations', show_text=True)

In [None]:
compert_plots.plot_latent_embeddings(compert_api.emb_covars, kind='covars')

In [None]:
latent_response = compert_api.latent_dose_response(perturbations=None)
compert_plots.plot_contvar_response(
    latent_response, 
    postfix='latent',
    var_name=compert_api.perturbation_key,
    title_name='Latent dose response')

In [None]:
perturbations_pair = ['Nutlin', 'BMS']
latent_dose_2D = compert_api.latent_dose_response2D(perturbations_pair, n_points=100)
compert_plots.plot_contvar_response2D(latent_dose_2D, 
        title_name='Latent dose-response')


In [None]:
%%time
reconstructed_response2D = compert_api.get_response2D(datasets, perturbations_pair, compert_api.unique_сovars[0])
compert_plots.plot_contvar_response2D(reconstructed_response2D,
                                              title_name='Reconstructed dose-response  2D',
                                              logdose=False,
                                              # xlims=(-3, 0), ylims=(-3, 0)
                                              )

compert_plots.plot_contvar_response2D(reconstructed_response2D,
                                      title_name='Reconstructed log10-dose-response 2D',
                                      logdose=True,
                                      xlims=(-3, 0), ylims=(-3, 0)
                                      )

If you want to plot in on a log scale, you can just log values in the data frame.

In [None]:
# %%time
df_reference = compert_api.get_response_reference(datasets)        

In [None]:
# %%time
reconstructed_response = compert_api.get_response(datasets)

You can plot an average response (saved under "response" column) among all genes, however, we don't consider it to be a good metric and strongly advise to look at the individual response among DE genes.

Solid lines in this plot correspond to the model predictions, dashed lines -- linear interpolations between measured points. Dots represent measured points, their color is proportional to the number of cells in this condition. Black dots represent points used in training and red dots correspond to the out-of-distribution examples.

In [None]:
df_reference = df_reference.replace('training_treated', 'train')
compert_plots.plot_contvar_response(
    reconstructed_response, 
    df_ref=df_reference, 
    postfix='reconstructed',
    title_name='Reconstructed dose response')

For example we can take of the top 50 DE genes for Nutlin - MDM2. MDM2 is itself transcriptionally-regulated by p53. And p53 is the target of Nutlin. Therefore, we expect our model to learn it.

In [None]:
compert_plots.plot_contvar_response(
    reconstructed_response, 
    df_ref=df_reference,
    response_name='MDM2',
    postfix='MDM2',
    title_name='Reconstructed dose response of MDM2')

We can also look at this plot on the log10-scale. It makes sense for this dataset, because the measured doses were not evenly distributed.

In [None]:
compert_plots.plot_contvar_response(
    reconstructed_response, 
    df_ref=df_reference,
    response_name='MDM2',
    postfix='MDM2',
    logdose=True,
    title_name='Reconstructed log10-dose response of MDM2')

# Predictions

In [None]:
print('Perturbations:', compert_api.unique_perts)
print('Covariates:', compert_api.unique_сovars)
print('Datasets splits:', datasets.keys())

We can chose control cells from which we want to make our predictions. It is easy to chose these cells from either training or test splits.

In [None]:
genes_control = datasets['test_control'].genes

In [None]:
df = pd.DataFrame({args['perturbation_key']: ['BMS', 'Dex'], 
                   args['dose_key']: ['1.0', '0.5'], 
                   args['cell_type_key']: ['A549', 'A549']})

By default, the prediction function returns means and variances of the applied perturbations. 

In [None]:
%%time
compert_api.predict(genes_control, df, return_anndata=True)

By default, the prediction function returns means and variances of the applied perturbations. 

In [None]:
anndata_predicted = compert_api.predict(genes_control, df, return_anndata=True, sample=False)

However, in some cases you want to sample from this distribution, so you can explicitly specify it in the predict function.

In [None]:
anndata_predicted_samples = compert_api.predict(genes_control, df, return_anndata=True, sample=True, n_samples=10)

# Evaluation

In [None]:
genes_control = datasets['training_control'].genes
df_train = compert_api.evaluate_r2(datasets['training_treated'], genes_control)
df_train['benchmark'] = 'CPA'

In [None]:
genes_control = datasets['test_control'].genes
df_ood = compert_api.evaluate_r2(datasets['ood'], genes_control)
df_ood['benchmark'] = 'CPA'

In [None]:
genes_control = datasets['test_control'].genes
df_test = compert_api.evaluate_r2(datasets['test_treated'], genes_control)
df_test['benchmark'] = 'CPA'

In [None]:
df_test = compert_api.evaluate_r2(datasets['test_treated'], genes_control)
df_test['benchmark'] = 'CPA'

In [None]:
df_ood['split'] = 'ood'
df_test['split'] ='test'
df_train['split'] ='train'

In [None]:
df_score = pd.concat([df_train, df_test, df_ood])
df_score.round(2).sort_values(by=['condition', 'R2_mean', 'R2_mean_DE'], ascending=False)

In [None]:
cols_print = ['condition', 'dose_val','R2_mean', 'R2_mean_DE', 'R2_var', 'R2_var_DE', 'split', 'num_cells']
df_score = df_score.round(2).sort_values(by=['condition', 'R2_mean', 'R2_mean_DE'], ascending=False)
print(df_score[cols_print])
# print(df_score[cols_print].to_latex(index=False))

# Uncertainty

We can profile all the predictions with an uncertainty score. Low uncertainty means "good/trustworthy" predictions, high values mean "bad/unknown quality" predictions.

In [None]:
import compert.plotting as pl

for drug in ['Nutlin', 'BMS', 'Dex', 'SAHA']:
    df_pred = pl.plot_uncertainty_dose(
        compert_api,
        cov='A549',
        pert=drug,
        N=51,
        measured_points=compert_api.measured_points['all'],
        cond_key='condition',
        log=True,
        metric='cosine'
    )

Previously, we demonstrated CPA predictions for drugs combinations. But our training data didn't contain any combinations examples. How much can we trust these examples? We can try to asses by running model uncertainty predictions:

In [None]:
df_pred2D = pl.plot_uncertainty_comb_dose(
    compert_api=compert_api,
    cov='A549',
    pert='Nutlin+BMS',
    N=51,
    cond_key='treatment',
    metric='cosine',
)

And here is the predicted response we plotted before:

In [None]:
compert_plots.plot_contvar_response2D(reconstructed_response2D, 
    title_name='Reconstructed dose-response', logdose=False)

And now keep in mind, that the highest uncertainty for OOD cases (for which we know that their predictions were fairly good) is 0.002:

In [None]:
df_pred['uncertainty_cosine'].max().round(3)

Therefore, be careful with interpreting drug combinations in this dataset!

In [None]:
compert_api.measured_points['training']['A549']['Nutlin']