<img style="float: right;" src="img/biva.png" width="20%">

BiVA: Bitcoin Network Visualization and Analysis 
---------------
This demo uses the database neo4j, ipython-cypher for querying neo4j using python, and python-igraph for plotting graphs.

In [8]:
from IPython.display import clear_output, display, Image
from ipywidgets import *
 

#*********************#
#status variables
#********************#
#to know whether to requery the db
current_input = 0
#to know which graph to work with
current_file = 0
#to keep the same plot layout
current_layout = 0
#keep track of search depth
current_depth = 0

#********************#
#Widgets: input box
#********************#
#to enter the Bitcoin input
text1 = Text(
    #description='Input:',
    placeholder='Enter a Bitcoin address/transaction',
    #value = '1JujBBkRGAEm7JvdnCDfGw939cEtbuuWa2',
    layout = Layout(width='450px')
)
#to decide the depth of search in the database
slider1 = IntSlider(
    min=0,
    max=10,
    step=1,
    value=3,
    description='Search depth:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout = Layout(width='450px'),
    style = {'description_width': 'initial'}
)

#address or transaction?
choice1 = RadioButtons(
    options=['a Bitcoin address', 'a transaction hash'],
    description='This is:',
    margin = '5px',
    disabled=False
)

#create one button that indicates the user is done with the input
button1 = Button(
    description='Done',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    icon='check'
)

#displays the input box
display(HBox([VBox([text1,slider1]),choice1, button1]))


#***************#
#Widgets: tabs
#***************#
#creates the first tab, the way to plot the result
#as a function of the type of network
choice_view = RadioButtons(
    options=['dual mode', 'address network', 'transaction network'],
    description='View:',
    border ='red',
    disabled=False
)

#as a function of the search depth 
slider_zoom = interactive(plt_zoom,x=
    IntSlider(
    min=1,
    max=slider1.max,
    step=1,
    value = slider1.value,
    description='Zoom:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout = Layout(width='260px')
))

#with or without labels
labelling = widgets.Checkbox(
    value=False,
    description='Labels',
    disabled=False,
    layout=Layout(width='30%')
)

tab1 = HBox([choice_view,slider_zoom,labelling],layout=Layout(height='70px'))

#***************#
#Widgets: tabs
#***************#
#creates the second tab, to filter edges
text_min = Text(
    description='BTC amt >= :',
    value = '0',
    layout=Layout(width='70%'))

text_max = Text(
    description='and <= :',
    value = '10000',
    layout=Layout(width='70%'))

