In [1]:
import pandas as pd
import networkx as nx
from typing import List,Tuple
from pyvis.network import Network
from collections import Counter
from networkx.algorithms.components import weakly_connected_components
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from itertools import groupby,chain
from operator import itemgetter


In [3]:
# misc functions 

def get_network_range(num:int)->str:
    if num < 10:
        return('small')
    elif num < 2000:
        return('medium')
    else:
        return('large')


def get_first_tuple(tup:Tuple)->str:
    return tup[0]


def get_second_tuple(tup:Tuple)->str:
    return tup[1]

# NETWORK CREATION 

## functions to generate and manipulate network 

def create_network(data:pd.DataFrame)->nx.classes.digraph.DiGraph:
    
    cpgs:List[str] = data['CpG'].unique() # cpg nodes
    #snps:List[str]= data['Top SNP'].unique() # snp nodes
    lds:List[str] = data['LD clump'].unique() # ld clump nodes

    # networkX node format with added color attributes

    nodes_cpg = [(cpg, {'color':'#7fc97f'} ) for cpg in cpgs] 
    #nodes_snp = [(snp, {'color':'#beaed4'}) for snp in snps]
    nodes_LD = [('ld_' + ld, {'color':'#ffff99'}) for ld in lds]

    # EDGES

    #cpg_snp_edges:list[Tuple] = [(cpg,snp) for cpg,snp in zip(data['CpG'],data['Top SNP'])] # cpg-snp
    #snp_ld_edges:list[Tuple] = [(snp,'ld_'+ld) for snp,ld in zip(data['Top SNP'],data['LD clump'])]
    #cpg_ld_weighted_edges:list[Tuple] = [(cpg,'ld_'+ld,w) for cpg,ld,w in zip(data['CpG'],data['LD clump'],data['P'])]
    cpg_ld_edges:list[Tuple] = [(cpg,'ld_'+ld) for cpg,ld in zip(data['CpG'],data['LD clump'])]

    cpgNet = nx.DiGraph()

    # add nodes to the directed graph

    cpgNet.add_nodes_from(nodes_cpg)
    #cpgNet.add_nodes_from(nodes_snp)
    cpgNet.add_nodes_from(nodes_LD)

    # add edges to the directed graph 

    #cpgNet.add_edges_from(cpg_snp_edges,color='black')
    #cpgNet.add_edges_from(snp_ld_edges,color='red')
    #cpgNet.add_weighted_edges_from(cpg_ld_weighted_edges,color='black')
    '#beaed4'
    cpgNet.add_edges_from(cpg_ld_edges,color='#beaed4')

    # removing nodes which has a degree 1 : can be a LD-clump connected to a single CpG, a CpG with only one LD clump ...
    
    weak_components:list[list] = [list(weak_component) for weak_component in weakly_connected_components(cpgNet)]
    cpg_ld_singles:list[list] = np.array([cpg_ld for cpg_ld in weak_components if len(cpg_ld)<=2]).flatten()
    cpgNet.remove_nodes_from(cpg_ld_singles)


    return cpgNet

# given CpG id return the connected network 

def subgraph_by_cpg(net:nx.classes.digraph.DiGraph,cpg:str)->nx.classes.digraph.DiGraph:
   
    undirected_net:nx.classes.graph.Graph = net.to_undirected() # turn network to undirected to locate all CpG associations by LD
    allowed_jumps =  2
    reachable_nodes = [node[0] for node in nx.shortest_path_length(undirected_net,source=cpg).items() 
                       if node[1] <= allowed_jumps] 
    subGraph =  net.subgraph(reachable_nodes)
    
    # change the color of source node
    
    nx.set_node_attributes(subGraph,{cpg:'#f0027f'},'color')

    # change the color of source edges

    edges_dict = {edges : '#386cb0' for edges in subGraph.edges(cpg)}
    nx.set_edge_attributes(subGraph,edges_dict,'color')
    
    return subGraph

