# CROP-seq KO150 - KO cross-prediction to generate similarity graph

In [1]:
cd ../

/home/sreichl/projects/bmdm-stim


In [2]:
# libraries
# general
import os
import pandas as pd
import numpy as np
from itertools import compress
import pickle

# visualization
import seaborn as sns
import matplotlib.pyplot as plt


# for classification
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegressionCV
from sklearn.ensemble import RandomForestClassifier

# for lineage graph
import igraph as ig

In [3]:
# configs
dir_data = os.path.join('results','KO150')
dir_results=os.path.join(dir_data,'KO_classifier')

# cond = 'untreated'
# cond = 'LO28-6h'
# cond = 'LO28-24h'
cond='all'

dir_results = os.path.join(dir_results,cond)

# classifier flag
clf_label = 'LR'
# clf_label = 'RF'

# approach flag (counting or probabilities)
conn_label = 'prob'
# conn_label = 'count'

# directed or undirected (symmetric) graph
directed = True
# directed = False

In [4]:
if not os.path.exists(dir_results):
        os.mkdir(dir_results)

## Load & prepare data and metadata

In [5]:
# Load data
data = pd.read_csv(os.path.join(dir_data,'mixscape',cond, "KO150_mixscape_PRTB.csv"), index_col=0, header=0,)
# data.columns = data.columns.str.replace(".", "-")
print(data.shape)
data.head()

(2735, 28303)


Unnamed: 0,A1_AAACCTGAGAATCTCC-1,A1_AAACCTGAGCGATGAC-1,A1_AAACCTGAGTGCCAGA-1,A1_AAACCTGAGTGGAGAA-1,A1_AAACCTGAGTTTAGGA-1,A1_AAACCTGCACATTCGA-1,A1_AAACCTGCATGCTGGC-1,A1_AAACCTGGTAATTGGA-1,A1_AAACCTGGTAGGCATG-1,A1_AAACCTGGTATATCCG-1,...,C3_TTTGCGCTCAAACCGT-1,C3_TTTGCGCTCACCAGGC-1,C3_TTTGCGCTCCAGAAGG-1,C3_TTTGGTTAGCCGATTT-1,C3_TTTGGTTAGTGAAGAG-1,C3_TTTGGTTCACCGATAT-1,C3_TTTGGTTGTCAGATAA-1,C3_TTTGGTTGTCGCTTTC-1,C3_TTTGGTTGTTCCATGA-1,C3_TTTGTCATCGCCCTTA-1
Saa3,0.055497,0.066633,0.044154,0.018713,0.069097,0.06236,0.099515,0.075316,0.136907,0.070422,...,-0.398294,1.486584,1.188525,1.194867,2.811732,1.760785,2.117486,0.590522,1.84869,1.478736
Cxcl3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.046456,0.067339,0.122639,-0.451832,-0.748925,0.015083,0.03507,1.980204,0.620501,0.232576
S100a8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.081128,0.139241,0.173401,1.684789,0.258586,0.148612,-3.013567,0.152865,0.742899,1.670336
Csf2,0.0,0.021037,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.804523,0.007499,0.03704,0.20794,0.39678,0.169305,0.295606,0.0,0.0,0.09055
Gm5483,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.499594,-0.022193,0.318807,0.698555,0.107833,0.435877,1.176207,0.29447,0.075272,0.496964


In [6]:
# Load annotations
annot = pd.read_csv(os.path.join(dir_data, "KO150_raw_metadata.csv"), index_col=0, header=0,)
print(annot.shape)
annot.head()

(76344, 16)


Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,nCount_AB,nFeature_AB,nCount_gRNA,nFeature_gRNA,nCount_HTO,nFeature_HTO,hto_demux,pass_QC,batch,percent.mt,guide_call,KO_call,condition
A1_AAACCTGAGAATCTCC-1,PT149_5_A1_transcriptome,17884,4442,987,9,358,51,379,3,HTO-RAW264.7-b_untreated,True,A1,1.481771,Chd7-2,Chd7,untreated
A1_AAACCTGAGCGATGAC-1,PT149_5_A1_transcriptome,12849,3787,1378,9,1103,56,288,3,HTO-RAW264.7-b_untreated,True,A1,1.299712,Babam1-1,Babam1,untreated
A1_AAACCTGAGGATCGCA-1,PT149_5_A1_transcriptome,21634,4842,1097,11,74,54,1062,3,HTO-RAW264.7-a_untreated,True,A1,1.127854,Negative,Negative,untreated
A1_AAACCTGAGTGCCAGA-1,PT149_5_A1_transcriptome,14479,3851,458,10,838,72,292,3,HTO-RAW264.7-b_untreated,True,A1,1.20174,Mov10-4,Mov10,untreated
A1_AAACCTGAGTGGAGAA-1,PT149_5_A1_transcriptome,13765,3832,1437,9,270,117,963,3,HTO-RAW264.7-b_untreated,True,A1,1.031602,Spi1-1,Spi1,untreated


