In [None]:
from ipywidgets import HBox, VBox, Layout, FloatRangeSlider, FloatSlider, IntSlider, SelectionSlider, Checkbox, Button, ValueWidget, link, interact, Label
from ipywidgets.widgets.widget_description import DescriptionWidget
import plotly.graph_objs as go
import plotly.express as px
import mdtraj as md
import networkit as nk
import numpy as np
import timeit

class NetworKitGraphWidget(HBox):
    def __init__(self, pdbFile, trajFile = None, width = 2000, customColorScale = None):
        super(NetworKitGraphWidget, self).__init__()

        # Basic widget values
        if customColorScale:
            self.colorScale = customColorScale
        else:
            self.colorScale = px.colors.sequential.Plasma
        self.width = width
        self.proteinFile = pdbFile
        self.square_contact_map = None
        self.numTrajectories = 500
        
        # Value history
        self.current = {'algorithm' : 'None', 'distance' : [0.5] * self.numTrajectories, 'trajectory' : 0, 'scores' : None, 'changeEvent' : None}
        self.history = {'algorithm' : 'None', 'distance' : [0.5] * self.numTrajectories, 'trajectory' : self.current['trajectory'], 'scores' : None}
        self.maxCoordinate = 0.0
        
        # Left VBox (static 3D graph, distance slider, recompute checkbox)
        self.proteinLayoutGraph = go.FigureWidget()
        self.distanceSlider = FloatSlider(min=0.1, max=1.0, step=0.05, value=self.current['distance'][self.current['trajectory']], continuous_update=False)
        self.algorithmSlider = SelectionSlider(options=[('None', 'None'), ('Betweenness Centrality', 'BetCen'), ('Closeness Centrality', 'CloCen'), ('Degree Centrality', 'DegCen'), ('Eigenvector Centrality', 'EigCen'), ('Katz Centrality', 'KatCen'), ('PLM Community Detection', 'PLMComDet'), ('PLP Community Detection', 'PLPComDet')], value=self.current['algorithm'], layout=Layout(align_items='stretch', width='60%'))
        self.trajectorySlider = IntSlider(min=0, max=(self.numTrajectories-1), value=0, continuous_update=False)
        self.recomputeCheckbox = Checkbox(value=True, description='Automatic recompute')
        self.recomputeButton = Button(description='Recompute',disabled = True)
        self.residueIdColoringCheckbox = Checkbox(value=False, description='ID coloring', tooltip='Only active when no graph measure (None) is selected.')
        self.leftVBox = VBox([self.proteinLayoutGraph, 
                              HBox([Label('Trajectory'), self.trajectorySlider]),
                              HBox([Label('Edge Distance cut-off (Å)'), self.distanceSlider]), 
                              HBox([self.recomputeButton, self.recomputeCheckbox, self.residueIdColoringCheckbox])])
        
        # Right VBox (dynamic 3D graph, algorithm slider)
        box_layout = Layout(display='flex',
                            flex_flow='column',
                            align_items='stretch',
                            width='80%')
        self.generatedLayoutGraph = go.FigureWidget()
        self.rightVBox = VBox([self.generatedLayoutGraph, HBox([Label('Graph Measure'), self.algorithmSlider], layout = box_layout)])
        self.children = [self.leftVBox, self.rightVBox]
        
        # Convert protein to graph
        # There are two cases. If an opt. trajectory file is given, we try to load all trajectories. Otherwise only the topology from pdb is loaded.
        if trajFile == None:
            pdb = md.load_pdb(pdbFile)
            self.protein = pdb.remove_solvent()
        else:
            self.protein = md.load(trajFile, top = pdbFile, stride = 20)
        self.__prepareRIN()
        self.__generateGraphs()
        self.__initializeLimits()
        
        # Generate layout and initialize dual-plotly widget
        generatedCoordinates = self.__generateLayout(self.current['trajectory'])
        self.__initializePlotlyWidget(self.generatedLayoutGraph, self.current['trajectory'], 'Layout: Maxent-Stress', generatedCoordinates)
        proteinCoordinates = self.__getProteinLayout(self.current['trajectory'])
        self.__initializePlotlyWidget(self.proteinLayoutGraph, self.current['trajectory'], 'Layout: Protein-based', proteinCoordinates)
        
        # Add event hooks
        self.algorithmSlider.observe(self.__onAlgorithmChange, names='value')
        self.distanceSlider.observe(self.__onCutoffChange, names='value')
        self.trajectorySlider.observe(self.__onTrajectoryChange, names='value')
        self.recomputeButton.on_click(self.__recomputeClick)
        self.recomputeCheckbox.observe(self.__recomputeToggle, names='value')
        self.residueIdColoringCheckbox.observe(self.__colorNodesById, names='value')

    def benchmarkAlgorithms(self, numRuns, cutoff, frame, only_nwk = True):
        benchmarkData = {'BetCen' : [0.0] * numRuns,  'CloCen' : [0.0] * numRuns, 'DegCen' : [0.0] * numRuns, 'EigCen' : [0.0] * numRuns, 'KatCen' : [0.0] * numRuns, 'PLMComDet' : [0.0] * numRuns, 'PLPComDet' : [0.0] * numRuns}
        self.__onCutoffChange({'new' : cutoff}, triggerUpdate = True)
        for i in range(numRuns):
            for key, value in benchmarkData.items():
                start = timeit.default_timer()
                if only_nwk:
                    self.__calculateNodesUpdate(key, frame)
                else:
                    self.__onAlgorithmChange({'new' : key}, triggerUpdate = True)
                end = timeit.default_timer()
                benchmarkData[key][i] = 1000 * (end - start) 
        generalData = {'n' : self.graphs[0].numberOfNodes(), 'e' : self.graphs[0].numberOfEdges()}
        return [generalData, benchmarkData]

    def benchmarkTrajectory(self, numRuns, cutoff):
        benchmarkData = [0.0] * numRuns
        generalData = {'n' : 0, 'e' : 0}
        for i in range(numRuns):
            for j in range(500):
                start = timeit.default_timer()
                self.__generateEdges(j, 0.1, cutoff)
                self.__generateLayout(j)
                end = timeit.default_timer()
                benchmarkData[i] += (end - start)
                generalData['n'] += self.graphs[j].numberOfNodes()
                generalData['e'] += self.graphs[j].numberOfEdges()
        generalData['n'] = generalData['n'] / (numRuns * 500)
        generalData['e'] = generalData['e'] / (numRuns * 500)
        return [generalData, benchmarkData]
    
    def benchmarkCompleteTrajectory(self, numRuns, cutoff):
        benchmarkData = [0.0] * numRuns
        self.distanceSlider.value = cutoff
        for i in range(numRuns):
            for j in range(500):
                start = timeit.default_timer()
                self.__onTrajectoryChange({'new' : j}, triggerUpdate = True)
                end = timeit.default_timer()
                benchmarkData[i] += (end - start)
        return benchmarkData

    def benchmarkCompleteCutOff(self, numRuns):
        benchmarkData = {0.3 : [0.0] * numRuns, 0.4 : [0.0] * numRuns, 0.5 : [0.0] * numRuns, 0.6 : [0.0] * numRuns, 0.7 : [0.0] * numRuns, 0.8 : [0.0] * numRuns, 0.9 : [0.0] * numRuns, 1.0 : [0.0] * numRuns}
        for i in range(numRuns):
            for key, value in benchmarkData.items():
                start = timeit.default_timer()
                self.__onCutoffChange({'new' : key}, triggerUpdate = True)
                end = timeit.default_timer()
                benchmarkData[key][i] = 1000 * (end - start)
        return benchmarkData
    
    def benchmarkEdgeLayoutGeneration(self, numRuns):
        benchmarkDataEdge = {0.3 : [0.0] * numRuns, 0.4 : [0.0] * numRuns, 0.5 : [0.0] * numRuns, 0.6 : [0.0] * numRuns, 0.7 : [0.0] * numRuns, 0.8 : [0.0] * numRuns, 0.9 : [0.0] * numRuns, 1.0 : [0.0] * numRuns}
        benchmarkDataLayout = {0.3 : [0.0] * numRuns, 0.4 : [0.0] * numRuns, 0.5 : [0.0] * numRuns, 0.6 : [0.0] * numRuns, 0.7 : [0.0] * numRuns, 0.8 : [0.0] * numRuns, 0.9 : [0.0] * numRuns, 1.0 : [0.0] * numRuns}
        for i in range(numRuns):
            for key, value in benchmarkDataEdge.items():
                start = timeit.default_timer()
                self.__generateEdges(0, 0.1, key)
                end = timeit.default_timer()
                benchmarkDataEdge[key][i] = 1000 * (end - start)
                start = timeit.default_timer()
                self.__generateLayout(0)
                end = timeit.default_timer()
                benchmarkDataLayout[key][i] = 1000 * (end - start)
        return [benchmarkDataEdge, benchmarkDataLayout]
    
    def __onTrajectoryChange(self, change, triggerUpdate = False):
        # Only recompute if automatic recomputation is active or an update is triggered
        if self.recomputeCheckbox.value == True or triggerUpdate == True:
            # Check for real change
            if self.current['trajectory'] != change['new']:
                self.current['changeEvent'] = 'trajectory'
            self.history['trajectory'] = self.current['trajectory']
            frame = change['new']
            self.current['trajectory'] = frame
            
            if(self.distanceSlider.value != self.current['distance'][frame]):
                self.__generateEdges(frame, 0.1, self.distanceSlider.value)
                self.current['distance'][frame] = self.distanceSlider.value
            
            proteinCoordinates = self.__getProteinLayout(frame)
            generatedCoordinates = self.__generateLayout(frame)
            
            self.proteinLayoutGraph.data = [self.proteinLayoutGraph.data[0]]
            self.__createScatterData(self.proteinLayoutGraph, frame, proteinCoordinates, update = True)

            self.generatedLayoutGraph.data = [self.generatedLayoutGraph.data[0]]
            self.__createScatterData(self.generatedLayoutGraph, frame, generatedCoordinates, update = True)
            
            self.__applyEdgesUpdate(self.proteinLayoutGraph, frame)
            self.__applyEdgesUpdate(self.generatedLayoutGraph, frame)

            self.__onAlgorithmChange({'new' : self.algorithmSlider.value}, triggerUpdate = True)
            

    def __colorNodesById(self, change):
        self.__onAlgorithmChange({'new' : self.algorithmSlider.value}, triggerUpdate = True)
        
    def __recomputeToggle(self, change):
        if change['new'] == True:
            self.recomputeButton.disabled = True
        else:
            self.recomputeButton.disabled = False
        
    def __recomputeClick(self, b):
        self.__onTrajectoryChange({'new' : self.trajectorySlider.value}, triggerUpdate = True)
        
    def __onAlgorithmChange(self, change, triggerUpdate = False):
        # Only recompute if automatic recomputation is active or an update is triggered
        if self.recomputeCheckbox.value == True or triggerUpdate == True:
            # Check for real change
            if self.current['algorithm'] != change['new']:
                self.current['changeEvent'] = 'algorithm'
            self.history['algorithm'] = self.current['algorithm']
            self.current['algorithm'] = change['new']
            frame = self.trajectorySlider.value
            nodesUpdate = self.__calculateNodesUpdate(change['new'], frame)
            self.__applyNodesUpdate(self.proteinLayoutGraph, nodesUpdate)
            self.__applyNodesUpdate(self.generatedLayoutGraph, nodesUpdate)
            
    
    def __calculateNodesUpdate(self, algorithm, frame):
        if algorithm == 'None':
            if(self.residueIdColoringCheckbox.value):
                maxId = self.graphs[frame].upperNodeIdBound()
                scores = list(n/maxId for n in self.graphs[frame].iterNodes())
            else:
                scores = list(self.colorScale[0] for n in self.graphs[frame].iterNodes())
        else:
            self.history['scores'] = self.current['scores']
            if algorithm == 'BetCen':
                cenRunner = nk.centrality.Betweenness(self.graphs[frame], normalized = True)
                cenRunner.run()
                scores = cenRunner.scores()
            if algorithm == 'CloCen':
                cenRunner = nk.centrality.Closeness(self.graphs[frame], True, nk.centrality.ClosenessVariant.Standard)
                cenRunner.run()
                scores = cenRunner.scores()
            if algorithm == 'DegCen':
                degRunner = nk.centrality.DegreeCentrality(self.graphs[frame], normalized = True)
                degRunner.run()
                scores = degRunner.scores()
            if algorithm == 'EigCen':
                eigRunner = nk.centrality.EigenvectorCentrality(self.graphs[frame])
                eigRunner.run()
                scores = eigRunner.scores()
            if algorithm == 'KatCen':
                katRunner = nk.centrality.KatzCentrality(self.graphs[frame])
                katRunner.run()
                scores = katRunner.scores()
            if algorithm == 'PLMComDet':
                PLMRunner = nk.community.PLM(self.graphs[frame])
                PLMRunner.run()
                scores = PLMRunner.getPartition().getVector()
            if algorithm == 'PLPComDet':
                PLPRunner = nk.community.PLP(self.graphs[frame])
                PLPRunner.run()
                scores = PLPRunner.getPartition().getVector()
            self.current['scores'] = scores
        return scores
        
    def __applyNodesUpdate(self, figWidget, nodesUpdate):
        labels = self.__getUpdatedLabel(nodesUpdate)
        figWidget.update_traces(patch=dict(marker=dict(symbol='circle',
                                    size=9,
                                    color=nodesUpdate,
                                    colorscale = self.colorScale
                                    ),
                                    text=labels), 
          selector=dict(name='nodes'))
        
    def __getUpdatedLabel(self, scores):
        if self.current['algorithm'] == 'None':
            scores = [0.0] * self.graphs[0].upperNodeIdBound()
        labels=list('Residue ID: ' + str(resID) + '<br><br>' + '<b>Graph Measure</b>' + '<br>Algorithm: ' + self.algorithmSlider.label + '<br>Score: ' + str(round(score,5)) for resID, score in enumerate(scores))
        # Show change from last computation
        if self.current['algorithm'] != 'None' and (self.current['changeEvent'] == 'trajectory' or self.current['changeEvent'] == 'distance'):
            for idx, label in enumerate(labels):
                diff = scores[idx] - self.history['scores'][idx]
                diffStr = str(round(diff,5))
                if scores[idx] != 0.0:
                    diffStr = diffStr + ' (' + str(round(diff/scores[idx]*100)) + '%)'
                if self.current['changeEvent'] == 'trajectory':
                    first, second = self.history['trajectory'], self.current['trajectory']
                else:
                    first, second = self.history['distance'][self.current['trajectory']], self.current['distance'][self.current['trajectory']]
                labels[idx] = label + '<br>' + 'Change (' + self.current['changeEvent'] + ' ' + str(first) + ' -> ' + str(second) + '): ' + diffStr            
        return labels
        
    def __applyEdgesUpdate(self, figWidget, frame):
        figWidget.update_layout(dict1=dict(
            annotations=[
               dict(
               showarrow=False,
                text='Number of Edges=' + str(self.graphs[frame].numberOfEdges()),
                xref='paper',
                yref='paper',
                x=0,
                y=-0.1,
                xanchor='left',
                yanchor='bottom',
                font=dict(
                size=14
                )
                )
            ]), 
            overwrite=True)
        
    def __onCutoffChange(self, change, triggerUpdate = False):
        # Only recompute if automatic recomputation is active or an update is triggered
        if self.recomputeCheckbox.value == True or triggerUpdate == True:
            frame = self.trajectorySlider.value    
            # Check for real change
            if(self.current['distance'][frame] != change['new']):
                self.__generateEdges(frame, 0.1, change['new'])
                self.current['changeEvent'] = 'distance'

            self.history['distance'][frame] = self.current['distance'][frame]
            self.current['distance'][frame] = change['new']
                
            proteinCoordinates = self.__getProteinLayout(frame)
            generatedCoordinates = self.__generateLayout(frame)
            
            self.proteinLayoutGraph.data = [self.proteinLayoutGraph.data[0]]
            self.__createAndAddEdgeScatterData(self.proteinLayoutGraph, frame, proteinCoordinates)
            
            self.generatedLayoutGraph.data = [self.generatedLayoutGraph.data[0]]
            self.__createScatterData(self.generatedLayoutGraph, frame, generatedCoordinates, update = True)
            
            self.__applyEdgesUpdate(self.proteinLayoutGraph, frame)
            self.__applyEdgesUpdate(self.generatedLayoutGraph, frame)

            self.__onAlgorithmChange({'new' : self.algorithmSlider.value}, triggerUpdate = True)

    def __prepareRIN(self):
        pair_matrix = np.zeros((self.protein.n_residues,self.protein.n_residues))
        all_pairs = np.vstack((np.triu_indices_from(pair_matrix, k = 1)[0], 
                                   np.triu_indices_from(pair_matrix, k = 1)[1]))
        all_pairs = all_pairs.transpose()
        self.contact_list = md.compute_contacts(self.protein, 
                          contacts= all_pairs, 
                          scheme= 'closest', 
                          ignore_nonprotein=False)

    def __generateGraphs(self):
        self.graphs = [None] * self.protein.n_frames
        for frame in range(0,self.protein.n_frames):
            self.graphs[frame] = nk.graph.Graph(self.protein.n_residues)
            self.__generateEdges(frame, 0.1, self.distanceSlider.value)

    def __generateEdges(self, frame, low, high):
        self.graphs[frame].removeAllEdges()
        edge_list = self.contact_list[1][np.logical_and(self.contact_list[0][frame,:] > low, self.contact_list[0][frame,:] < high)]
        for edge in edge_list:
            self.graphs[frame].addEdge(edge[0], edge[1])

    def __initializeLimits(self):
        self.current['scores'] = [0.0] * self.graphs[0].upperNodeIdBound()
        self.history['scores'] = [0.0] * self.graphs[0].upperNodeIdBound()
        for frame in range(0,self.protein.n_frames):
            coordinates = self.__getProteinLayout(frame)
            for pos in coordinates:
                self.maxCoordinate = max(self.maxCoordinate, max(pos))

    def __getProteinLayout(self, frame):
        C_alphas = self.protein.atom_slice(self.protein.top.select('name CA'))
        return C_alphas.xyz[frame]

    def __generateLayout(self, frame, distance=3, fast=1):
        maxentS = nk.viz.MaxentStress(self.graphs[frame],3,distance, fastComputation=fast, graphDistance=0)
        maxentS.run()
        return maxentS.getCoordinates()
    
    def __createScatterData(self, figWidget, frame, coordinates, update = False):
        self.__createAndAddNodeScatterData(figWidget, frame, coordinates, update)
        self.__createAndAddEdgeScatterData(figWidget, frame, coordinates)
    
    def __createAndAddNodeScatterData(self, figWidget, frame, coordinates, update = False):
        nodes = [[],[],[]]
        nodes[0] = [coordinates[k][0] for k in self.graphs[0].iterNodes()]
        nodes[1] = [coordinates[k][1] for k in self.graphs[0].iterNodes()]
        nodes[2] = [coordinates[k][2] for k in self.graphs[0].iterNodes()]

        if not update:
            nodeScatter = go.Scatter3d(x=nodes[0],
                       y=nodes[1],
                       z=nodes[2],
                       mode='markers',
                       name='nodes',
                       marker=dict(symbol='circle',
                                                size=9,
                                                colorscale=self.colorScale,
                                                color=self.colorScale[0],
                                                line=dict(color='rgb(50,50,50)', width=0.5)
                                                ),
                       hoverinfo='text'
                       )

            figWidget.add_traces(nodeScatter)
        else:
            figWidget.update_traces(patch=dict(x=nodes[0], y=nodes[1], z=nodes[2]), 
                selector=dict(name='nodes'))

    def __createAndAddEdgeScatterData(self, figWidget, frame, coordinates):
        edges = [[None] * self.graphs[frame].numberOfEdges() * 3,[None] * self.graphs[frame].numberOfEdges() * 3, [None] * self.graphs[frame].numberOfEdges() * 3]
        index = 0
        for e in self.graphs[frame].iterEdges():
            edges[0][index] = coordinates[e[0]][0]
            edges[0][index+1] = coordinates[e[1]][0]
            edges[1][index] = coordinates[e[0]][1]
            edges[1][index+1] = coordinates[e[1]][1]
            edges[2][index] = coordinates[e[0]][2]
            edges[2][index+1] = coordinates[e[1]][2]
            index = index + 3
        edgeScatter = go.Scatter3d(x=edges[0],
           y=edges[1],
           z=edges[2],
           mode='lines',
           opacity=0.7,
           line= dict(color='rgb(180,180,180)', width=2),
           hoverinfo='none',
           showlegend=None,
           name='edges',
           )
            
        figWidget.add_traces(edgeScatter)

    def __initializePlotlyWidget(self, figWidget, frame, title, coordinates):
        self.__createScatterData(figWidget, frame, coordinates)
        self.__applyNodesUpdate(figWidget, list(self.colorScale[0] for n in self.graphs[frame].iterNodes())) 
        
        axis=dict(showline=False, # hide axis line, grid, ticklabels and  title
          zeroline=False,
          showgrid=True,
          showticklabels=True,
          title='',
          autorange=False,
          range=[-self.maxCoordinate, self.maxCoordinate]
          )

        layout=go.Layout(title= "Protein: " + self.proteinFile + "<br> " + title + "<br> Nodes: " + str(self.graphs[0].numberOfNodes()),
            font= dict(size=8),
            showlegend=False,
            autosize=True,
            width=self.width * 0.5,
            height=self.width * 0.5,
            scene_aspectmode='cube',
            scene=dict(
                xaxis=dict(axis),
                yaxis=dict(axis),
                zaxis=dict(axis),
                bgcolor='white',
                camera=dict(
                    eye={'x' : 1.3, 'y' : 1.3, 'z' : 0.5})
            ),
            margin=go.layout.Margin(
                l=10,
                r=10,
                b=85,
                t=100,
            ),
            hovermode='closest',
            annotations=[
                   dict(
                   showarrow=False,
                    text='Number of Edges=' + str(self.graphs[frame].numberOfEdges()),
                    xref='paper',
                    yref='paper',
                    x=0,
                    y=-0.1,
                    xanchor='left',
                    yanchor='bottom',
                    font=dict(
                    size=12
                    )
                    )
                ]
            )
        figWidget.layout = layout