def aggregate_similar_nodes(net:nx.classes.digraph.DiGraph,num:int)->nx.classes.digraph.DiGraph:
    
    if nx.is_frozen(net):    
        subnet = nx.DiGraph(net)
    else:
        subnet = net
    
    
    subgraph_cpgs = [cpg for cpg in subnet.nodes 
                 if cpg[0] == 'c'] 
                 
    cpgs_with_one_degree = [cpg for cpg in subgraph_cpgs 
                    if subnet.degree(cpg)==num] # cpg nodes with single connections
                    
    cpg_ld_pairs = [(cpg,list(subnet.edges(cpg))[0][1]) for cpg in cpgs_with_one_degree] # cpg-ld edges with single connections

    cpg_ld_pairs.sort(key=lambda x: x[1]) # sort by the LD clump connection before grouping
    cpg_ld_pairs_group = {key : list(map(get_first_tuple,list(group))) 
                for key, group in groupby(cpg_ld_pairs, itemgetter(1))} # group tuples by second value 
    cpg_ld_groups = {key: group for key,group in cpg_ld_pairs_group.items() 
                if len(group)>=2}
    cpgs_to_remove = list(chain(*list(cpg_ld_groups.values())))

    # refactorizes this
    new_edges = []
    weights = []
    new_nodes = []
    
    
    i = num * 100 

    for key, val in cpg_ld_groups.items():   
        print(key)
        code = (str(i),key)
        weight = len(val)
        new_nodes.append(str(i))
        new_edges.append(code)
        weights.append(weight)
        i += 1

    subnet.remove_nodes_from(cpgs_to_remove)
    subnet.add_nodes_from(new_nodes)
    subnet.add_edges_from(new_edges)
    return subnet



# NETWORK plotting

def plot_visjs(net:nx.classes.digraph.DiGraph,file_name:str):
    net_gen:Network = Network()
    net_gen.from_nx(net)
    net_gen.set_options('''
    const options = {
    "layout":{
        "improvedLayout": "false"
    },
    "physics": {
        "solver": "forceAtlas2Based",
        "adaptiveTimestep": "true",
        "stabilization": "true",
        "avoidOverlap":"0"

    }

    }''')
    net_gen.show(file_name)



### to-dos: 
- add summary statistics-based filters to filter-out non-significant snp-cpg pairs
- add weight attribute to network graphs to encode pair significance

- removing SnPs (LD clump) show specific snp by clicking to the LD clump (SNPs removed)
- try prototype on the whole network
	- to see the clusters 
- check color scheme 
	- saturated colors 
	- Rampvis color scheme 
- given cpg, write a function to calculate the associated CpGs in the network 

### given cpg function 
given cpg plot all LD-clumps and other cpgs connected to the network 

In [4]:
file_name = "ld_clump_assoc.txt"  # ld - clump association file


# find cpgs with highest number of snp connections in the network to plot 

top_n = 10000
cpg_ids = [i[0] for i in Counter(data['CpG'].values).most_common()[:top_n]] 

cpgNet = create_network(data)
print(len(cpgNet))
subNet = subgraph_by_cpg(cpgNet,cpg_ids[0])
print(len(subNet))
plot_visjs(subNet,'topCpG.html')
subNet = aggregate_similar_nodes(subNet,1)
print(len(subNet))
plot_visjs(subNet,'topCpGReduced1.html')
subNet = aggregate_similar_nodes(subNet,2)
print(len(subNet))
plot_visjs(subNet,'topCpGReduced2.html')
subNet = aggregate_similar_nodes(subNet,3)
print(len(subNet))
plot_visjs(subNet,'topCpGReduced3.html')

315920
568
ld_8:10009949_C_G
ld_8:10339479_C_T
ld_8:8730488_G_A
ld_8:9314344_C_T
ld_8:9367743_C_G
ld_8:9796321_C_T
396
ld_8:10009949_C_G
ld_8:10339479_C_T
ld_8:8214996_C_G
ld_8:8298285_A_T
ld_8:8400723_A_G
ld_8:8506404_T_C
ld_8:8730488_G_A
ld_8:8839813_CAA_C
ld_8:9222081_C_G
ld_8:9314344_C_T
ld_8:9690898_G_T
ld_8:9796321_C_T
316
ld_8:10009949_C_G
ld_8:10339479_C_T
ld_8:8214996_C_G
ld_8:8506404_T_C
ld_8:8730488_G_A
ld_8:8916376_T_G
ld_8:9314344_C_T
ld_8:9690898_G_T
ld_8:9796321_C_T
267


In [164]:
components:list[list] = [list(weak_component) for weak_component in weakly_connected_components(cpgNet)]
subGraphs = [cpgNet.subgraph(component) for component in components]
subGraphs_length = [len(subgraph.nodes) for subgraph in subGraphs]
ranges = Counter([get_network_range(length) for length in subGraphs_length])
sorted_subGraphs = sorted(subGraphs,key= lambda x: len(x.nodes),reverse=True)
test_subGraph = sorted_subGraphs[1]
len(test_subGraph)


2909