In [7]:
# Load mixscape class information
annot_mixscape = pd.read_csv(os.path.join(dir_data,'mixscape',cond, "KO150_mixscape_annotations.csv"), index_col=0, header=0,)
print(annot_mixscape.shape)
annot_mixscape.head()

(28303, 3)


Unnamed: 0,mixscape_class,mixscape_class_p_ko,mixscape_class.global
A1_AAACCTGAGAATCTCC-1,Chd7 NP,0.0,NP
A1_AAACCTGAGCGATGAC-1,Babam1 NP,0.0,NP
A1_AAACCTGAGTGCCAGA-1,Mov10 NP,0.0,NP
A1_AAACCTGAGTGGAGAA-1,Spi1 NP,0.067429,NP
A1_AAACCTGAGTTTAGGA-1,Jmjd1c NP,0.0,NP


In [8]:
annot = pd.concat([annot,annot_mixscape], axis=1)
print(annot.shape)

(76344, 19)


In [9]:
# filter annotation by condition and perturbation classification
if cond=='all':
    annot = annot.loc[(annot['mixscape_class.global']!='NP')  & (annot['mixscape_class.global']!='NT') & ~(annot['mixscape_class.global'].isna()),]
else:
    annot = annot.loc[(annot['condition']==cond) & (annot['mixscape_class.global']!='NP')  & (annot['mixscape_class.global']!='NT') & ~(annot['mixscape_class.global'].isna()),]

print(annot.shape)

(3202, 19)


In [10]:
# filter data by condition and perturbation classification
data = data.loc[:,annot.index]
print(data.shape)

(2735, 3202)


In [11]:
print(annot['mixscape_class'].unique())
len(annot['mixscape_class'].unique())

['Smc1a KO' 'Ifnar1 KO' 'Sfpq KO' 'Tyk2 KO' 'Irf9 KO' 'Ep300 KO' 'Spi1 KO'
 'Ep400 KO' 'Arid1a KO' 'Ikzf1 KO' 'Dnttip2 KO' 'Brd2 KO' 'Stat2 KO'
 'Chd4 KO' 'Yeats2 KO' 'Jak1 KO' 'Dnmt1 KO' 'Med14 KO' 'Sf3b1 KO'
 'Chd8 KO' 'Runx1 KO' 'Smarca4 KO' 'Myd88 KO' 'Stat1 KO' 'Yeats4 KO'
 'Med8 KO' 'Irf3 KO' 'Ddx21 KO']


28

## Prepare Classifiers (Logistic Regression and RandomForest)

In [12]:
# Prepare data for training
X = np.array(data.T)
X = X-X.mean(axis=0) # center data
if cond=='all':
    y = annot["condition"]+'+'+annot["mixscape_class"]
else:
    y = annot["mixscape_class"]

In [13]:
print(y.unique())
len(y.unique())

