In [1]:
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.spatial.distance import cosine, cdist
from tqdm import tqdm

import scanpy as sc

import anndata as ad
import logging
import scgen

In [2]:
import torch
torch.cuda.is_available()

True

In [24]:
train_data = pd.read_csv('../l1000_data/p1shp2xpr/training_data_p1.csv', index_col=None)
geneid = np.array(train_data.columns[2:]).astype('int')
#train_data_sh = pd.read_csv('../l1000_data/p1shp2xpr/training_data_sh_p1.csv', index_col=None)
train_data2 = pd.read_csv('../l1000_data/p1shp2xpr/training_data_p2.csv', index_col=None)

basalccl = train_data.iloc[np.where(train_data['pert_iname']=='control')[0]]

train_data = pd.concat((train_data, train_data2))

val_data = pd.read_csv('../l1000_data/p1shp2xpr/validation_data_p1.csv', index_col=None)
val_data2 = pd.read_csv('../l1000_data/p1shp2xpr/validation_data_p2.csv', index_col=None)
val_data = pd.concat((val_data, val_data2))
infoidx = 2

# =============================================================================

train_data_info = train_data.iloc[:,:infoidx].values
train_data = train_data.iloc[:,infoidx:].values
val_data_info = val_data.iloc[:,:infoidx].values
val_data = val_data.iloc[:,infoidx:].values

geneinfo = pd.read_csv('../l1000_data/GSE92742_Broad_LINCS_gene_info.txt', sep='\t')
genemapper = pd.Series(data=geneinfo['pr_gene_symbol'].values, index=geneinfo['pr_gene_id'].values)
genesym = pd.Series(geneid).map(genemapper).values

mol_meta = pd.read_csv('../l1000_data/LINCS_small_molecules.tsv', sep='\t')
mol_meta.index = mol_meta['pert_name']
_, uid = np.unique(mol_meta.index, return_index=True)
mol_meta = mol_meta.iloc[uid]
mol_meta_tar = mol_meta.loc[mol_meta['target']!='-',:]

In [25]:
landmark_genes = pd.read_csv('../l1000_data/genelist.csv')
adata_train = ad.AnnData(pd.DataFrame(train_data, index=np.arange(train_data.shape[0]).astype('str'), columns=landmark_genes['pr_gene_symbol']))
adata_train.obs = pd.DataFrame(train_data_info, index=np.arange(train_data.shape[0]).astype('str'), columns=['sample', 'perturbation'])
adata_train

AnnData object with n_obs × n_vars = 338031 × 976
    obs: 'sample', 'perturbation'

In [33]:
totalpert = np.unique(train_data_info[:,1])
totalpert = totalpert[totalpert!='control']
totalpert.size

12794

