## STEP 0: IMPORT LIBRARIES AND VARIABLES

In [1]:
import json
import math
import random

import pandas
import requests
import networkx as nx
import plotly.graph_objects as go
from gseapy import enrichr

from indra.statements import stmts_from_json
from indra.databases import uniprot_client, hgnc_client
from indra.assemblers.html import HtmlAssembler

In [2]:
DATASET_PATH = "model.csv" # Set this path yourself

## STEP 1: FETCH DATA FROM INDRA

In [3]:
def uniprot_to_hgnc(uniprot_mnemonic):
    """Get an HGNC ID from a UniProt mnemonic."""
    uniprot_id = uniprot_client.get_id_from_mnemonic(uniprot_mnemonic)
    if uniprot_id:
        return uniprot_client.get_hgnc_id(uniprot_id)
    else:
        return None

In [4]:
# Filter dataset
LABELS_FILTER = ["DMSO-DbET6"]
P_VALUE_FILTER = 0.05

def construct_df(filename):
    """Return a filtered data frame from the given data file."""
    pandas_df = pandas.read_csv(filename)
    pandas_df = pandas_df[pandas_df['adj.pvalue'] < P_VALUE_FILTER]
    pandas_df = pandas_df[pandas_df['issue'].isnull()]
    pandas_df = pandas_df[pandas_df['Label'].isin(LABELS_FILTER)]
    pandas_df['HGNC'] = pandas_df['Protein'].apply(
        lambda uniprot_mnemonic: uniprot_to_hgnc(uniprot_mnemonic)
    )
    pandas_df = pandas_df[pandas_df['HGNC'].notnull()]
    return pandas_df

In [5]:
def get_protein_groundings(df):
    """Return all HGNC IDs derived from the data frame."""
    groundings = set()
    for index, row in df.iterrows():
        groundings.add(('HGNC', row['HGNC']))
    return sorted(groundings)

In [6]:
def query_indra_subnetwork(groundings):
    """Return a list INDRA subnetwork relations based on a list of groundings."""
    res = requests.post(
        'https://discovery.indra.bio/api/indra_subnetwork_relations',
        json={'nodes': groundings}
    )
    return res.json()

In [7]:
pandas_df = construct_df(DATASET_PATH)
groundings = get_protein_groundings(pandas_df)
subnetwork_relations = query_indra_subnetwork(groundings)

In [8]:
subnetwork_relations[4]

