In [1]:
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.spatial.distance import cosine, cdist
from tqdm import tqdm

import scanpy as sc

import anndata as ad
import logging
import scgen

import os

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
basedir = '../l1000_data/'
train_data = pd.read_csv(basedir+'p1shp2xpr/training_data_p1.csv', index_col=None)
geneid = np.array(train_data.columns[2:]).astype('int')
train_data_sh = pd.read_csv(basedir+'p1shp2xpr/training_data_sh_p1.csv', index_col=None)
train_data2 = pd.read_csv(basedir+'p1shp2xpr/training_data_p2.csv', index_col=None)
train_data_ve = pd.read_csv(basedir+'p1shp2xpr/training_data_vehicle.csv', index_col=None)
train_data_ve = train_data_ve.loc[train_data_ve['pert_iname']=='DMSO']
train_data_ve.loc[:,'pert_iname'] = 'control'

basalccl = train_data.iloc[np.where(train_data['pert_iname']=='control')[0]]

train_data = pd.concat((train_data, train_data_sh, train_data2, train_data_ve))

val_data = pd.read_csv(basedir+'p1shp2xpr/validation_data_p1.csv', index_col=None)
val_data2 = pd.read_csv(basedir+'p1shp2xpr/validation_data_p2.csv', index_col=None)
val_data = pd.concat((val_data, val_data2))
infoidx = 2

# =============================================================================

train_data_info = train_data.iloc[:,:infoidx].values
train_data = train_data.iloc[:,infoidx:].values
val_data_info = val_data.iloc[:,:infoidx].values
val_data = val_data.iloc[:,infoidx:].values

geneinfo = pd.read_csv(basedir+'GSE92742_Broad_LINCS_gene_info.txt', sep='\t')
genemapper = pd.Series(data=geneinfo['pr_gene_symbol'].values, index=geneinfo['pr_gene_id'].values)
genesym = pd.Series(geneid).map(genemapper).values

mol_meta = pd.read_csv(basedir+'LINCS_small_molecules.tsv', sep='\t')
mol_meta.index = mol_meta['pert_name']
_, uid = np.unique(mol_meta.index, return_index=True)
mol_meta = mol_meta.iloc[uid]
mol_meta_tar = mol_meta.loc[mol_meta['target']!='-',:]

In [4]:
landmark_genes = pd.read_csv(basedir+'genelist.csv')
adata_train = ad.AnnData(pd.DataFrame(train_data, index=np.arange(train_data.shape[0]).astype('str'), columns=landmark_genes['pr_gene_symbol']))
adata_train.obs = pd.DataFrame(train_data_info, index=np.arange(train_data.shape[0]).astype('str'), columns=['sample', 'perturbation'])

In [5]:
totalpert = np.unique(train_data_info[:,1])
totalpert = totalpert[totalpert!='control']