['untreated+Smc1a KO' 'untreated+Ifnar1 KO' 'untreated+Sfpq KO'
 'untreated+Tyk2 KO' 'untreated+Irf9 KO' 'untreated+Ep300 KO'
 'untreated+Spi1 KO' 'untreated+Ep400 KO' 'untreated+Arid1a KO'
 'untreated+Ikzf1 KO' 'untreated+Dnttip2 KO' 'untreated+Brd2 KO'
 'untreated+Stat2 KO' 'untreated+Chd4 KO' 'untreated+Yeats2 KO'
 'untreated+Jak1 KO' 'untreated+Dnmt1 KO' 'untreated+Med14 KO'
 'untreated+Sf3b1 KO' 'untreated+Chd8 KO' 'untreated+Runx1 KO'
 'LO28-6h+Irf9 KO' 'LO28-6h+Smarca4 KO' 'LO28-6h+Ep300 KO'
 'LO28-6h+Spi1 KO' 'LO28-6h+Arid1a KO' 'LO28-6h+Dnmt1 KO'
 'LO28-6h+Ep400 KO' 'LO28-6h+Ikzf1 KO' 'LO28-6h+Myd88 KO'
 'LO28-6h+Brd2 KO' 'LO28-6h+Runx1 KO' 'LO28-6h+Smc1a KO'
 'LO28-6h+Stat1 KO' 'LO28-6h+Yeats4 KO' 'LO28-6h+Med14 KO'
 'LO28-6h+Jak1 KO' 'LO28-6h+Med8 KO' 'LO28-6h+Dnttip2 KO'
 'LO28-6h+Ifnar1 KO' 'LO28-6h+Yeats2 KO' 'LO28-6h+Stat2 KO'
 'LO28-6h+Chd4 KO' 'LO28-6h+Irf3 KO' 'LO28-6h+Sfpq KO' 'LO28-6h+Ddx21 KO'
 'LO28-24h+Stat2 KO' 'LO28-24h+Smc1a KO' 'LO28-24h+Spi1 KO'
 'LO28-24h+T

58

In [14]:
# RANDOM FOREST w/ mostly default parameters
clf_RF = RandomForestClassifier(
    class_weight="balanced",
    random_state=42, 
)

In [17]:
# LOGISTIC REGRESSION w/ mostly default parameters
clf_LR = LogisticRegression(
    penalty="elasticnet",
    solver="saga",
    multi_class="multinomial",
    max_iter=100, # default 100
    n_jobs=-1,
    random_state=42,
    verbose=1,
    # l1_ratio & C determined via CV
#         C=C, # before 0.1
    l1_ratio=0.5 #before 0.5
)

## get connectivity matrix from LOO strategy

In [18]:
clf_label

'LR'

In [19]:
if clf_label == 'RF':
    clf = clf_RF
if clf_label == 'LR':
    clf = clf_LR
clf

LogisticRegression(l1_ratio=0.5, multi_class='multinomial', n_jobs=-1,
                   penalty='elasticnet', random_state=42, solver='saga',
                   verbose=1)

In [25]:
# # adjacency count matrix for the misclassification; rows are real ct, columns are misclassification
# conn_count = pd.DataFrame(0, index=y.unique(), columns=y.unique(), dtype=float)
# print(conn_count.shape)
# conn_count.head()

In [26]:
# # adjacency matrix for the prediction probabilities of the loo classifier; rows are real ct, columns are misclassification
# conn_prob = pd.DataFrame(0, index=y.unique(), columns=y.unique(), dtype=float)
# print(conn_prob.shape)
# conn_prob.head()

In [27]:
# OLD WAY (slow)
# for ct in y.unique():
#     # manual train/test split according to left out cell type
#     X_train = X[y != ct, :]
#     y_train = y[y != ct]
#     X_test = X[y == ct, :]
#     # train classifier leaving one class out & predict
#     clf.fit(X_train, y_train)
#     predictions = clf.predict(X_test)
#     probs = clf.predict_proba(X_test)
#     # determine probability / fraction where samples went
#     for j in range(len(predictions)):
#         conn_count.loc[ct, predictions[j]] = conn_count.loc[ct, predictions[j]] + 1
#         conn_prob.loc[ct, conn_prob.columns != ct] = conn_prob.loc[ct, conn_prob.columns != ct] + probs[j]

In [20]:
classnames, groups = np.unique(y, return_inverse=True)

In [21]:
cv = sklearn.model_selection.LeaveOneGroupOut()

In [22]:
if conn_label == 'count':
    pred = sklearn.model_selection.cross_val_predict(estimator=clf, X=X, y=y, groups=groups, cv=cv, n_jobs=-1, method='predict')
if conn_label == 'prob':
    pred = sklearn.model_selection.cross_val_predict(estimator=clf, X=X, y=y, groups=groups, cv=cv, n_jobs=-1, method='predict_proba')

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/nobackup/lab_bock/users/sreichl/miniconda3/envs/basics/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/nobackup/lab_bock/users/sreichl/miniconda3/envs/basics/lib/python3.8/site-packages/joblib/externals/loky/process_executor.py", line 567, in run
    self.flag_executor_shutting_down()
  File "/nobackup/lab_bock/users/sreichl/miniconda3/envs/basics/lib/python3.8/site-packages/joblib/externals/loky/process_executor.py", line 756, in flag_executor_shutting_down
    self.kill_workers()
  File "/nobackup/lab_bock/users/sreichl/miniconda3/envs/basics/lib/python3.8/site-packages/joblib/externals/loky/process_executor.py", line 766, in kill_workers
    recursive_terminate(p)
  File "/nobackup/lab_bock/users/sreichl/miniconda3/envs/basics/lib/python3.8/site-packages/joblib/externals/loky/backend/utils.py", line 28, in recursive_terminate
    _recursive_terminate_without_psutil(process)
 

KeyboardInterrupt: 

In [None]:
print(pred.shape)
pred.head()

In [31]:
# generate & save raw and normalized count based connectivity matrix
if conn_label == 'count':
    # count matrix
    conn = pd.DataFrame(sklearn.metrics.confusion_matrix(y, pred), index=classnames, columns=classnames, dtype=float)
    print(conn.shape)
    print(conn.head())
    
    # save connectivity matrix
    conn.to_csv(os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+".csv"))
    
    # normalized matrix
    conn_norm = pd.DataFrame(sklearn.metrics.confusion_matrix(y, pred, normalize='true'), index=classnames, columns=classnames, dtype=float)
    print(conn_norm.shape)
    print(conn_norm.head())
    
    # save normalized connectivity matrix
    conn_norm.to_csv(os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+"Norm.csv"))

