# Assessment of high-influence nodes in boolean models

This notebook gathers a study to identify valuable nodes from a predictive model by inverting their activity states (KI and KO) and checking their subsequent changes to model predictions, compared to the WT model. 

We are testing this approach on four models generated according to the following cell lines: AGS, SW-620, COLO 205 and DU-145. 

For each cell line, a logical model was manually built and simulated to predict drug combination effects of 18 single drugs. Predictions were tested against our experimental drug screen (Flobak et al., Scientific Data, 2019).

In this study, we tested the influence of each single node in the network to predictions of tested drug combinations. For this we applied additional perturbations to drug targets by modifying the activity of single nodes in the network. Each perturbation included a fixed or an inverted node state along with a single drug target perturbation or a combination of two drug targets perturbation. This implies $ 2 (fixed|inverted) * 144 (nodes) * 171 (drug combinations) $ mutations to test per cell line.

Once the mutation is set in the model, we compute the stable states using the bioLQM library tool. From the obtain stable states, we checked if the predictive drug synergies are compliant with experimentally observed synergies and whether predictions are altered compared to the WT model. 

---
**NOTE**

To simplify the notebook, some functions are defined in `utils.py`  and called in this notebook. 

---


In [None]:
import sys
import csv
import biolqm
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import seaborn as sns
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import accuracy_score
from mpl_toolkits import mplot3d
from scipy import spatial
import numpy as np
import utils

# Set color palette for all plots
c_map = sns.cubehelix_palette(dark=0, light=0.8, as_cmap=True)

## I Data import

Import of the combinations of drugs with their targets (nodes) to be knocked out.
Import the experimental synergy data for each cell line (HSA synergy). 
Import the four cell lines models with bioLQM.

---
**NOTE**

`cellLines` should contain a list of tuples where a tuple contains the string name of the cell (identical to column headers in all data) and it's GINML model.

---

In [None]:
# Import combination of drugs
file_comb_drugs = '../data/perturbations.txt'
comb_drugs = list(csv.reader(open(file_comb_drugs, 'r'), delimiter='\t'))

# Import drugs and their targets - header is removed
file_drug_targets = '../data/drugpanel.txt'
drug_targets = list(csv.reader(open(file_drug_targets, 'r'), delimiter='\t'))[1:]

# Combination of drugs and ko of their targets
perturbations = utils.get_perturbations(comb_drugs, drug_targets)
print('Combinations of drugs loaded in \'perturbations\'...')

# Free the memory
del comb_drugs
del drug_targets

# Get the number of perturbations and the number of drugs for later analysis
nb_perturbations = utils.get_nb_perturbations(perturbations)
nb_drugs = utils.get_nb_drugs(perturbations)  

# HSA synergy = experimental data observed in each cell line whether a combination of drugs is synergestic or not
file_hsa_synergy = '../data/20190205_Synergy_HSA.txt'
hsa_synergy = pd.read_csv(file_hsa_synergy, delimiter='\t')
print('Experimental HSA synergy data imported...')

# Models to analyse
file_ags = "../data/models/AGS_refined-model.zginml"
file_colo205 = "../data/models/COLO205_refined-model.zginml"
file_du145 = "../data/models/DU145_refined-model.zginml"
file_sw620 = "../data/models/SW620_refined-model.zginml"

# Load models with bioLQM
ags_model = biolqm.load(file_ags)
colo205_model = biolqm.load(file_colo205)
du145_model = biolqm.load(file_du145)
sw620_model = biolqm.load(file_sw620)

print('Models imported as ginml and loaded with biolqm...')

# Create a list of cell lines with the bioLQM model, the name of the cell line
cellLines = [('AGS', ags_model), ('COLO205', colo205_model), ('DU145', du145_model), ('SW620', sw620_model)]

# Number of double combinations = 153
nb_double_comb = nb_perturbations - nb_drugs

# Repositories name
predictions = "results/predictions/"
simulations = "results/simulationStates/"
randomF = "results/randomForest/"
figures = "results/figures/"

## II Analysis of the Wild-Type models

### 1. Compute stable states with bioLQM (time consuming)
Stable states are called fixpoints in bioLQM. If a stable state is found, the growth of the cell is calculated:

$ Growth = Prosurvival - Antisurvival$. 

