In [116]:
import sys
sys.path.insert(1, '../../')
sys.path.insert(1, '../')
from sc_preprocessing import sc_preprocess


# general imports
import click
import warnings
import numpy as np
import os
import pandas as pd
import scanpy as sc
import pickle
from pathlib import Path
from collections import Counter
import re




In [117]:

perturb_vec="other,mono14,mono16,b,cd4_cd8_naive,cd4,cd8"
nonperturb_vec="other,mono14,mono16,b,cd4_cd8_naive,cd4,nk"

perturb_vec = perturb_vec.split(',')
nonperturb_vec = nonperturb_vec.split(',')

res_name="pbmc6k-nk"
in_name="pbmc6k"
test_id="pbmc6k-nk_6"


res_path=f"{os.getcwd()}/../results/single_cell_data/diva_pbmc/"
aug_data_path=f"{os.getcwd()}/../data/single_cell_data/augmented_pbmc_data/"
data_path=f"{os.getcwd()}/../data/single_cell_data/pbmc6k/hg19/"
scpred_path=f"{os.getcwd()}/../results/single_cell_data/pbmc_cell_labels/"
cybersort_path=f"{os.getcwd()}/../data/single_cell_data/cybersort_pbmc/"
bp_path=f"{os.getcwd()}/../results/single_cell_data/bp_pbmc/"

num_cells_vec="5000,5000,5000,5000,5000,5000,5000,5000,5000,5000"


num_genes=5000



In [118]:

# read in the data
adata = sc.read_10x_mtx(
                            data_path,                  # the directory with the `.mtx` file
                            var_names='gene_symbols',   # use gene symbols for the variable names (variables-axis index)
                            cache=True)                 # write a cache file for faster subsequent reading
adata.var_names_make_unique()

# split the strings into a vector
num_cells_vec = num_cells_vec.split(',')
num_cells_vec = [int(i) for i in num_cells_vec]



# get the perturbed and non-perturbed cell-types
perturbed_cell_type = np.setdiff1d(perturb_vec, nonperturb_vec)
nonperturbed_cell_type = np.setdiff1d(nonperturb_vec, perturb_vec)
print(f"Perturbed Cell type: {perturbed_cell_type}")
print(f"Nonperturbed Cell type: {nonperturbed_cell_type}")



# add metadata
meta_data = pd.read_csv(f"{scpred_path}/{in_name}_scpred.tsv", sep="\t", index_col='code')
barcodes = pd.read_csv(f"{data_path}/barcodes.tsv", header=None, names=['code'])
meta_df = barcodes.join(other=meta_data, on=['code'], how='left', sort=False)
adata.obs['scpred_CellType'] = meta_df['scpred_prediction'].tolist()

# filter out cells with less than 200 genes and genes expressed in less than 3 cells
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

# remove genes with high mitochondrial content
adata.var['mt'] = adata.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

# remove cells with more than 2000 genes
# remove cells with more than 7% MTgenes
adata = adata[adata.obs.n_genes_by_counts < 2000, :]
adata = adata[adata.obs.pct_counts_mt < 7, :]


# normalize to 10K counts per cell
sc.pp.normalize_total(adata, target_sum=1e4)

# remove cells that are unlabeled or unclassified
cell_type_id = np.unique(adata.obs["scpred_CellType"].values)
cell_type_remove = ["unassigned", "unclassified"]
cell_type_id = set(cell_type_id).difference(set(cell_type_remove))
adata = adata[adata.obs["scpred_CellType"].isin(cell_type_id)]

# group together cell types that are not very frequent
all_vals = adata.obs["scpred_CellType"].to_list()
all_vals = np.char.replace(all_vals, 'adc', 'other')
all_vals = np.char.replace(all_vals, 'pdc', 'other')
all_vals = np.char.replace(all_vals, 'mk', 'other')
all_vals = np.char.replace(all_vals, 'hsc', 'other')
adata.obs["scpred_CellType"] = all_vals

# get the non-perturbed reference profiles
# we do this for CIBERSORT and BP
adata_nonperturb = adata[adata.obs["scpred_CellType"].isin(nonperturb_vec)]
all_vals = adata_nonperturb.obs["scpred_CellType"].to_list()
all_vals = [(x if x != nonperturbed_cell_type[0] else 'collapsed_celltype') for x in all_vals]
adata_nonperturb.obs["scpred_CellType"] = all_vals
print(f"non-perturbed counts: {Counter(all_vals)}")

# make it dense for BP and CIBERSORT
dense_matrix = adata_nonperturb.X.todense()


########### Make Pseudobulks

## set up the cell-noise perturbations
len_vector = len(perturb_vec)
cell_noise = [np.random.lognormal(0, 0.1, adata.var['gene_ids'].shape[0]) for i in range(len_vector)]

## get the perturbed reference profiles 
adata_perturb = adata[adata.obs["scpred_CellType"].isin(perturb_vec)]

all_vals = adata_perturb.obs["scpred_CellType"].to_list()
all_vals = [(x if x != perturbed_cell_type[0] else 'collapsed_celltype') for x in all_vals]
adata_perturb.obs["scpred_CellType"] = all_vals
print(f"perturbed counts: {Counter(all_vals)}")


Perturbed Cell type: ['cd8']
Nonperturbed Cell type: ['nk']


Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.


