# Load libraries

In [None]:
import os
import numpy as np
import pandas as pd
import glob as glob
import matplotlib.pyplot as plt
import scipy.stats
import seaborn as sns
from sklearn import metrics as skm
import math
from collections import Counter
import networkx as nx
from matplotlib.patches import Rectangle


# Define parameters

In [None]:
gene_percentile = 0.05  # empirical significance cutoff for gene hits
drug_percentile = 0.01  # empirical significance cutoff for drug hits
metric = 'ndc'          # centrality metric to apply
required_number_of_screen_hits = 3  # number of screens in which compounds should be significant

# Screen names indicate .rnk files saved at the user-supplied directory
rnk_directory = 'rnk_folder' # directory in which screen .rnk files are stored
screennames = ['screen1','screen2'] # list of names of .rnk files within directory


In [None]:
gene_dictionary = 'uniprotid_maps.txt'
drug_dictionary = 'drugbank.tsv'
miner_network = 'ChG-Miner_miner-chem-gene.tsv'

# Helper functions

In [None]:
def get_screenhits(screenname, gene_percentile, verbose=False):
    """
    Load rnk file
    Return top [gene_percentile] genes from rnk file
    """
    results = pd.read_csv(f'{rnk_folder}/{screenname}.rnk',sep='\t',header=None).sort_values(by=1,ascending=False, kind='mergesort')

    results = list(results[0])
    screenhits = results[0:int(gene_percentile*len(results))]
    if verbose:
        print(screenname)
        print(len(screenhits))
        print(screenhits[0:10] + ['...'])
    return screenhits

def make_full_graph(verbose=False):

    """
    Load MINER network of drug-chemical interactions
    Load MINER network as Nx object, removing any requested chemicals
    Return Nx MINER graph and its node types
    """

    # Conversion dictionary for gene IDs
    genemap = pd.read_csv(gene_dictionary,sep='\t')

    genemapdict = {}
    for i in genemap.index:
        src = genemap.From.loc[i]
        tgt = genemap.To.loc[i]
        genemapdict[src] = tgt

    # Conversion dictionary for drug IDs
    drugbank = pd.read_csv(drug_dictionary, sep='\t')

    drugmapdict = {}
    for i in drugbank.index:
        src = drugbank.drugbank_id.loc[i]
        tgt = drugbank.name.loc[i]
        drugmapdict[src] = tgt

    # Import network
    # drugs are given by STITCH chemical IDs and proteins are given by NCBI Entrez Gene IDs
    miner = pd.read_csv(miner_network, sep='\t').rename(columns={'#Drug':'drug','Gene':'gene'})
    miner = miner.drop(index=13541)
    # print(len(miner))
    miner = miner[miner.drug.isin(list(drugmapdict.keys()))]

    miner['genename'] = miner.gene.apply(lambda x: genemapdict[x])
    miner['drugname'] = miner.drug.apply(lambda x: drugmapdict[x])
    miner = miner[['genename','drugname']]
    miner.drop_duplicates(inplace=True)
    # print(len(miner))

    # Construct Nx object

    G = nx.Graph()

    all_compounds = list(set(miner.drugname))
    all_targets = list(set(miner.genename))
    positive_nodes = all_compounds + all_targets

    G.add_nodes_from(positive_nodes)

    positive_edges = [tuple(miner.loc[i]) for i in miner.index]
    positive_edges = list(filter(lambda a: a != ('genename', 'drugname'), positive_edges))
    G.add_edges_from(positive_edges)

    G_nodes = [x for x in G.nodes()]
    for node in G_nodes:
        if node in exclude_from_graph:
            G.remove_node(node)

    node_types = dict(list(zip(all_compounds + all_targets, 
                            [{'type':x} for x in ['compound' for x in all_compounds] + ['target' for x in all_targets]]
                            )))

    nx.set_node_attributes(G,node_types)

    types = nx.get_node_attributes(G, 'type')

    if verbose:
        print('Full graph:')
        print('Number of nodes:',G.number_of_nodes())
        print('Number of edges:',G.number_of_edges())
        for key, value in Counter([types[x] for x in types]).items():
            print(f'Number of type {key}',value)

    return G, types, all_targets, all_compounds


def make_hits_subgraph(G_full, screenhits, remove_singletons=True, verbose=False):
    """
    Remove nodes that are not screenhits
    Remove compounds that have no edges
    Return new graph
    """
    G_hits = G_full.copy()
    G_nodes = [x for x in G_full.nodes()]
    node_types = nx.get_node_attributes(G_hits,'type')
    for i, node in enumerate(G_nodes):
        t = node_types[node]
        # t = node['type']
        if t == 'target':
            if node not in screenhits:
                G_hits.remove_node(node)
    
    if remove_singletons:
        G_nodes = [x for x in G_hits.nodes()]
        for i, node in enumerate(G_nodes):
            t = node_types[node]
            # t = node['type']
            if t == 'compound':
                if G_hits.degree(node) == 0:
                    G_hits.remove_node(node)

    types = nx.get_node_attributes(G_hits, 'type')

    if verbose:
        print('\nHits subgraph:')
        print('Number of nodes:',G_hits.number_of_nodes())
        print('Number of edges:',G_hits.number_of_edges())
        for key, value in Counter([types[x] for x in types]).items():
            print(f'Number of type {key}',value)

    return G_hits