This is necessary to predict synergystic effects of a drug combination. 

The results (mutation with inverted/fixed node, drugs - growth - nb_fixpoints) are stored in `results/simulationStates/`*`model`*`_WT_growth.txt`

---

**Note** 

If multiple stable states are found, the mean of the Prosurvival & the mean of the Antisurvival are calculated to provide a unique Growth value.

---


In [None]:
print('This may take some time, please be patient...')
for i in range(len(cellLines)):
    file_fxpts = simulations+ str((cellLines[i])[0])+'_WT_growth.txt'
    fixpoints = open(file_fxpts, "w+")
    utils.get_stable_states_wt(fixpoints, cellLines[i][1], perturbations)
    print('Computed stable states for ' + (cellLines[i])[0])
print('Done.')

### 2. Analysis of stable states with experimental HSA synergy data 
Classify the prediction into different categories by defining whether the prediction is a __true positive__ (TP), __true negative__ (TN), __false positive__ (FP) or __false negative__ (FN) result based on the experimental HSA synergy data.


These metrics will be used for comparison with mutated nodes states performed later. 

In [None]:
# Compute predictions: synergy or not? Compare with HSA synergy data
print('Computation of predictions...')

phenotype = 'WT'

for i in range(len(cellLines)): 
    # Read fixpoints to compute predictions
    file_fxpts = simulations+(cellLines[i])[0]+'_WT_growth.txt'
    file_pred = predictions+ str((cellLines[i])[0])+'_WT.txt'
    synergy = hsa_synergy[['Combination', (cellLines[i])[0]]]
    # Write down the predictions
    utils.get_predictions(file_fxpts, file_pred, synergy, nb_perturbations, nb_drugs, phenotype, (cellLines[i])[0])
print('Done.')

## III Analysis of fixed nodes activities and inverted nodes activities models

### 1. Compute stable states for fixed nodes and inverted nodes activities (time consuming)

Import the activities of each node in specific cell lines (__node_activities__) and invert their activity (__inverted_node_activities__) in order to apply a mutation to the models.

For the models with an inversion of node:
* If a node is active in a cell line model, we do a knock-out
* If a node is inactive in a cell line model, we do an knock-in (ectopic mutation)

It is the other way around for the study of fixed node activities.

In [None]:
# Import file with node activities for each model
node_activity_file = '../data/20180828_Ginsim_node_activity.txt'
fixed_nodes_activities = pd.read_csv(node_activity_file, sep='\t', index_col=0)

# Set empty dataframe for inverted node activities
inverted_nodes_activities = pd.DataFrame(columns=list(fixed_nodes_activities.columns.values), index= list(fixed_nodes_activities.index.values))

# Invert activity of nodes: 1 becomes 0 and 0 becomes 1
for node in fixed_nodes_activities.iterrows():
    inverted_nodes_activities.loc[node[0]] = 1 - node[1]

Similarly to the study for the WT models, we define a function to automatically:

* Compute the stable states for each mutation
    * Fix the node activities as measured in GINsim (apply a mutation) and compute the stable states for each perturbation
    * Invert the node activities measured and do the same process as in step one.
* Get Prosurvival node and Antisurvival node values
* Calculate Growth: $ Growth = Prosurvival - Antisurvival$
* If multiple stable states are found, get the mean of the Prosurvival & the mean of the Antisurvival to calculate a unique Growth value.
* Store the mutation (mutated node - targets KO - drugs) - growth - nb_fixpoints - stablestate in the results folder, for example: `results/simulationStates/AGS_fixed_growth.txt`.





---
**Note**

This is a time consuming computation. The results are already computed and available (look for `results/simulationStates/`*`model`*`_fixed_growth.txt` or `results/simulationStates/`*`model`*`_inverted_growth.txt`). If you don't want to run the analysis, you can directly go to the section `Analysis of stable states with experimental HSA synergy data`.

---

In [None]:
print('This will take some time, please be patient...')
for i in range(len(cellLines)):
    print('Computing fixed node activities mutations')
    file_fixed_fxpts = simulations+(cellLines[i])[0]+'_fixed_growth.txt'
    utils.get_stable_states_mut((cellLines[i])[1], file_fixed_fxpts, fixed_nodes_activities[[(cellLines[i])[0]]], perturbations)

    print('Computing inverted node activities mutations')
    file_inverted_fxpts = simulations+(cellLines[i])[0]+'_inverted_growth.txt'
    utils.get_stable_states_mut((cellLines[i])[1], file_inverted_fxpts, inverted_nodes_activities[[(cellLines[i])[0]]], perturbations)