#creates slider)
slider2 = IntSlider(
    min=0,
    max=20,
    step=1,
    value = slider1.value,
    description='Depth:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

tab2 = HBox([VBox([text_min,text_max]),slider2],layout=Layout(height='70px'))

#***************#
#Widgets: tabs
#***************#
#creates the third tab, to find paths between adresses
#input the starting point
text_from = Text(
    #description='From:',
    placeholder='Choose a start addr/trans',
    #value = text1.value,
    layout=Layout(width='98%'))

#input the end point
text_to = Text(
    #description='to :',
    placeholder='and an end addr/trans',
    layout=Layout(width='98%'))

#input for path confluence
text_addr = Textarea(
    #description='Address :',
    placeholder='or a list of addresses for path confluence',
    layout=Layout(width='40%'))

#choose the type of paths
choice2 = RadioButtons(
    options=['directed','undirected'], #'the database (undirected)'],
    description='in:',
    disabled=False
)
tab3 = HBox([VBox([text_from,text_to]),choice2,text_addr],layout=Layout(height='70px'))

#***************#
#Widgets: tabs
#***************#
#creates the 4rth tab, to do clustering
#choose clustering
choice_clus = RadioButtons(
    options=['spectral', 'probabilistic'],
    #description='Type:',
    disabled=False
)
#parameters
nb_clus = Text(
    description='#clusters:',
    value='3',
    layout=Layout(width='130px')
)
alpha = Text(
    description= r'\(\alpha\)',
    placeholder='0 or 1',
    layout=Layout(width='140px')
)
dw = Text(
    description= r'\(D_w\)',
    value = '0',
    layout=Layout(width='130px')
)
mu = Text(
    description= r'\(\mu\)',
    placeholder='0 or 1',
    layout=Layout(width='140px')
)
nb_iter = Text(
    description='#iters:',
    value='3',
    layout=Layout(width='130px')
)

tab4 = HBox([choice_clus,nb_clus,alpha,dw,mu,nb_iter],layout=Layout(height='70px'))

#***************#
#Widgets: tabs
#***************#
#create text boxes
list_neighb = Textarea(
    description='Addresses:',
    placeholder = 'Enter a list of addresses/transactions that you are looking for',
    layout=Layout(width='80%',))

#creates the 5th tab, to do clustering
tab5 = VBox([list_neighb],layout=Layout(height='70px'))


#creates the set of tabs, with the respective title
tab = widgets.Tab([tab1, tab2, tab3, tab4, tab5],layout=Layout(width="95%"))
tab.set_title(0, 'plot')
tab.set_title(1, 'filter edges')
tab.set_title(2, 'find paths')
tab.set_title(3, 'cluster')
tab.set_title(4, 'neighbours')



#**************************************#
# create the empty space to put the plot
#***************************************#
from IPython.display import Image
out = widgets.Output(layout={'border': '1px solid black','width': '95%'})
with out:
    display(Image(filename='img/blank.png'))

#**************************************#
# create a box to write/update comments
#***************************************#
out_comment = widgets.Output(layout={'width': '80%'})     
    
#display done button
display(VBox([tab, out, out_comment]))

#*******************************#
#when the done button is clicked

button1.on_click(handle_input)


#3225f8f6ec52bd1a0e113ebc2dcc5208bb2c331d765efd7b4dba39d88c61bb8d
#1DriJgHZrYYmY4jRiVQaKHzcJpUjCpGeUQ (test addr)
#17dpY2fUTnNB7xQEA2mTcxBiZYDFN1MNGo (test another address to get a path)
#377kST3E7qDJ1FZRoA9xZX7bK8UYCJC6Pt (not inside)
#1JujBBkRGAEm7JvdnCDfGw939cEtbuuWa2 (scam)
#17dpY2fUTnNB7xQEA2mTcxBiZYDFN1MNGo,19djNhJLi8ar8DCPrN7j5qr8JBytzU1kfm,1NqUXBCeVPf1QGM9K6eeVGKoLrGBdY88WC

HBox(children=(VBox(children=(Text(value='', layout=Layout(width='450px'), placeholder='Enter a Bitcoin addres…

VBox(children=(Tab(children=(HBox(children=(RadioButtons(description='View:', options=('dual mode', 'address n…

245 rows affected.
11 rows affected.
2 rows affected.
1531 rows affected.
245 rows affected.


In [3]:
%load_ext cypher

In [7]:
import networkx as nx
import igraph as ig
import math
import numpy as np

from sklearn.cluster import SpectralClustering

from collections import defaultdict
from collections import deque

import re 
from hashlib import sha256

#******************#
# Settings
#******************#
max_size = 1500

#********************************#
# check bitcoin address validity
# (code found online)
#*******************************#
digits58 = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
 
def decode_base58(bc, length):
    n = 0
    for char in bc:
        n = n * 58 + digits58.index(char)
    return n.to_bytes(length, 'big')

def check_bc(bc):
    try:
        bcbytes = decode_base58(bc, 25)
        return bcbytes[-4:] == sha256(sha256(bcbytes[:-4]).digest()).digest()[:4]
    except Exception:
        return False

#***********************************************#    
#decides what happens once the button is clicked 
#***********************************************#    
def handle_input(sender):
    global current_input, current_file, current_layout, current_depth
    
    #**************************#
    #db querying and basic plot
    #**************************#
    if tab.selected_index == 0:
        #retrieve inputs
        depth = slider1.value
        labflag = labelling.value
        #*************#
        #db querying
        #*************#
        #distinguishes address and transaction, addresses first
        if choice1.value == 'a Bitcoin address' and (not(text1.value == current_input) or not(depth == current_depth)):
            addrvalue  = text1.value
            if not check_bc(addrvalue):
                raise TypeError("This is not a valid Bitcoin adress")
            else:
                with out_comment:
                    #if the bitcoin address is valid
                    clear_output(wait=True)
                    print('Please wait...querying the database...')
                answer = query_db_addr(addrvalue,depth)
                if len(answer) == 0 :
                    #create an empty graph
                    nx.write_gml(nx.Graph(), "Gq%s.gml" %addrvalue)
                    with out_comment:
                        clear_output(wait=True)
                        print(addrvalue," not found.")
                else:
                    current_input = text1.value
                    current_depth = depth
                    #export to networkx
                    Gq = answer.get_graph()
                    #save the data for further use
                    nx.write_gml(Gq, "Gq%s.gml" %addrvalue)
                    current_file = "Gq%s.gml" %addrvalue
                    with out_comment:
                        clear_output(wait=True)
                        print('The network contains ',Gq.order(),' nodes and ',Gq.size(),' edges.' )
        #transactions next
        if choice1.value == 'a transaction hash' and not(text1.value == current_input):
            txvalue = text1.value
            with out_comment:
                clear_output(wait=True)
                print('Please wait...querying the database...')
            answer = query_db_tx(txvalue,depth)
            if len(answer) == 0:
                #create an empty graph
                nx.write_gml(nx.Graph(), "Gq%s.gml" %txvalue)
                with out_comment:
                    clear_output(wait=True)
                    print(txvalue," not found.")
            else:
                current_input = text1.value
                #export to networkx
                Gq = answer.get_graph()
                #save the data for further use
                nx.write_gml(Gq, "Gq%s.gml" %txvalue)
                current_file = "Gq%s.gml" %txvalue
                with out_comment:
                    clear_output(wait=True)
                    print('The network contains ',Gq.order(),' nodes and ',Gq.size(),' edges.' )
         
        #at this point, we have a graph Gq stored in its corresponding file
        #*************#
        #basic plot
        #*************#
        currvalue = text1.value
        #load in networkx format
        Gnx = nx.read_gml('Gq%s.gml' %currvalue)
        N = Gnx.order()
        if choice1.value == 'a Bitcoin address':
            currtype = 'out'
            #this is for returning after zoom
            current_file = "Gq%s.gml" %currvalue
        else:
            currtype = 'tx'
            #this is for returning after zoom
            current_file = "Gq%s.gml" %currvalue
        #*************#
        #dual mode
        #*************#
        if choice_view.value == 'dual mode':
            if Gnx.order() > max_size:
                #*****************#
                #plot a subgraph
                #*****************#
                distfromaddr = compd(Gnx.to_undirected(),currvalue,currtype)
                dmax = 1
                maxreached = False
                while maxreached == False:
                    nodescloserthand = [n for n in distfromaddr.keys() if distfromaddr[n] <= dmax]
                    if len(nodescloserthand) <= max_size: 
                        dmax = dmax+1
                    else:
                        maxreached = True
                Gnx.remove_nodes_from([n for n in distfromaddr.keys() if distfromaddr[n] > dmax-1])
                nx.write_gml(Gnx, "Gq%s.gml" %currvalue)
                current_file = "Gq%s.gml" %currvalue
                with out_comment:
                    print('Extracted a subnetwork of ',Gnx.order(),' nodes and ',Gnx.size(),' edges.')
                            
            if Gnx.order()>0:
                #basic layout
                G,plotstyle,all_labels = styleG("Gq%s.gml" %currvalue,labflag)
                startid = []
                #highlights the starting address
                if currtype == 'out':
                    startid = [int(n['id']) for n in G.vs if n['labels']==currtype and n['addr']==currvalue]
                #highlights the starting tx    
                else:
                    startid = [int(n['id']) for n in G.vs if n['labels']==currtype and n['txhash']==currvalue]
                if len(startid)>0:
                    for i in startid:
                        plotstyle["vertex_color"][i] = 'cyan'
                ig.plot(G,'img/Gq%s.png' %currvalue,**plotstyle,vertex_frame_width=0)
                current_layout = plotstyle
                #display within widget
                with out:
                    clear_output(wait=True)
                    display(Image(filename='img/Gq%s.png' %currvalue))

                if labflag==True:
                    with out_comment:
                        clear_output(wait=True)
                        for i in all_labels:
                            print(i)
            
                    
        #***************#
        #single mode
        #***************#  
        else: 
            addrvalue  = text1.value
            if choice_view.value == 'address network':
                nxtype = 'out'
            else:
                nxtype = 'tx'
            Gextr = extract_graph(nx.read_gml('Gq%s.gml' %currvalue),nxtype)
            #information about the contraction is removed so it can be saved in gml format
            for n in Gextr.nodes():
                if 'contraction' in Gextr.node[n].keys():
                    del Gextr.node[n]['contraction']
            nx.write_gml(Gextr, "Gextr%s.gml" %currvalue)
            current_file = "Gextr%s.gml" %addrvalue
            with out_comment:
                clear_output(wait=True)
                print('The network contains ',Gextr.order(),' nodes')
            #********#
            #plot
            #*******#
            if Gextr.order() < max_size and Gextr.order()>0:
                Gext,plotstyle_ext,all_labels_ext = styleG("Gextr%s.gml" %currvalue,labflag)
                startid = []
                #highlights the starting address
                if currtype == 'out' and nxtype == 'out':
                    startid = [int(n['id']) for n in Gext.vs if n['labels']==currtype and n['addr']==currvalue]
                #highlights the starting tx    
                if currtype == 'tx' and nxtype == 'tx':
                    startid = [int(n['id']) for n in Gext.vs if n['labels']==currtype and n['txhash']==currvalue]
                if len(startid)>0:
                    for i in startid:
                        plotstyle_ext["vertex_color"][i] = 'cyan'
                #mymargin = 200
                #mybbox = (900,600)
                #ig.plot(G, **plotstyle,bbox = mybbox, margin =mymargin)    
                ig.plot(Gext,'img/Gextr%s.png' %currvalue,**plotstyle_ext,vertex_frame_width=0)
                current_layout = plotstyle_ext
                #display within widget
                with out:
                    clear_output(wait=True)
                    display(Image(filename='img/Gextr%s.png' %currvalue))
                with out_comment:
                        clear_output()
                if labflag==True:
                    with out_comment:
                        clear_output(wait=True)
                        for i in all_labels_ext:
                            print(i)

            
    #**************************#
    #edge filtering
    #**************************#
    if tab.selected_index == 1:
        with out_comment:
            clear_output(wait=True)
            print('Please wait...querying the database...')
        lwbnd = float(text_min.value)
        upbnd = float(text_max.value)
        #cypher query
        answer = query_db_edge(text1.value,slider2.value,lwbnd,upbnd)
        #export to networkx
        Gq = answer.get_graph()
        #save the data for further use
        nx.write_gml(Gq, "Gq%s.gml" %text1.value)
        with out_comment:
            clear_output(wait=True)
            print('The transaction network contains ',Gq.order(),' nodes and ',Gq.size(),' edges.' )
        #********#
        #plot
        #********#
        currvalue = text1.value
        labflag = labelling.value
        Gnx = nx.read_gml('Gq%s.gml' %currvalue)
        if Gnx.order() < max_size and Gnx.order()>0:
            #basic layout
            G,plotstyle,all_labels = styleG("Gq%s.gml" %currvalue,labflag)
            startid = []
            #highlights the starting address
            startid = [int(n['id']) for n in G.vs if n['labels']=='out' and n['addr']==currvalue]
            if len(startid)>0:
                for i in startid:
                    plotstyle["vertex_color"][i] = 'cyan'
            ig.plot(G,'img/Gq%s.png' %currvalue,**plotstyle,vertex_frame_width=0)
            current_layout = plotstyle
            #display within widget
            with out:
                clear_output(wait=True)
                display(Image(filename='img/Gq%s.png' %currvalue))
            with out_comment:
                clear_output()
            if labflag==True:
                with out_comment:
                    clear_output(wait=True)
                    for i in all_labels:
                        print(i)
    
    #**************************#
    #paths
    #**************************#
    if tab.selected_index == 2:
        addrvalue1 = text_from.value
        addrvalue2 = text_to.value
        addrvalue3 = (text_addr.value).split(',')
        
        if not(text_addr.value==''):
            if not(text_from.value=='') or not(text_from.value==''):
                with out_comment:
                    clear_output(wait=True)
                    print('Choose paths between addresses or path confluence.')
            else:
                Gnx = nx.read_gml(current_file)
                idxlist = [n[0] for n in Gnx.nodes(data=True) if n[1]['labels']=='out' and n[1]['addr'] in addrvalue3]
                idxlist.extend([n[0] for n in Gnx.nodes(data=True) if n[1]['labels']=='tx' and n[1]['txhash'] in addrvalue3])
                jointPaths = findjointPath(Gnx,idxlist)
                listJointPaths = exportJointPath(Gnx,jointPaths)
        else:
            pq = findpath(addrvalue1,addrvalue2)   
            
        Gnx = nx.read_gml(current_file)
        if Gnx.order() < max_size:
            #************************#
            #draw paths
            #***********************# 
            G=ig.Graph.Read_GML(current_file)
            plotstyle = current_layout 
            color_dict = {"out": "light blue", "tx": "pink"}
            plotstyle["vertex_color"] = [color_dict[types] for types in G.vs["labels"]]
            
            allpathid = []
            if not(text_addr.value==''):
                #path confluence
                listnodepath = [pair[1] for pair in listJointPaths]
                allpathid = [int(n['id']) for n in G.vs if n['labels']=='out' and n['addr'] in listnodepath]
                allpathid.extend(int(n['id']) for n in G.vs if n['labels']=='tx' and n['txhash'] in listnodepath)
                for j in allpathid:
                    plotstyle["vertex_color"][j] = 'maroon'
                startid = []
                endid = []
                with out_comment:
                    clear_output(wait=True)
                    print(listnodepath)
            else:
                for i in range(len(pq)):     
                    pathid = [int(n['id']) for n in G.vs if n['labels']=='out' and n['addr'] in pq[i]]
                    pathid.extend(int(n['id']) for n in G.vs if n['labels']=='tx' and n['txhash'] in pq[i])
                    for j in pathid:
                        plotstyle["vertex_color"][j] = 'maroon'
                    allpathid.extend(pathid)
                #find start and end 
                startid = [int(n['id']) for n in G.vs if n['labels']=='out' and n['addr']==addrvalue1]
                startidtx = [int(n['id']) for n in G.vs if n['labels']=='tx' and n['txhash']==addrvalue1]
                if len(startidtx)>0:
                    startid.extend(startidtx)
                endid = [int(n['id']) for n in G.vs if n['labels']=='out' and n['addr']==addrvalue2]
                endidtx = [int(n['id']) for n in G.vs if n['labels']=='tx' and n['txhash']==addrvalue2]
                if len(endidtx)>0:
                    endid.extend(endidtx)
                with out_comment:
                    clear_output(wait=True)
                    print('Path(s):',pq)
                            
            vertex_set = set(allpathid)
            G.es["color"] = ["red" if (edge.source in vertex_set and edge.target in vertex_set) else "black" for edge in G.es] 
            #otherwise the edge setting supercedes the width change
            plotstyle.pop('edge_width', None)
            G.es["width"] = [4 if (edge.source in vertex_set and edge.target in vertex_set) else 0.5 for edge in G.es]
                        
            #highlights start and end (only for addresses)
            for i in startid:
                plotstyle["vertex_color"][i] = 'yellow'
            for i in endid:
                plotstyle["vertex_color"][i] = 'green'
            ig.plot(G,current_file.split('.',0)[0]+'.png',**plotstyle,vertex_frame_width=0)
                    
            #display within widget
            with out:
                clear_output(wait=True)
                display(Image(current_file.split('.',0)[0]+'.png'))
            
        
    #**************************#
    #cluster
    #**************************#
    if tab.selected_index == 3: 
        Gnx = nx.read_gml(current_file)
        nbclus = int(nb_clus.value)
        currvalue = text1.value
        #**************************#
        #spectral clustering
        #**************************#
        if choice_clus.value == 'spectral':
            #compute spectral clustering from the adjacency matrix
            nx.to_numpy_matrix(Gnx.to_undirected())
            adj_mat = nx.to_numpy_matrix(Gnx.to_undirected())
            sc = SpectralClustering(nbclus, affinity='precomputed', n_init=100)
            sc.fit(adj_mat)
            clusters = sc.labels_
            #plot the result
            Gclus,plotstyle_clus,all_labels = styleG(current_file,labelling.value)
            pal = ig.drawing.colors.ClusterColoringPalette(nbclus)
            for n in Gclus.vs():
                plotstyle_clus["vertex_color"][int(n['id'])] = pal.get(clusters[int(n['id'])])
                ig.plot(Gclus,'img/Gqclus%s.png' %currvalue,**plotstyle_clus,vertex_frame_width=0)
            with out:
                clear_output(wait=True)
                display(Image(filename='img/Gqclus%s.png' %currvalue))
            with out_comment:
                clear_output()
        #**************************#
        # probabilitistic clustering
        #**************************#   
        if choice_clus.value == 'probabilistic':
            nbclus = int(nb_clus.value)
            alph = int(alpha.value)
            dwf = float(dw.value)
            muf = int(mu.value)
            nbiter = int(nb_iter.value)
            Gnxinput =  nx.convert_node_labels_to_integers(Gnx, first_label=0)
            #keep first parameters to 3 and 0.3, somehow changing 3 sometimes gives errors
            clusters = ClusteringProbDist(Gnxinput,3,0.30,nbiter,alph,dwf,0,muf)
            #plot the result
            Gclus,plotstyle_clus,all_labels = styleG(current_file,labelling.value)
            pal = ig.drawing.colors.ClusterColoringPalette(len(clusters))
            for n in Gclus.vs():
                idx = int(n['id'])
                plotstyle_clus["vertex_color"][idx] = pal.get([clusters.index(sblist) for sblist in clusters if idx in sblist][0])
                ig.plot(Gclus,'img/Gqclus%s.png' %currvalue,**plotstyle_clus,vertex_frame_width=0)
            with out:
                clear_output(wait=True)
                display(Image(filename='img/Gqclus%s.png' %currvalue))
    #**************************#
    #neighbours
    #**************************#
    if tab.selected_index == 4:
        inputlist = list_neighb.value
        list_addr = inputlist.split(',')

        if choice1.value == 'a Bitcoin address':
            currtype = 'out'
        else:
            currtype = 'tx'
        addrvalue = text1.value
        lst = findnodes(list_addr)
    
        with out_comment:
            clear_output(wait=True)
            print('Found nodes:',lst)
        Gnx = nx.read_gml(current_file)
        if Gnx.order() < max_size and Gnx.order()>0:
            G=ig.Graph.Read_GML(current_file)
            plotstyle = current_layout 
            color_dict = {"out": "light blue", "tx": "pink"}
            plotstyle["vertex_color"] = [color_dict[types] for types in G.vs["labels"]]
            #nodes of interest, typically the starting address
            neighid = [int(n['id']) for n in G.vs if n['labels']=='out' and n['addr'] in lst]
            neighid.extend([int(n['id']) for n in G.vs if n['labels']=='tx' and n['txhash'] in lst]) 
            for i in neighid:
                plotstyle["vertex_color"][i] = 'green'
            ig.plot(G,current_file.split('.')[0]+'.png',**plotstyle,vertex_frame_width=0)
            
            #display within widget
            with out:
                clear_output(wait=True)
                display(Image(current_file.split('.')[0]+'.png'))
        
#****************#    
#queries neo4j
#****************#    
def query_db_addr(nodeid,radius):
    #cypher query: looks for neighbours at a given radius around the address nodeid
    #this is really stupid, should learn cypher
    if radius == 1:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..1]-(m) RETURN g
    elif  radius == 2:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..2]-(m) RETURN g
    elif radius == 3:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..3]-(m) RETURN g
    elif  radius == 4:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..4]-(m) RETURN g
    elif  radius == 5:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..5]-(m) RETURN g
    elif  radius == 6:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..6]-(m) RETURN g
    elif  radius == 7:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..7]-(m) RETURN g
    elif  radius == 8:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..8]-(m) RETURN g
    elif  radius == 9:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..9]-(m) RETURN g
    else :
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {nodeid}})-[*..10]-(m) RETURN g
    return query     

def query_db_tx(nodeid,radius):
    #cypher query: looks for neighbours at a given radius around the transaction nodeid
    #this is really stupid, should learn cypher
    if radius == 1:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..1]-(m) RETURN g
    elif radius == 2:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..2]-(m) RETURN g
    elif radius == 3:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..3]-(m) RETURN g
    elif radius == 4:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..4]-(m) RETURN g
    elif radius == 5:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..5]-(m) RETURN g
    elif radius == 6:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..6]-(m) RETURN g
    elif radius == 7:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..7]-(m) RETURN g
    elif radius == 8:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..8]-(m) RETURN g
    elif radius == 9:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..9]-(m) RETURN g
    else :
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:tx {txhash: {nodeid}})-[*..10]-(m) RETURN g
    return query