{'data': {'belief': 0.65,
  'evidence_count': 1,
  'has_database_evidence': False,
  'has_reader_evidence': True,
  'has_retracted_evidence': False,
  'medscan_only': False,
  'source_counts': '{"reach": 1}',
  'sparser_only': False,
  'stmt_hash': 24950175850600212,
  'stmt_json': '{"type": "DecreaseAmount", "subj": {"name": "BRD4", "mods": [{"mod_type": "modification", "is_modified": true}], "db_refs": {"UP": "O60885", "HGNC": "13575", "TEXT": "Brd4", "EGID": "23476"}}, "obj": {"name": "BRD2", "db_refs": {"UP": "P25440", "HGNC": "1103", "TEXT": "BRD2", "EGID": "6046"}}, "belief": 0.65, "evidence": [{"source_api": "reach", "pmid": "28931940", "text": "In this study, we have shown that BRD2 protein expression was not only induced during adipocyte differentiation, but was also reduced by Brd4 shRNA expression at a late stage (8 days after differentiation), but not at an early stage (2 days after differentiation) of adipocyte differentiation in 3T3-L1 adipocytes.", "annotations": {"found

## STEP 2: CONSTRUCT NETWORK FROM INDRA RELATIONS

In [9]:
ev_counts = {
    (entry['source_id'], entry['target_id']): entry['data']['evidence_count']
    for entry in subnetwork_relations
}

def initialize_networkx_graph(subnetwork_relations, filter_bidirectional=False):
    """Return a networkx graph from the INDRA relations."""
    G = nx.DiGraph()
    for entry in subnetwork_relations:
        source = entry['source_id']
        target = entry['target_id']
        
        if filter_bidirectional:
            # If there is a statement with opposite direction and more evidence
            # then we skip this one
            if ev_counts[(source, target)] < ev_counts.get((target, source), 0):
                continue

        # Add nodes to graph
        source_name = hgnc_client.get_hgnc_name(source)
        target_name = hgnc_client.get_hgnc_name(target)
        G.add_node(source, label=source_name)
        G.add_node(target, label=target_name)
    
        # Add the edge to the graph
        G.add_edge(
            source,
            target,
            evidence_count=entry['data']['evidence_count'],
            belief=entry['data']['belief'],
            stmt_type=entry['data']['stmt_type']
        )

    return G

In [10]:
def find_communities(G, weight='evidence_count'):
    """Return the communities of a networkx graph using a custom weight attribute."""
    return nx.community.louvain_communities(G, weight=weight)

In [11]:
def generate_node_initial_positions(G, communities):
    """Return node positions of a networkx graph based on communities."""
    initial_pos = {}
    circle_r = 1
    big_r = 1
    pi = math.pi
    centers = [(math.cos(2 * pi / len(communities) * x) * big_r, math.sin(2 * pi / len(communities) * x) * big_r)
               for x in range(0, len(communities))]
    for index, nodes in enumerate(communities):
        for node in nodes:
            alpha = 2 * math.pi * random.random()
            r = circle_r * math.sqrt(random.random())
            x = r * math.cos(alpha) + centers[index][0]
            y = r * math.sin(alpha) + centers[index][1]
            initial_pos[node] = [x, y]
            nx.set_node_attributes(G, {node: index}, name='community')
    
    return initial_pos

In [12]:
def apply_layout_to_graph(G, initial_pos, k=30, iterations=100, weight='evidence_count'):
    """Apply custom layout positions to a networkx graph."""
    pos = nx.spring_layout(G, weight=weight, k=k / math.sqrt(len(G.nodes)), pos=initial_pos, iterations=iterations)
    for node in G.nodes():
        x = pos[node][0]
        y = pos[node][1]
        nx.set_node_attributes(G, {node: [x, y]}, name='pos')
    return G

In [13]:
def apply_gsea_to_graph(G, communities):
    """Set graph node attributes reflecting top GO enrichment for each node in a community."""
    for index, nodes in enumerate(communities):
        # Do GSEA and fetch top gene set
        gene_list = [hgnc_client.get_hgnc_name(node) for node in nodes]
        gene_sets = enrichr(gene_list=gene_list,
                            gene_sets=['GO_Biological_Process_2023',
                                       'GO_Cellular_Component_2023',
                                       'GO_Molecular_Function_2023'],
                            organism='Human').results
        top_gene_set = (f'{gene_sets["Term"][0]} w/ p-value '
                        f'{str(gene_sets["Adjusted P-value"][0])}')
        for node in nodes:
            nx.set_node_attributes(G, {node: top_gene_set}, name='gsea')

In [14]:
def construct_networkx_graph(subnetwork_relations):
    """Return a custom networkx graph from INDRA subnetwork relations."""
    G = initialize_networkx_graph(subnetwork_relations)
    communities = find_communities(G)
    initial_pos = generate_node_initial_positions(G, communities)
    apply_layout_to_graph(G, initial_pos)
    apply_gsea_to_graph(G, communities)
    return G


G = construct_networkx_graph(subnetwork_relations)

In [15]:
G.nodes['13575']

{'label': 'BRD4',
 'community': 0,
 'pos': [0.9395495999755249, -0.20174452430318776],
 'gsea': 'rRNA Base Methylation (GO:0070475) w/ p-value 0.0221614275546731'}

## STEP 3: GENERATE VISUALIZATION

In [16]:
def construct_arrows(G):
    """Return list of directed edges for network visualization."""
    arrow_list = []
    for edge in G.edges():
        x0, y0 = G.nodes[edge[0]]['pos']
        x1, y1 = G.nodes[edge[1]]['pos']

        arrow = go.layout.Annotation(dict(
            x=x0,
            y=y0,
            xref="x", yref="y",
            showarrow=True,
            axref="x", ayref='y',
            ax=x1,
            ay=y1,
            arrowhead=3,
            arrowwidth=min(5, G.edges[edge]['evidence_count']),
            arrowcolor='lightgreen')
        )

        arrow_list.append(arrow)
    return arrow_list

In [17]:
def construct_arrow_labels(G):
    """Return custom edge labels for network visualization."""
    mnode_x, mnode_y, mnode_txt = [], [], []
    for edge in G.edges():
        x0, y0 = G.nodes[edge[0]]['pos']
        x1, y1 = G.nodes[edge[1]]['pos']
        name0 = G.nodes[edge[0]]['label']
        name1 = G.nodes[edge[1]]['label']

        mnode_x.extend([(x0 + x1) / 2])
        mnode_y.extend([(y0 + y1) / 2])
        mnode_txt.extend([f'{name0}->{name1} evidence count: {G.edges[edge]["evidence_count"]}'])

    mnode_trace = go.Scatter(
        x=mnode_x, y=mnode_y,
        mode="markers",
        showlegend=False,
        hovertext=mnode_txt,
        hovertemplate="Edge %{hovertext}<extra></extra>",
        marker=go.scatter.Marker(opacity=0)
    )
    
    return mnode_trace

In [18]:
def construct_node_trace(G):
    """Return custom nodes for network visualization"""
    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = G.nodes[node]['pos']
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        hoverinfo='text',
        text=[data['label'] for node, data in list(G.nodes(data=True))],
        textposition="bottom center",
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=30,
            colorbar=dict(
                thickness=15,
                title='Cluster ID',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))

    node_colors = []
    node_text = []
    for node in G.nodes():
        node_text.append(G.nodes[node]['gsea'])
        node_colors.append(G.nodes[node]['community'])

    node_trace.marker.color = node_colors
    node_trace.hovertext = node_text

    return node_trace

In [27]:
def show_plotly_graph(nodes, arrows):
    """Visualize the network based on the list of nodes and arrows."""
    fig = go.Figure(data=nodes,
                    layout=go.Layout(
                        title='<br>Network graph made with Python',
                        font=dict(
                            family="Courier New, monospace",
                            size=10,
                            color="Black"
                        ),
                        annotations=arrows,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )
    fig.show()

In [28]:
def create_plotly_graph(G):
    """Generate and visualize a custom network from a networkx graph."""
    edges = construct_arrows(G)
    edge_midpoints = construct_arrow_labels(G)
    nodes = construct_node_trace(G)
    show_plotly_graph([edge_midpoints, nodes], edges)

In [29]:
create_plotly_graph(G)

In [24]:
G.edges()

OutEdgeView([('13575', '1103'), ('13575', '7867'), ('13575', '1104'), ('1103', '13575'), ('1103', '1104'), ('1104', '1103'), ('1104', '13575'), ('5176', '7867'), ('5176', '25758'), ('7867', '13575'), ('7867', '5176'), ('7867', '21535'), ('7867', '24218'), ('24218', '7867'), ('24218', '21535'), ('21535', '7867'), ('21535', '14098'), ('21535', '24218'), ('6316', '4005'), ('6316', '2683'), ('6316', '5036'), ('6316', '11802'), ('4005', '2683'), ('4005', '6316'), ('4005', '5036'), ('4005', '21100'), ('5036', '4005'), ('5036', '6316'), ('5036', '26087'), ('5036', '11802'), ('2683', '4005'), ('2683', '6316'), ('2683', '11802'), ('21100', '4005'), ('25758', '28945'), ('25758', '5176'), ('28945', '25758'), ('11802', '2683'), ('11802', '6316'), ('11802', '5036'), ('26087', '5036'), ('14098', '21535')])

## STEP 4: GENERATE STATEMENT EVIDENCE BROWSER

In [None]:
# Gather statistics for HTML presentation
unique_stmts = {entry['data']['stmt_hash']: json.loads(entry['data']['stmt_json'])
                for entry in subnetwork_relations}
ev_counts_by_hash = {entry['data']['stmt_hash']: entry['data']['evidence_count']
                     for entry in subnetwork_relations}
source_counts_by_hash = {entry['data']['stmt_hash']: json.loads(entry['data']['source_counts'])
                         for entry in subnetwork_relations}
stmts = stmts_from_json(list(unique_stmts.values()))

In [None]:
ha = HtmlAssembler(stmts,
                   title='INDRA subnetwork statements',
                   db_rest_url='https://db.indra.bio',
                   ev_counts=ev_counts_by_hash,
                   source_counts=source_counts_by_hash)
html_str = ha.make_model()

In [None]:
from IPython.core.display import HTML
#HTML(html_str)