### 2 Analysis of stable states with experimental HSA synergy data
Define a function that goes through the resulting stable states of each cell line to:
* Split the data into chunks of inverted nodes stable states results: 1 inverted node has 171 combination of perturbations computed.
* For each of the inverted model, find the drug synergies by comparing single drugs *(AK, BI)* with double drugs *(AK-BI)* analysis
    * Create a triplet of information on both the single drugs and the double drugs KO and go through this list
    * Get the according HSA synergy experimental data
    * If fixpoints are missing: define what is missing:
        * __none__: when both single drugs and combination don't find a fixpoint
        * __double__: When both single drugs don't find a fixpoint but the combination does
        * __single__: when the combination doesn't find a fixpoint but both single drugs do
        * __either__: when either one of the single drug or the combination of drugs don't find a fixpoint
    * If fixpoints are found: 
        * Define whether the prediction is a __true positive__ (TF), __true negative__ (TN), __false positive__ (FP) or __false negative__ (FN) result based on the HSA synergy data.
        
The predictions are stored in a separate file for each cell line.


In [None]:
print('Computation of predictions... this may take some time')
#4min per cell line

phenotype = 'mutant'

for i in range(len(cellLines)):
    print(cellLines[i][0])
    # Read both fixed and inverted fixpoints 
    file_fixed_fxpts = simulations+(cellLines[i])[0]+'_fixed_growth.txt'
    file_inv_fxpts = simulations+(cellLines[i])[0]+'_inverted_growth.txt'
    
    # Get the synergy column of the cell line of interest
    synergy = hsa_synergy[['Combination', (cellLines[i])[0]]]
    
    # Output files: store predictions
    file_inv_pred = predictions+(cellLines[i])[0]+'_inverted.txt'
    file_fixed_pred = predictions+(cellLines[i])[0]+'_fixed.txt'
    
    # Compute predictions of fixed and inverted models
    utils.get_predictions(file_inv_fxpts, file_inv_pred, synergy, nb_perturbations, nb_drugs, phenotype, (cellLines[i])[0])
    utils.get_predictions(file_fixed_fxpts, file_fixed_pred, synergy, nb_perturbations, nb_drugs, phenotype, (cellLines[i])[0])

print('Done.')

### 3 Comparison of results between WT and mutants
We now classify nodes according to their importance on the different models. 
The importance or influential character of a node is assessed on whether the fixation (or the inversion) of it's activity has changed the predictions compared to the WT analysis. 

First, we classify the observations into four categories, from the WT to the mutant:
- gain of TP
- gain of TN
- loss of TP
- loss of TN



In [None]:
for i in range(len(cellLines)):
    # Load prediction files to read 
    file_inv_pred = predictions+(cellLines[i])[0]+'_inverted.txt'
    file_fix_pred = predictions+(cellLines[i])[0]+'_fixed.txt'
    file_WT_pred = predictions+(cellLines[i])[0]+'_WT.txt'
    
    # Write classification of gain and loss - return classification file name
    file_classification_fix = utils.get_classification_gain_loss(file_fix_pred, file_WT_pred, nb_double_comb)
    file_classification_inv = utils.get_classification_gain_loss(file_inv_pred, file_WT_pred, nb_double_comb)
    
    # Create heatmaps
    utils.get_classification_heatmap(file_classification_fix, c_map)
    utils.get_classification_heatmap(file_classification_inv, c_map)

In [None]:
# Get a list of all the nodes
nodes_list = []
with open('results/AGS_fixed_classification_gainloss.txt') as csvFile:
    reader = csv.reader(csvFile, delimiter = '\t')
    next(reader)
    for line in reader:
        nodes_list.append(line[0][:-2])

# Create an empty dataframe with row names corresponding to the nodes of the model
influentialNodes = pd.DataFrame(index = nodes_list)