def query_db_edge(addrvalue,depth,lwbnd,upbnd):
    #cypher query: looks for neighbours with filtered edges at a given radius around the transaction nodeid
    #this is really stupid, should learn cypher
    if depth == 1:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..1]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 2:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..2]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 3:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..3]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 4:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..4]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 5:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..5]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 6:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..6]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 7:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..7]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 8:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..8]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 9:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..9]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif depth == 10 :
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g = (n:out {addr: {addrvalue}})-[r*..10]-(m) WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL)  RETURN g
    elif depth == 11:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..11]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 12:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..12]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 13:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..13]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 14:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..14]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 15:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..15]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 16:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..16]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 17:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..17]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 18:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..18]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    elif  depth == 19:
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g=(n:out {addr: {addrvalue}})-[r*..19]-(m)  WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL) RETURN g
    else :
        query = %cypher http://neo4j:neo4jpswd@localhost:7474/db/data MATCH g = (n:out {addr: {addrvalue}})-[r*..20]-(m) WHERE ALL (z IN r WHERE (z.amount >= {lwbnd} AND z.amount <= {upbnd} ) OR z.amount IS NULL)  RETURN g
    return query

#*****************#
# plot with igraph
#*****************#
def styleG(fileG,l):
    #**********************************************#
    #fileG is the name of the file to be read
    #l is a flag for labels: True = label, False = no label
    #**********************************************#
    G=ig.Graph.Read_GML(fileG)
    #kamada-kawai layout
    layt=G.layout('kk') 
    #style
    visual_style = {}
    visual_style["layout"] = layt
    nbv = len(G.vs)
    visual_style["vertex_size"] = max(1,round(10/math.log10(10)))
    #see x11 color names
    color_dict = {"out": "light blue", "tx": "pink"}
    visual_style["vertex_color"] = [color_dict[types] for types in G.vs["labels"]]
    visual_style["bbox"] = (600, 500)
    
    #add labels
    lst_labels = []
    if l==True: 
        #label with integers    
        visual_style["vertex_label"] = [str(int(l)) for l in G.vs["id"]]
        #label by txid and addr
        lst_labels = [(int(n['id']),n['txhash']) if n['labels']=='tx' else (int(n['id']),n['addr']) for n in G.vs]
        #visual_style["vertex_label"] = [n['txhash'] if n['labels']=='tx' \
        #                                else n['addr'] for n in G.vs ]
        #visual_style["vertex_label"] = [n['txhash'] if n['labels']=='tx' \
        #                                else ' ' for n in G.vs ]
    else:
        #no label
        visual_style["vertex_label"] = [' ' for l in G.vs["id"]]
    
    visual_style["edge_width"] = 0.3
    visual_style["edge_arrow_size"] = [0.5 for e in G.es]
    return G,visual_style,lst_labels