In [27]:
scgen.SCGEN.setup_anndata(adata_train, batch_key="perturbation", labels_key="sample")
adata_train

  categorical_mapping = _make_column_categorical(


AnnData object with n_obs × n_vars = 338031 × 976
    obs: 'sample', 'perturbation', '_scvi_batch', '_scvi_labels'
    uns: '_scvi_uuid', '_scvi_manager_uuid'

# Training

In [28]:
adata_train

AnnData object with n_obs × n_vars = 338031 × 976
    obs: 'sample', 'perturbation', '_scvi_batch', '_scvi_labels'
    uns: '_scvi_uuid', '_scvi_manager_uuid'

In [29]:
model = scgen.SCGEN(adata_train)
model.save("scgen_saved_models/model_perturbation_prediction.pt", overwrite=True)

In [30]:
model.train(
    max_epochs=100,
    batch_size=32,
    early_stopping=True,
    early_stopping_patience=25
)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/data/yhhan/PGAN/scgen-env/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.9 /data/yhhan/PGAN/scgen-env/lib/python3.9/site-pac ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/data/yhhan/PGAN/scgen-env/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, 

Epoch 69/100:  68%|▋| 68/100 [1:55:23<55:26, 103.96s/it, v_num=1, train_loss_ste
Monitored metric elbo_validation did not improve in the last 25 records. Best score: 2.120. Signaling Trainer to stop.


In [31]:
model.save("scgen_saved_models/model_perturbation_prediction.pt", overwrite=True)

In [32]:
[_ for _ in model.module.parameters()]

[Parameter containing:
 tensor([[-0.2278, -0.2932, -0.5332,  ...,  0.5487, -0.0281, -0.0058],
         [-0.3547,  0.0570,  0.0067,  ..., -0.0371, -0.1163, -0.0261],
         [ 0.5417,  0.1518, -0.3283,  ..., -0.4686, -0.4017,  0.2086],
         ...,
         [ 0.0374, -0.0542, -0.5965,  ...,  0.2778, -0.2250,  0.2603],
         [ 0.2548,  0.0036,  0.2441,  ...,  0.2310,  0.3415, -0.1540],
         [-0.1731,  1.2971,  0.0907,  ...,  0.1439,  0.3020,  0.0039]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([ 1.9068e-02,  2.7142e-02,  1.8848e-02,  1.5695e-02, -6.7707e-03,
         -1.6416e-02,  1.1915e-02, -1.4855e-02, -2.5270e-03,  1.1741e-02,
         -8.7362e-03,  1.3568e-02, -2.9552e-03,  2.2311e-02,  2.8877e-03,
         -2.4296e-02,  6.0171e-03,  2.9519e-03,  1.1124e-02, -1.7822e-02,
          2.5372e-02,  2.4866e-02,  1.7810e-02, -1.5271e-02, -2.3536e-03,
          6.6598e-03,  2.6651e-02, -9.9876e-03,  2.8491e-02, -2.4203e-02,
         -1.1954e-03,  

In [None]:
# scVIDR
import os
import sys
# Code need to be downloaded from https://github.com/BhattacharyaLab/scVIDR
sys.path.insert(1, '../vidr')

from vidr import VIDR
from utils import normalize_data, prepare_data, prepare_cont_data

SINGLE_DOSE_COMMAND = 'single_dose'

CELLTYPE_COLUMN = 'sample'
DOSE_COLUMN = 'Dose'
DOSE_CATEGORICAL_COLUMN = 'dose'
TEST_CELLTYPE = 'HELA' # dummy
CONTROL_DOSE = 0
TREATED_DOSE = 10

MODEL_OUTPUT_DIR = 'myresult/'

CELLTYPES_OF_INTEREST = 'ALL'

TRAIN_COMMAND = 'single_dose'

adata = adata_train.copy()
adata.obs[DOSE_CATEGORICAL_COLUMN] = adata.obs[DOSE_COLUMN].astype(str)
available_doses = adata.obs[DOSE_CATEGORICAL_COLUMN].unique()
available_cell_types = adata.obs[CELLTYPE_COLUMN].unique()
CELLTYPES_OF_INTEREST = available_cell_types
adata = adata[adata.obs[CELLTYPE_COLUMN].isin(CELLTYPES_OF_INTEREST)]

train_adata, test_adata = prepare_data(
    adata,
    CELLTYPE_COLUMN,
    DOSE_CATEGORICAL_COLUMN,
    TEST_CELLTYPE,
    TREATED_DOSE,
    normalized = True
)
model = VIDR(train_adata, linear_decoder = False)
model.train(
    max_epochs=100,
    batch_size=128,
    early_stopping=True,
    early_stopping_patience=25
)
model.save(MODEL_OUTPUT_DIR)

pred, delta, *other = model.predict(
    ctrl_key='0',
    treat_key='10',
    cell_type_to_predict='SAMb6ba60b525,
    regression = False
)

# Chanye

In [35]:
adata_mc_landmark = sc.read_h5ad('ChangYe2021_SEACells.h5ad')
adata_mc_landmark

adata_mc_concat = adata_train.concatenate(adata_mc_landmark, batch_categories=['ref', 'new'])
adata_mc_concat

adata_mc_concat_combat = sc.pp.combat(adata_mc_concat, key='batch', inplace=False)

import qnorm
target_dist = np.median(adata_train.X, axis=0)
corrected_mc = qnorm.quantile_normalize(adata_mc_concat_combat[-adata_mc_landmark.shape[0]:,:].T, target=target_dist).T

Aalpha_ = corrected_mc[adata_mc_landmark.obs['perturbation']=='control']
Abeta_ = corrected_mc[adata_mc_landmark.obs['perturbation']=='erlotinib']

#Aalpha_ = adata_mc_landmark.X[adata_mc_landmark.obs['perturbation']=='control']
#Abeta_ = adata_mc_landmark.X[adata_mc_landmark.obs['perturbation']=='erlotinib']

  adata_mc_concat = adata_train.concatenate(adata_mc_landmark, batch_categories=['ref', 'new'])
  (abs(g_new - g_old) / g_old).max(), (abs(d_new - d_old) / d_old).max()


In [36]:
onlycont = corrected_mc[adata_mc_landmark.obs['perturbation'] == 'control']
adata_total_ = ad.AnnData(pd.DataFrame(onlycont, index=np.arange(onlycont.shape[0]).astype('str'), columns=landmark_genes['pr_gene_symbol']))
adata_total_.obs = adata_mc_landmark.obs[['sample', 'perturbation']][adata_mc_landmark.obs['perturbation'] == 'control']
adata_total_

AnnData object with n_obs × n_vars = 204 × 976
    obs: 'sample', 'perturbation'

In [37]:
adata_total = adata_train

In [38]:
adata_total = adata_total.concatenate(adata_total_, batch_categories=['ref', 'new'])
adata_total

  adata_total = adata_total.concatenate(adata_total_, batch_categories=['ref', 'new'])


AnnData object with n_obs × n_vars = 338235 × 976
    obs: 'sample', 'perturbation', '_scvi_batch', '_scvi_labels', 'batch'

In [39]:
z,x,c = np.unique(adata_total.obs['sample'] + ' ' + adata_total.obs['perturbation'], return_counts=True, return_index=True)
adata_total = ad.concat((adata_total, adata_total[x[c==1]].copy()))
adata_total

  utils.warn_names_duplicates("obs")


AnnData object with n_obs × n_vars = 338689 × 976
    obs: 'sample', 'perturbation', '_scvi_batch', '_scvi_labels', 'batch'

In [40]:
scgen.SCGEN.setup_anndata(adata_total, batch_key="perturbation", labels_key="sample")
adata_total

  categorical_mapping = _make_column_categorical(


AnnData object with n_obs × n_vars = 338689 × 976
    obs: 'sample', 'perturbation', '_scvi_batch', '_scvi_labels', 'batch'
    uns: '_scvi_uuid', '_scvi_manager_uuid'

In [42]:
_model = torch.load("scgen_saved_models/model_perturbation_prediction.pt/model.pt")
pred_model = scgen.SCGEN(adata_total)
pred_model.module.load_state_dict(_model['model_state_dict'])

<All keys matched successfully>

In [44]:
pred_model.is_trained=True

In [None]:
pred_house = []
for p in totalpert:
    pred, _ = pred_model.predict(
        ctrl_key='control',
        stim_key=p,
        celltype_to_predict='SAMb6ba60b525'
    )
    pred.obs['condition'] = 'pred'

    pred_house.append(pred)

In [59]:
rankperf_house = []

mc_info = adata_mc_landmark.obs[['sample', 'perturbation']].values

cors_a_ = np.zeros(len(pred_house))
for p in range(len(pred_house)):
    cors_a_[p] = -cdist(corrected_mc[(mc_info[:,1]=='erlotinib'),:], pred_house[p].X, 'correlation').ravel().mean()
rankperf_house.append([cors_a_])

In [60]:
print(np.where(totalpert[np.argsort(rankperf_house[0][0])[::-1]]=='erlotinib')[0]+1, '/', totalpert.size,)

[11139] / 12465