def compute_compound_normalized_degree_centrality(G_hits):
    """
    Rank nodes by centrality
    Consider only drugs (not genes)
    Sort drugs by centrality score
    Return sorted drugs and scores
    """
    ranking = nx.degree_centrality(G_hits)
    # print(ranking)
    ranking = pd.DataFrame.from_dict(ranking, orient='index')
    ranking = ranking.loc[ranking.index.isin(all_compounds)]
    ranking.rename(columns={0:'score'}, inplace=True)
    ranking.sort_index(inplace=True, ascending=False)
    ranking.sort_values(by='score', ascending=False, inplace=True, kind='stable')
    # print(ranking)
    return ranking


def compute_compound_betweenness_centrality(G_hits):
    ranking = nx.betweenness_centrality(G_hits)
    ranking = pd.DataFrame.from_dict(ranking, orient='index')
    ranking = ranking.loc[ranking.index.isin(all_compounds)]
    ranking.rename(columns={0:'score'}, inplace=True)
    ranking.sort_values(by='score', ascending=False, inplace=True, kind='mergesort')
    return ranking

def compute_compound_eigenvector_centrality(G_hits):
    ranking = nx.eigenvector_centrality_numpy(G_hits)
    # ranking = nx.katz_centrality_numpy(G_hits)
    ranking = pd.DataFrame.from_dict(ranking, orient='index')
    ranking = ranking.loc[ranking.index.isin(all_compounds)]
    ranking.rename(columns={0:'score'}, inplace=True)
    ranking.sort_values(by='score', ascending=False, inplace=True, kind='mergesort')
    return ranking

def compute_compound_closeness_centrality(G_hits):
    ranking = nx.closeness_centrality(G_hits)
    ranking = pd.DataFrame.from_dict(ranking, orient='index')
    ranking = ranking.loc[ranking.index.isin(all_compounds)]
    ranking.rename(columns={0:'score'}, inplace=True)
    ranking.sort_values(by='score', ascending=False, inplace=True, kind='mergesort')
    return ranking

exclude_from_graph = []


# Perform RxGRID using given parameters

In [None]:
G_full, types, all_targets, all_compounds = make_full_graph()

graph_stats = pd.DataFrame()

drug_dict = {}
drug_dict_info = {}
drug_screen_hits = pd.DataFrame()

for screenname in screennames:

    screenhits = get_screenhits(screenname, gene_percentile, verbose=True)
    G_hits = make_hits_subgraph(G_full, screenhits, verbose=True, remove_singletons=True)

    # calculate descriptive statistics
    num_genes = len([x for x in G_hits.nodes if x in all_targets])
    num_drugs = len([x for x in G_hits.nodes if x in all_compounds])
    num_edges = len([x for x in G_hits.edges])
    mean_degree = np.mean([G_hits.degree(x) for x in G_hits.nodes])
    mean_degree_genes = np.mean([G_hits.degree(x) for x in G_hits.nodes if x in all_targets])
    mean_degree_drugs = np.mean([G_hits.degree(x) for x in G_hits.nodes if x in all_compounds])
    density = num_edges / (num_genes * num_drugs)
    single_gene_frac = len([x for x in G_hits.nodes if (x in all_targets)&(G_hits.degree(x)==0)]) / num_genes

    graph_stats.loc[screenname,'num_genes'] = num_genes
    graph_stats.loc[screenname,'num_drugs'] = num_drugs
    graph_stats.loc[screenname,'num_edges'] = num_edges
    graph_stats.loc[screenname,'mean_degree'] = mean_degree
    graph_stats.loc[screenname,'mean_degree_genes'] = mean_degree_genes
    graph_stats.loc[screenname,'mean_degree_drugs'] = mean_degree_drugs
    graph_stats.loc[screenname,'density'] = density
    graph_stats.loc[screenname,'single_gene_frac'] = single_gene_frac

    if metric == 'ndc':
        ranking = compute_compound_normalized_degree_centrality(G_hits)
    if metric == 'betweenness':
        ranking = compute_compound_betweenness_centrality(G_hits)
    if metric == 'eigenvector':
        ranking = compute_compound_eigenvector_centrality(G_hits)
    if metric == 'closeness':
        ranking = compute_compound_closeness_centrality(G_hits)

    betweenness_ranking = compute_compound_betweenness_centrality(G_hits)
    closeness_ranking = compute_compound_closeness_centrality(G_hits)
    ndc_ranking = compute_compound_normalized_degree_centrality(G_hits)

    agg_ranking = pd.merge(ndc_ranking, betweenness_ranking, left_index=True, right_index=True,
                           suffixes=('_ndc','_btw'))

    agg_ranking.sort_index(ascending=True, inplace=True)
    agg_ranking.sort_values(by='score_btw', ascending=False, inplace=True, kind='mergesort')
    agg_ranking.sort_values(by='score_ndc', ascending=False, inplace=True, kind='mergesort')
    top_drugs = list(agg_ranking.iloc[0:int(np.ceil(4927*drug_percentile))].index)

    drug_dict[screenname] = top_drugs

    drug_dict_info[screenname] = ranking



In [None]:
drug_hits_combined = []
for screenname in drug_dict.keys():
    drug_hits_combined += drug_dict[screenname]

drug_hits_count = Counter(drug_hits_combined)
drug_hits_sorted = Counter(drug_hits_combined).most_common()

collective_hits = []
for drug in drug_hits_count.keys():
    if drug_hits_count[drug] >= required_number_of_screen_hits:
        collective_hits.append(drug)

print(len(collective_hits))
print(collective_hits)


In [None]:
for label in graph_stats.columns:
    print(label+':')
    data = graph_stats[label]
    print(f'{np.mean(data)} ({np.std(data)})\n')


In [None]:
exclude_from_downstream = []

final_drugs = [x for x in collective_hits if x not in exclude_from_downstream]

In [None]:
print(final_drugs)