#************************************#
# refresh plot with zoom and interact
#************************************#
def plt_zoom(x):
    global current_file
    if not(current_file == 0) and choice_view.value == 'dual mode' :
        if x<slider1.value:
            currvalue = current_file.split('.')[0][2:]
            Gnx = nx.read_gml(current_file)
            if choice1.value == 'a Bitcoin address':
                currtype = 'out'
            else:
                currtype = 'tx'
            #compute distances
            distfromaddr = compd(Gnx.to_undirected(),currvalue,currtype)
            #write data
            Gnx.remove_nodes_from([n for n in distfromaddr.keys() if distfromaddr[n] > x])
            nx.write_gml(Gnx, "Gqz%s.gml" %currvalue)
            #load for plotting
            G=ig.Graph.Read_GML(current_file)
            Gz=ig.Graph.Read_GML("Gqz%s.gml" %currvalue)
            Gz_addr_tx = [n['addr'] if n['labels']=='out' else n['txhash'] for n in Gz.vs]
            far_nodes = []
            l1 = [n for n in G.vs if n['labels'] == 'out']
            l2 = [n for n in G.vs if n['labels'] == 'tx']        
            far_nodes = [int(n['id']) for n in l1 if not(n['addr'] in Gz_addr_tx)]
            far_nodes.extend([int(n['id']) for n in l2 if not(n['txhash'] in Gz_addr_tx)])
            plotstyle_zoom  = current_layout
            color_dict = {"out": "light blue", "tx": "pink"}
            plotstyle_zoom["vertex_color"] = [color_dict[types] for types in G.vs["labels"]]
            if currtype == 'out':
                startid = [int(n['id']) for n in G.vs if n['labels']=='out' and n['addr']==currvalue]
                if len(startid)>0:
                    for i in startid:
                        plotstyle_zoom["vertex_color"][i] = 'cyan'
            if len(far_nodes)>0:
                for i in far_nodes:
                    plotstyle_zoom["vertex_color"][i] = 'white'
                    plotstyle_zoom["vertex_label"][i] = ' '
            plotstyle_zoom["edge_width"] = [0 if (edge.source in far_nodes or edge.target in far_nodes) else 0.3 for edge in G.es]
            plotstyle_zoom["edge_arrow_size"] = [0 if (edge.source in far_nodes or edge.target in far_nodes) else 0.5 for edge in G.es] 
            ig.plot(G,'img/Gqz%s.png' %currvalue,**plotstyle_zoom,vertex_frame_width=0)
            with out: 
                clear_output(wait=True)
                display(Image(filename='img/Gqz%s.png' %currvalue))
        else:
            currvalue = current_file.split('.')[0][2:]
            G=ig.Graph.Read_GML(current_file)
            plotstyle_zoom  = current_layout
            color_dict = {"out": "light blue", "tx": "pink"}
            plotstyle_zoom["vertex_color"] = [color_dict[types] for types in G.vs["labels"]]
            plotstyle_zoom["edge_width"] = 0.3
            plotstyle_zoom["edge_arrow_size"] = 0.5
            ig.plot(G,'img/Gq%s.png' %currvalue,**plotstyle_zoom,vertex_frame_width=0)
            with out:
                clear_output(wait=True)
                display(Image(filename='img/Gq%s.png' %currvalue))