non-perturbed counts: Counter({'cd4_cd8_naive': 1304, 'mono14': 1014, 'cd4': 693, 'b': 690, 'mono16': 327, 'collapsed_celltype': 299, 'other': 125})


Trying to set attribute `.obs` of view, copying.


perturbed counts: Counter({'cd4_cd8_naive': 1304, 'mono14': 1014, 'cd4': 693, 'b': 690, 'collapsed_celltype': 661, 'mono16': 327, 'other': 125})


In [121]:
from tensorflow.keras.utils import to_categorical, normalize

n_train = 10
Label_perturb = np.concatenate([np.full(n_train, 1), np.full(n_train, 0),
                            np.full(n_train, 0), np.full(n_train, 0),
                            np.full(n_train, 0), np.full(n_train, 1),
                            np.full(n_train, 1), np.full(n_train, 1),
                            np.full(n_train, 0), np.full(n_train, 0)], axis=0)
label_perturb = to_categorical(Label_perturb)
label_perturb

array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.

In [119]:
curr_adata = adata_perturb
prop_df, pseudobulks_df, test_prop_df, test_pseudobulks_df = sc_preprocess.make_prop_and_sum(curr_adata, 
                                                                        num_samples=100, 
                                                                        num_cells=1000,
                                                                        use_true_prop=False,
                                                                        cell_noise=cell_noise)

# make the proportions instead of cell counts
prop_df = prop_df.div(prop_df.sum(axis=1), axis=0)
prop_df

0
100


Unnamed: 0,cd4_cd8_naive,mono14,collapsed_celltype,b,other,cd4,mono16
0,0.012,0.036,0.006,0.511,0.253,0.181,0.001
0,0.137,0.197,0.156,0.011,0.128,0.258,0.113
0,0.002,0.84,0.005,0.052,0.028,0.0,0.073
0,0.107,0.023,0.001,0.0,0.752,0.117,0.0
0,0.116,0.048,0.057,0.01,0.062,0.124,0.583
...,...,...,...,...,...,...,...
0,0.0,0.571,0.082,0.006,0.0,0.316,0.025
0,0.566,0.012,0.011,0.129,0.023,0.055,0.204
0,0.762,0.004,0.047,0.115,0.047,0.004,0.021
0,0.054,0.13,0.106,0.039,0.008,0.208,0.455


In [120]:
curr_adata = adata_nonperturb
prop_df, pseudobulks_df, test_prop_df, test_pseudobulks_df = sc_preprocess.make_prop_and_sum(curr_adata, 
                                                                        num_samples=100, 
                                                                        num_cells=1000,
                                                                        use_true_prop=False,
                                                                        cell_noise=cell_noise)

# make the proportions instead of cell counts
prop_df = prop_df.div(prop_df.sum(axis=1), axis=0)
print(prop_df)
tmp_Y = prop_df.to_numpy()
tmp_Y = prop_df.columns[np.argmax(tmp_Y, axis=1)]
tmp_Y

0
100
   cd4_cd8_naive mono14      b  other collapsed_celltype    cd4 mono16
0           0.16  0.234  0.066  0.049              0.165  0.009  0.317
0          0.109   0.44  0.014  0.281              0.135  0.008  0.013
0          0.107    0.0  0.014  0.777              0.084    0.0  0.018
0          0.004  0.197  0.169    0.0              0.235  0.391  0.004
0          0.326  0.106   0.01  0.152              0.039  0.358  0.009
..           ...    ...    ...    ...                ...    ...    ...
0          0.619  0.003  0.015  0.002              0.337  0.012  0.012
0           0.01  0.143  0.012  0.005              0.006  0.768  0.056
0          0.003  0.033  0.013   0.05              0.167  0.696  0.038
0          0.167   0.25    0.1  0.071               0.04  0.058  0.314
0          0.332  0.053   0.31  0.083              0.011  0.157  0.054

[100 rows x 7 columns]


Index(['mono16', 'mono14', 'other', 'cd4', 'cd4', 'other', 'mono14', 'mono16',
       'b', 'other', 'mono16', 'b', 'cd4_cd8_naive', 'other', 'mono16',
       'other', 'collapsed_celltype', 'collapsed_celltype', 'b', 'mono16',
       'mono14', 'cd4_cd8_naive', 'mono14', 'mono16', 'b', 'b', 'b', 'other',
       'cd4', 'cd4', 'cd4_cd8_naive', 'cd4_cd8_naive', 'other', 'mono16',
       'cd4_cd8_naive', 'mono14', 'other', 'collapsed_celltype',
       'cd4_cd8_naive', 'mono14', 'cd4', 'cd4_cd8_naive', 'collapsed_celltype',
       'mono16', 'mono16', 'mono16', 'other', 'mono16', 'other',
       'cd4_cd8_naive', 'collapsed_celltype', 'cd4_cd8_naive', 'mono16',
       'other', 'b', 'collapsed_celltype', 'cd4_cd8_naive', 'mono14',
       'cd4_cd8_naive', 'cd4', 'mono14', 'b', 'cd4_cd8_naive', 'cd4_cd8_naive',
       'other', 'other', 'mono16', 'mono16', 'other', 'mono14',
       'cd4_cd8_naive', 'mono14', 'cd4', 'mono16', 'other', 'mono14',
       'cd4_cd8_naive', 'mono16', 'other', 'mono16', 'c