def set_influential_node(model):
    list_influence = []

    mod_f = open('results/' + model + '_fixed_classification_gainloss.txt', 'r')
    mod_f = csv.reader(mod_f, delimiter='\t')
    mod_f = list(mod_f)

    mod_i = open('results/' + model + '_inverted_classification_gainloss.txt', 'r')
    mod_i = csv.reader(mod_i, delimiter='\t')
    mod_i = list(mod_i)

    for i in range(1, len(mod_f), 1):
        if (int(mod_f[i][1]) + int(mod_f[i][2]) + int(mod_f[i][3]) + int(mod_f[i][4]) + int(mod_f[i][5]) + int(
                mod_i[i][1]) + int(mod_i[i][2]) + int(mod_i[i][3]) + int(mod_i[i][4]) + int(mod_i[i][5])) == 0:
            list_influence.append('0')
        else:
            list_influence.append('1')
    return list_influence
    
def rank_per_cell(model):
    list_ranked = []
    mod_f = open('results/'+ model +'_fixed_classification_gainloss.txt', 'r')
    mod_f = csv.reader(mod_f, delimiter='\t')
    mod_f = list(mod_f)
    
    mod_i = open('results/'+ model +'_inverted_classification_gainloss.txt', 'r')
    mod_i = csv.reader(mod_i, delimiter='\t')
    mod_i = list(mod_i)
    
    for i in range(1,len(mod_f), 1):
        sum_rank = mod_f[i][1:5] + mod_i[i][1:5]
        sum_rank  = list(map(int, sum_rank))
        list_ranked.append(sum(sum_rank))
    return list_ranked
    
influentialNodes['AGS'] = set_influential_node('AGS')
influentialNodes['AGS_rank'] = rank_per_cell('AGS')
influentialNodes['COLO205'] = set_influential_node('COLO205')
influentialNodes['COLO205_rank'] = rank_per_cell('COLO205')
influentialNodes['SW620'] = set_influential_node('SW620')
influentialNodes['SW620_rank'] = rank_per_cell('SW620')
influentialNodes['DU145'] = set_influential_node('DU145')
influentialNodes['DU145_rank'] = rank_per_cell('DU145')

list_all = []
list_any = []

for index, row in influentialNodes.iterrows():
    if (int(row[0]) + int(row[2]) + int(row[4]) + int(row[6])) == 0:
        list_all.append('0')
        list_any.append('0')
    else:
        if (int(row[0]) + int(row[2]) + int(row[4]) + int(row[6])) == 4:
            list_all.append('1')
            list_any.append('1')
        else:
            list_all.append('0')
            list_any.append('1')
influentialNodes['allCells'] = list_all
influentialNodes['anyCells'] = list_any

influentialNodes.to_csv('results/Classification_influential_nodes.txt', sep='\t')

### 4. Classification of influential nodes based on Random Forest

A random forest analysis has been performed taking into account biological and network features to assess whether any feature would allow to classify a node as influential or not. 


In [None]:
file_rf_results = randomF + 'importance.tsv'
rf_results = pd.read_csv(file_rf_results, sep='\t', header = 1, index_col=0)

file_error = randomF + 'batch'
error = pd.read_csv(file_error, sep = '\t', header= 1, index_col=0)

def keep_cols(DataFrame, keep_these):
    """Keep only the columns [keep_these] in a DataFrame, delete all other columns."""
    drop_these = list(set(list(DataFrame)) - set(keep_these))
    return DataFrame.drop(drop_these, axis = 1)

rf_results.fillna(value=0)

l_imp = []
for val in error.iterrows():
    if(val[1]['Error'] < 0.4):
        l_imp.append(val[0])
    else:
        print(val)

df = keep_cols(rf_results, l_imp)

def filterDataInterest(index, filter_name):
    column_renaming = ["AGS", "COLO205", "SW620", "DU145", "allCells", "anyCells"]
    data = df[df.filter(like=filter_name).columns]
    if isinstance(index, list): 
        data = data.iloc[index[0]:index[1], :]
    else:
        data = data.iloc[index:, :]
    data.columns = column_renaming
    return data