#**************************************#
#extracts transaction/address network
#**************************************#
def extract_graph(G,label):
    #G is the graph read from gml format
    #label is either 'tx' or 'out'
    listedges = list(G.edges())
    listsubedges = []
    
    #extract IDs of transaction/output nodes
    keptnodes =  [n[0] for n in list(G.nodes(data=True)) if n[1]['labels']==label]
    
    for u in keptnodes:
        listneighb_u = [listedges[n][1] for n in range(len(listedges)) if listedges[n][0] == u ]
        for v in listneighb_u:
            listnewedges = [(u,listedges[n][1]) for n in range(len(listedges)) if listedges[n][0] == v and listedges[n][1] in keptnodes]
            listsubedges.extend(listnewedges)
    
    Gnew =nx.DiGraph()
    #nodes, they are added separately from edges in case there are isolated nodes
    Gnew.add_nodes_from([n for n in list(G.nodes(data=True)) if n[1]['labels']==label])
    #edges
    Gnew.add_edges_from(listsubedges)
    
    if label == 'out':
        #merges the same addresses
        addr_map = defaultdict(list)
        #identifies repeated addresses
        for n in list(Gnew.nodes(data=True)):
            addr_map[n[1]['addr']].append(n[0])

        #for each repeated address, find node identifiers
        for key in addr_map.keys():
            ll = addr_map[key]
            if len(ll) > 1:
                #contract repeated addresses pairwise
                for n in ll[1:]:
                    Gnew = nx.contracted_nodes(Gnew,ll[0],n)
    
    return Gnew

#**********************************#
#find paths between two addresses
#**********************************#
def findpath(addrvalue1,addrvalue2):
    #
    Gnx = nx.read_gml(current_file)
    if check_bc(addrvalue1): 
        nodeid1 = [n for n in list(Gnx.nodes()) if Gnx.nodes(data=True)[n]['labels']=='out' and Gnx.nodes(data=True)[n]['addr']==addrvalue1]
    else:
        nodeid1 = [n for n in list(Gnx.nodes()) if Gnx.nodes(data=True)[n]['labels']=='tx' and Gnx.nodes(data=True)[n]['txhash']==addrvalue1]
   
    if check_bc(addrvalue2):
        nodeid2 = [n for n in list(Gnx.nodes()) if Gnx.nodes(data=True)[n]['labels']=='out' and Gnx.nodes(data=True)[n]['addr']==addrvalue2]
    else:
        nodeid2 = [n for n in list(Gnx.nodes()) if Gnx.nodes(data=True)[n]['labels']=='tx' and Gnx.nodes(data=True)[n]['txhash']==addrvalue2]

    ppall = []

    for addr1 in nodeid1:
        for addr2 in nodeid2:
            if choice2.value == 'undirected': 
                pt = nx.shortest_path(Gnx.to_undirected(),addr1,addr2)
            if choice2.value == 'directed': 
                try:
                    pt = nx.shortest_path(Gnx,addr1,addr2)
                except nx.NetworkXNoPath:
                    with out_comment:
                        print("No directed path.")
                    pt = []
                        
            ppt = []                      
            for n in pt:
                if 'addr' in Gnx.nodes(data=True)[n].keys():
                    ppt.append(Gnx.nodes(data=True)[n]['addr'])
                else:
                    ppt.append(Gnx.nodes(data=True)[n]['txhash'])              
            ppall.append(ppt)

    return ppall 

#***************************#
#find neighbours
#***************************#
def findnodes(listaddr):
    
    addrvalue = text1.value
    Gnx = nx.read_gml('Gq%s.gml' %addrvalue)
    listinG = [n[1]['addr'] for n in list(Gnx.nodes(data=True)) if n[1]['labels']=='out'] 
    listinG.extend([n[1]['txhash'] for n in list(Gnx.nodes(data=True)) if n[1]['labels']=='tx'])
    
    return [addr for addr in listaddr if addr in listinG]


#*************************************#
#compute which nodes at which distance
#*************************************#
def compd(currG,currvalue,currtype):
    #currtype is either 'tx' or 'out'
    #compute distances from nodes of interest
    if currtype == 'out':
        nodeid = [n for n in list(currG.nodes()) if currG.nodes(data=True)[n]['labels']=='out' and currG.nodes(data=True)[n]['addr']==currvalue]
    else:
        nodeid = [n for n in list(currG.nodes()) if currG.nodes(data=True)[n]['labels']=='tx' and currG.nodes(data=True)[n]['txhash']==currvalue]
        
    #create a dictionary to store the closest distance between any node and different instances of the same addr
    at_dist = {}
    for n in currG.nodes():
        at_dist[n] = 10000
    #computes distances from chosen node
    for id in nodeid: 
        pathlen = nx.single_source_shortest_path_length(currG,id) 
        for n in pathlen.keys():
            at_dist[n] = min(at_dist[n],pathlen[n])
    return at_dist 



In [6]:
import time
import operator
import random
from sklearn import cluster

#******************************#
#Bui's code for path confluence
#******************************#

def BFS_findjoin(s,G,successors,cutoff):
    #written by Phetsouvanh Silivanxay
    # Create a queue for BFS
    queue = deque()
    visited = set()
    # Mark the source node as 
    # visited and enqueue it
    queue.append(s)
    level = 0
    numChild = 1;
    numNewChild = 0;
    nodeCount = 0;
    while queue:
        node = queue.popleft()
        iteration = []
        if( successors):
            iteration = G.successors(node)
            
        else:
            iteration = G.predecessors(node)
        for neighbor in iteration:
            if neighbor not in visited:
                numNewChild+=1;
                queue.append(neighbor)
                visited.add(neighbor)
            elif neighbor != node and node!=s and neighbor in visited:
                return neighbor
        nodeCount+=1;
        if( nodeCount == numChild):
            level+=1
            numChild = numNewChild;
            if(level==cutoff):
                return "empty"
        

    return "empty"