In [32]:
# sanity check if counts are adding up correctly
if (conn_label == 'count') & (cond!='all'):
    assert annot.groupby("mixscape_class").size().equals(conn.sum(axis=1).astype(int))

In [33]:
# load raw & normalized count connectivity matrix
if conn_label == 'count':
    conn = pd.read_csv(os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+".csv"), index_col=0, header=0,)
    print(conn.shape)
    print(conn.head())

    # load normalized count connectivity matrix
    conn_norm = pd.read_csv(os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+"Norm.csv"), index_col=0, header=0,)
    print(conn_norm.shape)
    print(conn_norm.head())

In [34]:
if conn_label == 'prob':

    # sanity check if approach works as intended -> pred. prob for correct label has to be 0!
    result=pd.DataFrame(pred, columns=classnames, index=y)
    for index, row in result.iterrows():
        if row[index]!=0:
            print('problem')
            
    conn_norm = pd.DataFrame(columns=classnames, index=classnames)
    for col in classnames:
        for row in classnames:
            conn_norm.loc[row,col]=result.loc[row,col].mean()

    # sanity check if approach works as intended -> diagonal has to be 0
    if sum(np.diag(conn_norm)!=0)>0:
        print('problem')

    # save normalized probability connectivity matrix
    conn_norm.to_csv(os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+"Norm.csv"))

In [5]:
if conn_label == 'prob':
    # load normalized probability connectivity matrix
    conn_norm = pd.read_csv(
        os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+"Norm.csv"), index_col=0, header=0,
    )
    print(conn_norm.shape)
    conn_norm

(12, 12)


### simplify & make symmetric to get similarity connectivity matrix

In [35]:
if not(directed):
    # simplify & make symmetric to get similarity connectivity matrix
    conn_sim = pd.DataFrame()
    for ct1 in conn_norm.index:
        for ct2 in conn_norm.columns:
            conn_sim.loc[ct1, ct2] = (
                ((conn_norm.loc[ct1, ct2] + conn_norm.loc[ct2, ct1])/2)
            )
    print(conn_sim.head())
    
    # save similarity count connectivity matrix
    conn_sim.to_csv(os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+"Sim.csv"))

In [36]:
if not(directed):
    # load similarity count connectivity matrix
    conn_sim = pd.read_csv(os.path.join(dir_results, "KO150_clf"+clf_label+"_conn"+conn_label+"Sim.csv"), index_col=0, header=0,)
    print(conn_sim.shape)
    conn_sim

### build, fill and visualize graph with python-igraph

In [37]:
# rename connectivity similarity matrix for graph
if directed:
    conn_graph = conn_norm
else:
    conn_graph = conn_sim
    
print(conn_graph.index)