In [6]:
scgen.SCGEN.setup_anndata(adata_train, batch_key="perturbation", labels_key="sample")

  categorical_mapping = _make_column_categorical(


# Training

In [15]:
savedir = 'scgen_saved_models/model_perturbation_prediction.pt'
if os.path.exists(savedir+"/model.pt"):
    _model = torch.load(savedir+"/model.pt")
    model = scgen.SCGEN(adata_train)
    model.module.load_state_dict(_model['model_state_dict'])
else:
    model = scgen.SCGEN(adata_train)
    model.save(savedir, overwrite=True)
    model.train(
        max_epochs=100,
        batch_size=32,
        early_stopping=True,
        early_stopping_patience=25
    )
    model.save(savedir, overwrite=True)

In [16]:
[_ for _ in model.module.parameters()]

[Parameter containing:
 tensor([[ 0.1586,  0.3435,  0.4464,  ...,  0.2957,  0.1271, -0.0204],
         [ 0.0858,  0.1906, -0.5277,  ...,  0.0455, -0.3366, -0.0382],
         [ 0.0686,  0.1848, -0.2581,  ..., -0.0266,  0.0026,  0.1188],
         ...,
         [ 0.7419, -0.2091,  0.2366,  ...,  0.1962, -0.0010, -0.3355],
         [ 0.2296,  0.1261,  0.3472,  ...,  0.0603, -0.1343, -0.0219],
         [-0.0959, -0.1333,  0.2987,  ...,  0.1626, -0.3310,  0.1065]],
        requires_grad=True),
 Parameter containing:
 tensor([ 2.6350e-02, -1.1266e-02, -2.8826e-02, -2.7148e-02,  2.3419e-02,
          3.5233e-03,  1.3414e-02, -1.9411e-02, -1.1340e-02,  1.6240e-02,
         -1.8583e-02, -2.9410e-02,  1.1967e-02,  5.8096e-03, -2.2264e-02,
          2.2568e-02, -2.7384e-02,  2.6876e-02,  6.4958e-03,  1.4030e-02,
         -1.9547e-02, -1.7894e-03,  1.6325e-02,  2.7074e-02, -3.3565e-03,
          2.9343e-02,  1.7218e-02,  1.6301e-02, -8.2071e-03,  6.7270e-03,
         -2.2585e-04, -7.0114e-03,  1.53

# Chanye

In [17]:
adata_mc_landmark = sc.read_h5ad('ChangYe2021_SEACells.h5ad')
adata_mc_landmark

adata_mc_concat = adata_train.concatenate(adata_mc_landmark, batch_categories=['ref', 'new'])
adata_mc_concat

adata_mc_concat_combat = sc.pp.combat(adata_mc_concat, key='batch', inplace=False)

import qnorm
target_dist = np.median(adata_train.X, axis=0)
corrected_mc = qnorm.quantile_normalize(adata_mc_concat_combat[-adata_mc_landmark.shape[0]:,:].T, target=target_dist).T

Aalpha_ = corrected_mc[adata_mc_landmark.obs['perturbation']=='control']
Abeta_ = corrected_mc[adata_mc_landmark.obs['perturbation']=='erlotinib']

  adata_mc_concat = adata_train.concatenate(adata_mc_landmark, batch_categories=['ref', 'new'])
  (abs(g_new - g_old) / g_old).max(), (abs(d_new - d_old) / d_old).max()


In [18]:
onlycont = corrected_mc[adata_mc_landmark.obs['perturbation'] == 'control']
adata_total_ = ad.AnnData(pd.DataFrame(onlycont, index=np.arange(onlycont.shape[0]).astype('str'), columns=landmark_genes['pr_gene_symbol']))
adata_total_.obs = adata_mc_landmark.obs[['sample', 'perturbation']][adata_mc_landmark.obs['perturbation'] == 'control']
adata_total_

AnnData object with n_obs × n_vars = 204 × 976
    obs: 'sample', 'perturbation'

In [19]:
adata_total = adata_train

In [20]:
# Include basal state to the adata_total
adata_total = adata_total.concatenate(adata_total_, batch_categories=['ref', 'new'])

  adata_total = adata_total.concatenate(adata_total_, batch_categories=['ref', 'new'])


In [21]:
z,x,c = np.unique(adata_total.obs['sample'] + ' ' + adata_total.obs['perturbation'], return_counts=True, return_index=True)
adata_total = ad.concat((adata_total, adata_total[x[c==1]].copy()))

  utils.warn_names_duplicates("obs")


In [22]:
scgen.SCGEN.setup_anndata(adata_total, batch_key="perturbation", labels_key="sample")

  categorical_mapping = _make_column_categorical(


In [23]:
_model = torch.load(savedir+"/model.pt")
pred_model = scgen.SCGEN(adata_total)
pred_model.module.load_state_dict(_model['model_state_dict'])

<All keys matched successfully>

In [24]:
pred_model.is_trained=True

In [None]:
pred_house = []
for p in totalpert:
    pred, _ = pred_model.predict(
        ctrl_key='control',
        stim_key=p,
        celltype_to_predict='SAMb6ba60b525'
    )
    pred.obs['condition'] = 'pred'

    pred_house.append(pred)

In [28]:
rankperf_house = []

mc_info = adata_mc_landmark.obs[['sample', 'perturbation']].values

cors_a_ = np.zeros(len(pred_house))
for p in range(len(pred_house)):
    cors_a_[p] = -cdist(corrected_mc[(mc_info[:,1]=='erlotinib'),:], pred_house[p].X, 'correlation').ravel().mean()
rankperf_house.append([cors_a_])

In [29]:
print(np.where(totalpert[np.argsort(rankperf_house[0][0])[::-1]]=='erlotinib')[0]+1, '/', totalpert.size,)

[11139] / 12794