'''Balance data - numeric features'''
filter_col = [0,13]
df_num_balanced = filterDataInterest(filter_col, "balance_data-numeric_features.txt")
heatmap_n = sns.clustermap(df_num_balanced, figsize=(7,5), col_cluster=False, cmap = c_map)
plt.setp(heatmap_n.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
heatmap_n.savefig(figures+'cluster_RF_balanced_numeric_feature.svg')

'''Non-balanced data - numeric features'''
df_num_nonbalanced = filterDataInterest(filter_col, "non_balanced_data-numeric_features.txt")
heatmap_nb = sns.clustermap(df_num_nonbalanced, figsize=(7,5), col_cluster=False, cmap = c_map)
plt.setp(heatmap_nb.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
heatmap_nb.savefig(figures+'cluster_RF_non-balanced_numeric_feature.svg')

'''Balanced data - biological features'''
filter_col = 14
df_binary_balanced = filterDataInterest(filter_col, "balance_data-binary_all.txt")
heatmap_b = sns.clustermap(df_binary_balanced, figsize=(20,30), col_cluster = False, cmap = c_map, yticklabels=True)
heatmap_b.savefig(figures+'cluster_RF_balanced_bio_feature.svg')

''' Non-balanced data - biological features'''
df_binary_balanced = filterDataInterest(filter_col, "non_balanced_data-binary_all.txt")
heatmap_bnb = sns.clustermap(df_binary_balanced, figsize=(20,30), col_cluster = False, cmap = c_map, yticklabels=True)
heatmap_bnb.savefig(figures+'cluster_RF_non_balanced_bio_feature.svg')

From the heatmaps above, there are only network features that seem to distinguish the nodes and classify them into two categories: PCI, Closeness centrality and Betweenness centrality. It doesn't look like any biological features stands out.

#### Network features combination 

We would like to see now if a combination of these 3 features allows to completely sort the nodes.
The following graphs generate 3D scatter plots where nodes are projected according to the combination of PCI, Betweenness centrality and Closeness centrality, in the different cell lines.

In [None]:
def plot_scatter_3d(x, y, z, points, c_map, title, outputfile): 
    plot = plt.figure(figsize=(10,8)).gca(projection='3d')
    ax = plot.scatter(x, y, z, c = points, cmap =  c_map)
    plot.set_xlabel('PCI')
    plot.set_ylabel('Betweenness centrality')
    plot.set_zlabel('Closeness centrality')
    plot.set_title(title)
    clb = plt.colorbar(ax)
    clb.ax.set_title('Important node')
    plt.savefig(outputfile,  dpi = 300)
    plt.show()
    

file_node_features = '../data/features/node_features.txt'
file_influential_node = 'results/Classification_influential_nodes.txt'

df_node_features = pd.read_csv(file_node_features, sep='\t', header = 0, index_col=0)
df_influential_nodes = pd.read_csv(file_influential_node, sep = '\t', header = 0, index_col = 0)
df = pd.concat([df_node_features, df_influential_nodes], axis=1, sort=False)

for i in range(len(cellLines)): 
    plot_scatter_3d(df[['PCI']], df[['BetweennessCentrality']], df[['ClosenessCentrality']], df['AGS_rank'], c_map, 
                'Scatter plot of nodes ranked on their importance in '+(cellLines[i])[0], 
                figures+'scatter_3dplot_features_'+(cellLines[i])[0].lower()+'.svg')

# Plot for all cell lines
plot_scatter_3d(df[['PCI']], df[['BetweennessCentrality']], df[['ClosenessCentrality']], df['allCells'], c_map,
                'Scatter plot of nodes classified as important (1) or not (0) in all cells', 
                figures+'scatter_3dplot_features_allCells.svg')

# Plot for any of the cell lines
plot_scatter_3d(df[['PCI']], df[['BetweennessCentrality']], df[['ClosenessCentrality']], df['anyCells'], c_map, 
                'Scatter plot of nodes classified as important (1) or not (0) in any cells', 
                figures+'scatter_3dplot_features_anyCells.svg')

#Plot for the different cell lines
for i in range(len(cellLines)): 
    print((cellLines[i])[0].lower())
    plot_scatter_3d(df[['PCI']], df[['BetweennessCentrality']], df[['ClosenessCentrality']], df[(cellLines[i])[0]+'_rank'], c_map, 
                'Scatter plot of nodes ranked on their importance in '+(cellLines[i])[0], 
                figures+'scatter_3dplot_features_'+(cellLines[i])[0].lower()+'.svg')