Index(['LO28-24h+Chd4 KO', 'LO28-24h+Ep400 KO', 'LO28-24h+Ifnar1 KO',
       'LO28-24h+Ikzf1 KO', 'LO28-24h+Irf9 KO', 'LO28-24h+Jak1 KO',
       'LO28-24h+Myd88 KO', 'LO28-24h+Sfpq KO', 'LO28-24h+Smc1a KO',
       'LO28-24h+Spi1 KO', 'LO28-24h+Stat2 KO', 'LO28-24h+Tyk2 KO',
       'LO28-6h+Arid1a KO', 'LO28-6h+Brd2 KO', 'LO28-6h+Chd4 KO',
       'LO28-6h+Ddx21 KO', 'LO28-6h+Dnmt1 KO', 'LO28-6h+Dnttip2 KO',
       'LO28-6h+Ep300 KO', 'LO28-6h+Ep400 KO', 'LO28-6h+Ifnar1 KO',
       'LO28-6h+Ikzf1 KO', 'LO28-6h+Irf3 KO', 'LO28-6h+Irf9 KO',
       'LO28-6h+Jak1 KO', 'LO28-6h+Med14 KO', 'LO28-6h+Med8 KO',
       'LO28-6h+Myd88 KO', 'LO28-6h+Runx1 KO', 'LO28-6h+Sfpq KO',
       'LO28-6h+Smarca4 KO', 'LO28-6h+Smc1a KO', 'LO28-6h+Spi1 KO',
       'LO28-6h+Stat1 KO', 'LO28-6h+Stat2 KO', 'LO28-6h+Yeats2 KO',
       'LO28-6h+Yeats4 KO', 'untreated+Arid1a KO', 'untreated+Brd2 KO',
       'untreated+Chd4 KO', 'untreated+Chd8 KO', 'untreated+Dnmt1 KO',
       'untreated+Dnttip2 KO', 'untreated+Ep300

In [39]:
cut_offs = [0,0.1,0.2,0.25,0.5]

for cut_off in cut_offs:
    # initiate graph
    g = ig.Graph(directed=directed)
    g.add_vertices(conn_graph.index)
    g.vs["label"] = conn_graph.index

    if cond!='all':
        g.vs["num_samples"] = annot.groupby("mixscape_class").size()[conn_graph.index]

    # create edges
    for i, ct1 in enumerate(conn_graph.index):
        for j, ct2 in enumerate(conn_graph.columns):
            if i<j:
                continue
            if conn_graph.loc[ct1, ct2] > cut_off:
                g.add_edge(ct1, ct2, weight=conn_graph.loc[ct1, ct2])#, label=conn_graph.loc[i, j])
            if conn_graph.loc[ct2, ct1] > cut_off:
                g.add_edge(ct2, ct1, weight=conn_graph.loc[ct2, ct1])

    # draw graph
    layout = g.layout("fr") # Fruchterman-Reingold algorithm
    colors = plt.cm.get_cmap("tab20").colors
    # colors = [list(color_ct_dict[ct][:3]) for ct in conn_graph.index]

    visual_style = {}

    # VERTICES
    # vertex size according to no. of samples
    if cond!='all':
        visual_style["vertex_size"] = np.array(g.vs["num_samples"])/5
    visual_style["vertex_color"] = colors[:len(conn_graph.index)]

    # # VERTEX LABELS
    visual_style["vertex_label"] = [s.replace(" KO", "") for s in g.vs["label"]]
    if cond=='all':
        visual_style["vertex_label"] = [s.replace("untreated", "ut") for s in visual_style["vertex_label"]]
        visual_style["vertex_label"] = [s.replace("LO28-", "") for s in visual_style["vertex_label"]]
    visual_style["vertex_label_size"] = 11
    visual_style["vertex_label_dist"] = 1.5
    visual_style["vertex_label_angle"] = 2

    # # EDGES
    visual_style["edge_width"] = np.array(g.es["weight"]) * 10  # mutliply to see difference
    visual_style["edge_arrow_size"] = 1/100 #rendering arrows invisible
    visual_style["edge_curved"] = False
#     visual_style["edge_tapered"] = True
#     visual_style["edge_drawer_factory"] = TaperedEdgeDrawer

    # GENERAL
    visual_style["margin"] = 50
    visual_style["bbox"] = (400, 400)
    if cond=='all':
        visual_style["bbox"] = (600, 600)

    # ig.plot(g, **visual_style)

    ig.plot(g, os.path.join(dir_results, "KO150_clf"+clf_label+"_graph"+conn_label+"_cutoff"+str(cut_off)+".svg"), **visual_style)