def findjointPath(Kcg,extracted_addr):
    #written by Phetsouvanh Silivanxay
    Kcg_u = Kcg.to_undirected()
    shortest_paths = []
    for sourceAddr in extracted_addr:
        for destinationAddr in extracted_addr:
            if ( sourceAddr!=destinationAddr):
                try:
                    paths =nx.all_shortest_paths(Kcg_u, sourceAddr, destinationAddr)
                    for path in paths:
                        shortest_paths.append(path)
                except nx.exception.NodeNotFound:
                    pass
                except nx.exception.NetworkXNoPath:
                    pass
    
    
    endNode = []
    startNode = []
    countt=0
    
    jointPaths = []
    for path in shortest_paths:
        if len(path) >2:
            for i in range(len(path)-2):
                pre = path[i]
                mid = path[i+1]
                suc = path[i+2]
                if (pre in Kcg.predecessors(mid) and suc in Kcg.predecessors(mid)):
                    endNode.append(mid)
                    paths = []
                    node = BFS_findjoin(mid,Kcg,False,10)
                    if  (node !='empty'):
                        paths = nx.all_simple_paths(Kcg, source=node, target=mid, cutoff=10)
                        paths = list (paths)
                    if( paths != []):
                        paths[0] = paths[0]+path
                        jointPaths.append(paths)
                elif (pre in Kcg.successors(mid) and suc in Kcg.successors(mid)):
                    startNode.append(mid)
                    end = BFS_findjoin(mid,Kcg,True,10)
                    paths = []
                    if  (end !='empty'):
                        paths = nx.all_simple_paths(Kcg, source=mid, target=end, cutoff=1)
                        paths = list (paths)
                    if( paths != []):
                        paths[0] = paths[0]+path
                        jointPaths.append(paths)
        countt+=1
    endNode = set(endNode)
    startNode = set(startNode)
    #print ('endNode:',endNode)
    #print ('startNode:',startNode)    
    
    nodesAlongPaths = set()
    for multipaths in jointPaths:
        for path in multipaths:
            for address in path:
                nodesAlongPaths.add(address)
    #print ('nodesAlongPaths:',nodesAlongPaths)

    
    return jointPaths

def exportJointPath(Kcg,jointPaths):
    #modified to return instead of write
    visited = set()
    pathlist = []
    for idn, label in Kcg._node.items():
        count = 1
        for multipaths in jointPaths:
            tmplist = []
            for path in multipaths:
                if (idn in path and idn not in visited):
                    #print({'Id': idn, 'JointPath':count})
                    if 'addr' in label.keys():
                        tmplist.append((idn,label['addr']))
                    if 'txhash' in label.keys():
                        tmplist.append((idn,label['txhash']))
                    visited.add(idn)
            count +=1
            pathlist.extend(tmplist)
    return pathlist 

#******************************#
#Bui's code for clustering
#******************************#

def ClusteringProbDist(Kcg,num_cluster,topN,iter,alpha,dw,t,mu_f):
    #written by Phetsouvanh Silivanxay
    tt = time.time()
    N = Kcg.order()
    #Pkcg, D = MatrixWithAuxilary(Kcg,w)
    #Pkcg, D = MatrixWithAuxilaryFanOut(Kcg,w)
    #Hinf, mt = Hijt(Pkcg,D,t)

    # mt is probablity matrix
    #Hinf,mt = Hijinf(Pkcg,D)
    Hinf,mt,Pkcg = Ht(Kcg, alpha,dw,t,mu_f)

    #seelct query nodes from bottom n entropy value
    includedSet = []
    queryNodesClusterAggreation = []
    
    TopEntropyNodes = getBottomN(Hinf,int(N*topN),True)

    remainingNodes = getRemainingNodeSortedBylowestEntropy(Kcg,includedSet,Hinf,False)

    exceptionNodes = []
    while(len(remainingNodes) > 0):
        queryNodes = []
        for remainingNode in remainingNodes:
            queryNodes.append(remainingNode)
            break
        #print('queryNodes',queryNodes)
        queryNodesCluster = getQueryNodeCluster(queryNodes,N,mt,num_cluster)
        queryNodesCluster = getValidQueryNodeCluster(queryNodesCluster,TopEntropyNodes,queryNodes,exceptionNodes,mt,queryNodesClusterAggreation)

        queryNodesCluster.extend(queryNodesClusterAggreation)
        queryNodesClusterAggreation, includedSet = ClusterAggregation(queryNodesCluster)
        includedSet = includedSet.union(exceptionNodes)

        for eachNode in includedSet:
            if ( eachNode in includedSet and eachNode in remainingNodes):
                remainingNodes.remove(eachNode)

    
    clusters = queryNodesClusterAggreation
    #print ('seeds.append (',queryNodesClusterAggreation,')')
    for i in range (iter):
        #print ('round:', i)
        queryNodesClusterAggreation =ClusterAgglomertion(clusters, mt, num_cluster,Hinf,Pkcg)
        clusters = queryNodesClusterAggreation
        #print ('final:',clusters)
    
    
    elapsed = time.time() - tt
    with out_comment: 
        print('Complete  time in secs',elapsed)
    return clusters

def Ht(G,alpha,dw,t,mu_f):
    tt = time.time()
    #G is the graph, nodes must be labelled by integers starting from 0, 
    #use nx.convert_node_labels_to_integers(G, first_label=0) if needed
    #
    #alpha is a flag: 
    #alpha = 0 means the graph is considered unweighted, i.e. alpha(w(e)) = 1 for all edges e 
    #use alpha = 0 if the graph is actually unweighted
    #alpha = 1 means that we use the weights of the graph, i.e. alpha(w(e)) = w(e) for all edges e
    #
    #dw defines the matrix of auxiliary probabilities 
    #dw = 0 means that the function Dunif is called
    #dw > 0 means that dw*I is used as the matrix
    #
    #t is the time, if t==0, the asymptotic behaviour is computed
    #
    #mu_f is a flag:
    #mu_f = 0 means that we use a non-weighted entropy
    #mu_f = 1 means that we use an entropy weighted by the ratio of weighted out-degree by degree
    #mu_f = 2 means that we use an entropy weighted by the ratio of log2 weighted out-degree by log2 degree

    #number of nodes
    n = G.order()
    
    def Dunif(G,alpha):
        #compute the auxiliary probabilities in a way which is proportional to the weighted degree 
        n = G.order()
        if alpha == 0:
            Du = np.zeros((n,n))
            for i in range(n):
                li = list(G[i].keys())
                #nb of neighbours + self-loop + auxiliary node
                Du[i][i] = 1/(len(li)+2)

            
        if alpha == 1:
            Du = np.zeros((n,n))
            #list containing the weighted (out-)degrees
            if nx.is_directed(G):
                weighted_deg = G.out_degree(weight='weight')
            else:
                weighted_deg = G.degree(weight='weight')
            for i in range(n):
                li = list(G[i].keys())
                #nb of neighbours + self-loop + auxiliary node
                Du[i][i] = 1/(weighted_deg[i]+2)       
        
        return Du
    
    #compute the transition probability matrices P and Ptilde
    #matrix of probabilities
    P = np.zeros((n,n))
    #initialize Ptilde
    Pt = np.zeros((n,n))
    #adjacency matrix and self-loops added
    A = nx.to_numpy_matrix(G)+np.identity(n)
    
    #computes the matrix of auxiliary probabilities
    D = []
    if dw == 0:
        D = Dunif(G,alpha)
    else: 
        D = dw*np.identity(n)
    
   
    
    if alpha == 0:
        for i in range(n):
            li = list(G[i].keys())
            for j in range(n):
                if not(A[i].getA()[0][j] == 0):
                    P[i][j] = 1/(len(li)+1)    
                    Pt[i][j] = P[i][j] - D[i][i]/(len(li)+1)
    weighted_deg = []
    if alpha == 1:
        #list containing the weighted (out-)degrees
        if nx.is_directed(G):
            weighted_deg = G.out_degree(weight='weight')
        else:
            weighted_deg = G.degree(weight='weight')
        
        for i in range(n):
            li = list(G[i].keys())
            for j in range(n):
                Aij = A[i].getA()[0][j]
                if not(Aij == 0):
                    #a weight of 1 is given to the self-loop
                    P[i][j] = Aij/(weighted_deg[i]+1)
                    Pt[i][j] = P[i][j] -D[i][i]*Aij/(weighted_deg[i]+1)
                    #Pt[i][j] = P[i][j] - D[i][i]/(len(li)+1)
                    

    #m contains the modified version of P, Ptilde  
    mup = np.hstack((Pt, D))
    
    mdown = np.hstack((np.zeros((n,n)), np.identity(n)))
    m = np.vstack((mup,mdown))
    
    #distinguish finite from asymptotic values of t
    if t == 0:
        #asymptotic case        
        mt = np.matmul(np.linalg.inv(np.identity(n)-Pt),D)
    else: 
        mt = np.linalg.matrix_power(m,t)
   
    #define the weight mu for each node
    mu = []
    if mu_f == 0:
        mu = np.ones(n)
        
    if mu_f == 1:
        for u in G.nodes():
            dwout = sum([G[u][v]['weight'] for v in G.successors(u)])
            if not(dwout == G.out_degree(u)): 
                #ratio of degree as weight
                mu.append(dwout/G.out_degree(u))
            else:
                mu.append(1)
       
    if mu_f == 2:
        for u in G.nodes():
            dwout = sum([G[u][v]['weight'] for v in G.successors(u)])
            if not(dwout == G.out_degree(u)): 
                #ratio of entropic centralities as weight
                mu.append(np.log2(dwout)/np.log2(G.out_degree(u)))
            else:
                mu.append(1) 
    
    #initialize entropy vector
    H = [0 for k in range(n)];
    if t == 0:
        #asymptotic case
        for i in range(n):
            for j in range(n):
                pij = mt[i][j] 
                if pij != 0:
                    #handles numerical approximation of zero
                    if abs(pij)> 1.0e-14:
                        H[i] = H[i] - pij*np.log2(pij)*mu[j]
    else:
        #finite case
        for i in range(n):
            for j in range(n):
                pij = mt[i][j] + mt[i][n+j]
                if pij != 0:
                    H[i] = H[i] - pij*np.log2(pij)*mu[j]
                    
    elapsed = time.time() - tt
    #print('time in secs',elapsed)
    return H,mt,P

def getBottomN(H,n,rev):
    #written by Phetsouvanh Silivanxay
    H_dict= dict()
    for i in range (len(H)):
        H_dict[i] = H[i]
    count = 0
    output = []
    for key, value in sorted(H_dict.items(), key=operator.itemgetter(1),reverse=rev):
        if(count == n):
            break
        output.append(key)
        count = count+1

    return output

def getRemainingNodeSortedBylowestEntropy(Kcg,includedSet,H,rev):
    #written by Phetsouvanh Silivanxay
    H_dict= dict()
    for i in range (len(H)):
        H_dict[i] = H[i]
    count = 0
    remainingNodes = []
    for node , value in sorted(H_dict.items(), key=operator.itemgetter(1),reverse=rev):
        if node not in includedSet:
            remainingNodes.append(node)
    return remainingNodes

def getQueryNodeCluster(queryNodes,N,mt,num_cluster):
    #written by Phetsouvanh Silivanxay
    queryNodesCluster =[]
    for queryNode in queryNodes:
        probDist = []
        indxedClusters, maxIndex = findProbdistClusterWithQueryNode(mt,num_cluster,queryNode,N)
        #print ('mini result',maxIndex, indxedClusters[maxIndex])
        queryNodesCluster.append(indxedClusters[maxIndex])
    
    return queryNodesCluster

def getValidQueryNodeCluster(queryNodesCluster,TopEntropyNodes,queryNodes,exceptionNodes,mt,queryNodesClusterAggreation):
    #written by Phetsouvanh Silivanxay
    validCluster = []
    count = 0;
    for cluster in queryNodesCluster:
        numOfHighEntropy = 0
        highEntropyNode = []
        maxProb = 0
        maxProbNode = 0
        countMajority = dict()
        isContainedAllTopEntropy = isContainedAllTopEntropyNodes(TopEntropyNodes, queryNodes[count],cluster)
        
        for i in range (len(queryNodesClusterAggreation)):
            countMajority[i] = 0
        for node in cluster:
            if queryNodes[count] in TopEntropyNodes or isContainedAllTopEntropy:
                for i in range (len(queryNodesClusterAggreation)):
                    if( node in queryNodesClusterAggreation[i]):
                        countMajority[i] = countMajority[i]+1
                
            if ( node in TopEntropyNodes and node != queryNodes[count]):
                numOfHighEntropy = numOfHighEntropy+1
                highEntropyNode.append(node)
                currentProbNode = mt[queryNodes[count]][node]
                #print ('highEntropyNode',queryNodes[count],node, currentProbNode)
                
                if ( maxProb < currentProbNode):
                    maxProb = currentProbNode
                    maxProbNode = node
        if  (queryNodes[count] in TopEntropyNodes or isContainedAllTopEntropy):
            maxMajority = 0#countMajority[0]
            maxMajorityCluster = 0
            for i in range (len(queryNodesClusterAggreation)):
                if ( maxMajority < countMajority[i]):
                        maxMajority = countMajority[i]
                        maxMajorityCluster = i
            iterable_cluster = list(cluster)
            for i in range (len(queryNodesClusterAggreation)):
                if ( i != maxMajorityCluster):
                    for node in iterable_cluster:
                        if( node in queryNodesClusterAggreation[i]):
                            cluster.remove(node)
            #print ('validCluster',cluster)
        if ( numOfHighEntropy > 1 and not isContainedAllTopEntropy):
            for node in highEntropyNode:
                if( node != maxProbNode and node in cluster):
                    cluster.remove(node)
            #print ('numOfHighEntropy >1 -> validCluster',cluster,'maxProbNode:',maxProbNode)
            exceptionNodes.append(queryNodes[count])
        count = count+1
        validCluster.append(cluster)
        #print ('validCluster',validCluster)
    return validCluster 

def ClusterAgglomertion(clusters, mt, num_cluster,Hinf,Pkcg):
    #written by Phetsouvanh Silivanxay
    aggClusters=[]
    N = len(clusters)
    probdistMatrix = [[0 for x in range(N)] for y in range(N)]
    probdistMatrixMax = [[0 for x in range(N)] for y in range(N)]
    connectMatrix = [[0 for x in range(N)] for y in range(N)]
    for i in range  (N):
        for j in range  (N):
            if ( i != j):
                Min = 1
                Max = 0
                Sum = 0
                connect = 0
                for nodeI in clusters[i]:
                    for nodeJ in clusters[j]:
                        if( Max < mt[nodeI][nodeJ]):
                            Max = mt[nodeI][nodeJ]
                        if( Min > mt[nodeI][nodeJ] and mt[nodeI][nodeJ]!=0):
                            Min = mt[nodeI][nodeJ]
                        if ( Pkcg[nodeI][nodeJ]!= 0):
                            connect = Pkcg[nodeI][nodeJ]
                        Sum = mt[nodeI][nodeJ]
                avg = Sum / (len(clusters[i])*len(clusters[j]))
                #print (i, j,'avg:', avg,'min:',min,'max:',max)
                probdistMatrix[i][j] =  Min
                probdistMatrixMax[i][j] = Max
                connectMatrix[i][j] = connect
    N = len(clusters)
    
    
    clusterList = getBottomNClusters(Hinf,int(N),False,clusters)
    TopEntropyClusterList = getBottomNClusters(Hinf,int(N*0.3),True,clusters)
    randomList = random.sample(range(0,N),N)
    queryNodesClusterAggreation = []
    #clusterList = randomList
    #print ('clusterList',clusterList)
    #print ('TopEntropyClusterList',TopEntropyClusterList)
    while(len(clusterList) > 0):
        node = []
        for acluster in clusterList:
            node = acluster
            break
        clusterList.remove(node)
        indxedClusters, maxIndex = findProbdistClusterWithQueryNode(probdistMatrix,num_cluster,node,N)
        exceptionNodes = []
        indxedClusters[maxIndex] = getValidQueryNodeCluster([indxedClusters[maxIndex]],TopEntropyClusterList,[node],exceptionNodes,probdistMatrix,queryNodesClusterAggreation)
        copy_indexCluster = list(indxedClusters[maxIndex][0])
        for index  in copy_indexCluster:
            if (connectMatrix[node][index] ==0 and node != index):
                indxedClusters[maxIndex][0].remove(index)
                #print ('remove not connect',indxedClusters[maxIndex][0])
        
        
        for i in range (len(indxedClusters)):
            if ( i ==maxIndex):
                #resultSet = set()
                #print ('indxedClusters[i][0]',indxedClusters[i][0])
                for index  in indxedClusters[i][0]:
                    if ( index in clusterList):
                        clusterList.remove(index)

                aggClusters.append(set(indxedClusters[maxIndex][0]))
        #print ('aggClusters',aggClusters)
        aggClusters.extend(queryNodesClusterAggreation)
        queryNodesClusterAggreation, includedSet = ClusterAggregation(aggClusters)
        #print ('final queryNodesClusterAggreation:',queryNodesClusterAggreation)

    MappingBackCluster = []
    with out_comment:
        clear_output(wait=True)
        print ('seeds.append (',queryNodesClusterAggreation,')')
    for acluster in queryNodesClusterAggreation:
        resultSet = set()
        for index in acluster:
            resultSet = resultSet.union(clusters[index])
        MappingBackCluster.append(resultSet)
    
    MappingBackCluster.extend(clusters)
    queryNodesClusterAggreation, includedSet = ClusterAggregation(MappingBackCluster)
    #print ('final MappingBackCluster:',queryNodesClusterAggreation)
    return queryNodesClusterAggreation

def ClusterAggregation(queryNodesCluster):
    #written by Phetsouvanh Silivanxay
    queryNodesClusterAggreation = []
    includedSet = set()
    while(len(queryNodesCluster) >0):
        mergeCluster = []
        for singleCluster in  queryNodesCluster:
            mergeCluster = singleCluster
            break
        queryNodesCluster.remove(mergeCluster)
        mergeClusterSet = set(mergeCluster)
        removeClusters = []
        for singleCluster in  queryNodesCluster:
            singleClusterSet = set(singleCluster)
            if (singleClusterSet.intersection(mergeClusterSet)):
                mergeClusterSet = mergeClusterSet.union(singleClusterSet)
                removeClusters.append(singleCluster)
        for item in removeClusters:
            queryNodesCluster.remove(item)
        
        queryNodesClusterAggreation.append(mergeClusterSet)
        includedSet = includedSet.union(mergeClusterSet)
    return queryNodesClusterAggreation,includedSet


def findProbdistClusterWithQueryNode(probdistMatrix,num_cluster,node,N):
    #written by Phetsouvanh Silivanxay
    probdist = []
        
    for i in range  (N):
        if (i != node):
            probdist.append([probdistMatrix[node][i],0])
        else:
            probdist.append([0,0])
    results =[]
    if( N > 2):
        agglomerative = cluster.AgglomerativeClustering(n_clusters=num_cluster, linkage="ward",affinity='euclidean')
        agglomerative.fit(probdist)   
    
        results.append(list(agglomerative.labels_))
    else:
        agglomerative = cluster.AgglomerativeClustering(n_clusters=1, linkage="ward",affinity='euclidean')
        agglomerative.fit(probdist)   
    
        results.append(list(agglomerative.labels_))
    #print (node,'results',results)

    avg_result = dict()
    count_result = dict()
    indxedClusters = []
    for i in range (num_cluster):
        avg_result[i] = 0
        count_result[i] = 0
        indxedClusters.append([])
    for i in range (len(results[0])):
        avg_result[results[0][i]] = avg_result[results[0][i]] +probdist[i][0]
        count_result [results[0][i]] = count_result [results[0][i]]+1
        indxedClusters[results[0][i]].append(i)
    max = 0
    maxIndex = 0
    for i in range (num_cluster):
        if ( count_result [i] != 0):
            avg_result[i] = avg_result[i] / (count_result [i]+0.0)
            if( max < avg_result[i] and avg_result[i] != 0 and avg_result[i] !=1) :
                max = avg_result[i]
                maxIndex = i
            #print (i,avg_result[i], count_result[i])
    if (max == 0 or max ==1):
        indxedClusters[maxIndex] = []
    indxedClusters[maxIndex].append(node)
    #print ('mini result',maxIndex, indxedClusters[maxIndex])
    return indxedClusters,maxIndex

def isContainedAllTopEntropyNodes(TopEntropyNodes,queryNode,cluster):
    #written by Phetsouvanh Silivanxay
    contained = True
    for node in cluster:
        if( node != queryNode and node not in TopEntropyNodes):
            contained = False
    return contained


def getBottomNClusters(H,n,rev,clusters):
    #written by Phetsouvanh Silivanxay
    H_dict= dict()
    for i in range (len(H)):
        H_dict[i] = H[i]
    H_cluster_dict= dict()
    for i in range ( len(clusters)):
        H_cluster_dict[i] = 0
        for node in clusters[i]:
            H_cluster_dict[i] = H_cluster_dict[i]+H_dict[node]
        H_cluster_dict[i] = H_cluster_dict[i] / len(clusters[i])
            
    count = 0
    output = []
    for key, value in sorted(H_cluster_dict.items(), key=operator.itemgetter(1),reverse=rev):
        if(count == n):
            break
        output.append(key)
        count = count+1

    return output


In [5]:
#http://chris-said.io/2016/02/13/how-to-make-polished-jupyter-presentations-with-optional-code-visibility/
from IPython.display import HTML

HTML('''<script>
function code_toggle() {
    if (code_shown){
      $('div.input').hide('300');
      $('#toggleButton').val('Show Code')
    } else {
      $('div.input').show('300');
      $('#toggleButton').val('Hide Code')
    }
    code_shown = !code_shown
  }

  $( document ).ready(function(){
    code_shown=false;
    $('div.input').hide()
  });
  $(document).ready(function(){
    $('div.prompt').hide();
    $('div.back-to-top').hide();
    $('nav#menubar').hide();
    $('.breadcrumb').hide();
    $('.hidden-print').hide();
  });
</script>
<form action="javascript:code_toggle()"><input type="submit" id="toggleButton" value="Show Code"></form>
''')