In [1]:
from decomp import UDSCorpus
from decomp.semantics.uds import UDSGraph
from typing import List, Dict
import dash
import dash_core_components as dcc
import dash_html_components as html
import networkx as nx
import plotly.graph_objs as go
import numpy as np
import matplotlib


In [2]:
# problem = UDSCorpus(split="train")['ewt-train-9']
problem = UDSCorpus(split="dev")['ewt-dev-1']

In [64]:
class UDSVisualization:
    def __init__(self,
                 graph: UDSGraph, 
                 add_span_edges: bool = True,
                 add_syntax_edges: bool = False,
                 syntax_y: float = 0.0,
                 semantics_y: float = 10.0,
                 node_offset: float = 7.0) -> None:
        self.graph = graph
        self.annotations = []
        self.trace_list = []
        self.node_to_xy = {}
        self.do_shorten = True if len(self.graph.syntax_subgraph) > 10 else False
        self.syntax_y = syntax_y
        self.semantics_y = semantics_y
        self.node_offset = node_offset
        self.added_edges = []
        self.add_span_edges = add_span_edges
        self.add_syntax_edges = add_syntax_edges
        
    def _format_line(self, start, end):
        # format a line between dependents 
        if start == end:
            return None, None, None
        
        x0, y0 = start
        x1, y1 = end
        if x0 > x1:
            x1, x0 = x0, x1
            y1, y0 = y0, y1
        offset = x1-x0
     
        height_factor = 1/(4*offset)
        x_range = np.linspace(x0, x1, num=100)
        
        different = y1 != y0
        
        if different:
            if y0 < y1 and x0 < x1:
                x1_root = 4*(y1 - y0) + x1 
                y_range = -height_factor * (x_range - x0) * (x_range-x1_root) + y0
            elif y0 > y1 and x0 < x1: 
                x0_root = -4*(y0-y1) + x0
                y_range = -height_factor * (x_range - x0_root) * (x_range-x1) + y1
            else:
                raise ValueError
        else:
            y_range = -height_factor * (x_range - x0)*(x_range - x1) + y0
              
#         arrow_head = self._add_arrowhead((x1,y1), x0, x1, "right")
        return x_range, y_range, np.max(y_range)
    
    def _add_arrowhead(self, point, root0, root1, direction, color="black"):
        # get tangent line at point
        x,y = point
        if direction in ["left", "right"]:
            derivative = 1/(4*(root1-root0)) * (2*x - root0 - root1) 
            theta_rad = np.arctan(derivative)
        else:
            # downward at a slope 
            if x != root0:
#                 derivative = (y - root1)/(x - root0)
                derivative = (y-root1)/(x-root0)
                theta_rad = 3.14 - np.arctan(derivative)
            else:
                theta_rad = 3.14/2
            
        l = 1
        x0 = x
        y0 = y
        x1 = x - l
        x2 = x - l 
        y1 = y + 0.1*l
        y2 = y - 0.1*l
        
#         self.trace_list.append(arrow_before)

#         vertices = [[x0,y0], [x1, y1], [x2, y2], [x0,y0]]
        # put at origin
        vertices = [[0, 0], [x1-x0, y1-y0], [x2-x0, y2-y0], [0,0]]
        
        width = 1
        if direction == "left":
            arrowhead_transformation = (matplotlib.transforms.Affine2D()
                                        .rotate_around(0,0,theta_rad)
                                        .rotate_around(0,0,3.14)
                                        .translate(x0, y0)
                                        .frozen())
        else:
            arrowhead_transformation = (matplotlib.transforms.Affine2D()
                                        .rotate_around(0,0,-theta_rad)
#                                         .rotate_around(0,0,3.14 - theta_rad)
                                        .translate(x0, y0)
                                        .frozen())

        vertices_prime = [arrowhead_transformation.transform_point((x,y)) for (x,y) in vertices]
        x0_prime, y0_prime = vertices_prime[0] 
        x1_prime, y1_prime = vertices_prime[1]
        x2_prime, y2_prime = vertices_prime[2]
        
        arrow = go.Scatter(x=[x0_prime , x1_prime , x2_prime , x0_prime ], 
                           y=[y0_prime , y1_prime , y2_prime , y0_prime ],
                           hoverinfo='skip',
                           mode='lines',
                           fill='toself',
                           line={'width': 0.5, "color":color},
                           fillcolor=color,
                          )
        self.trace_list.append(arrow)

        
    def _get_attribute_str(self, node: str, is_node:bool=True) -> str:
        NODE_ONTOLOGY = ['factuality-factual', 'genericity-arg-abstract', 'genericity-arg-kind', 'genericity-arg-particular', 'genericity-pred-dynamic', 'genericity-pred-hypothetical', 'genericity-pred-particular', 'time-dur-centuries', 'time-dur-days', 'time-dur-decades', 'time-dur-forever', 'time-dur-hours', 'time-dur-instant', 'time-dur-minutes', 'time-dur-months', 'time-dur-seconds', 'time-dur-weeks', 'time-dur-years', 'wordsense-supersense-noun.Tops', 'wordsense-supersense-noun.act', 'wordsense-supersense-noun.animal', 'wordsense-supersense-noun.artifact', 'wordsense-supersense-noun.attribute', 'wordsense-supersense-noun.body', 'wordsense-supersense-noun.cognition', 'wordsense-supersense-noun.communication', 'wordsense-supersense-noun.event', 'wordsense-supersense-noun.feeling', 'wordsense-supersense-noun.food', 'wordsense-supersense-noun.group', 'wordsense-supersense-noun.location', 'wordsense-supersense-noun.motive', 'wordsense-supersense-noun.object', 'wordsense-supersense-noun.person', 'wordsense-supersense-noun.phenomenon', 'wordsense-supersense-noun.plant', 'wordsense-supersense-noun.possession', 'wordsense-supersense-noun.process', 'wordsense-supersense-noun.quantity', 'wordsense-supersense-noun.relation', 'wordsense-supersense-noun.shape', 'wordsense-supersense-noun.state', 'wordsense-supersense-noun.substance', 'wordsense-supersense-noun.time']
        EDGE_ONTOLOGY =  ['protoroles-awareness', 'protoroles-change_of_location', 'protoroles-change_of_possession', 'protoroles-change_of_state', 'protoroles-change_of_state_continuous', 'protoroles-existed_after', 'protoroles-existed_before', 'protoroles-existed_during', 'protoroles-instigation', 'protoroles-partitive', 'protoroles-sentient', 'protoroles-volition', 'protoroles-was_for_benefit', 'protoroles-was_used']
        # format attribute string for hovering
        to_ret = ""
        if is_node:
            for attr in NODE_ONTOLOGY:
                try:
                    val = self.graph.nodes[node][attr]
                except KeyError:
                    continue
                to_ret += f"{attr}: {val}<br>"
        else:
            for attr in EDGE_ONTOLOGY:
                try:
                    val = self.graph.edges[node][attr]
                except KeyError:
                    continue
                to_ret += f"{attr}: {val}<br>"

        return to_ret
    
    def _get_xy_from_edge(self, node_0, node_1):
        try:
            x0,y0 = self.node_to_xy[node_0]
            x1,y1 = self.node_to_xy[node_1]
            return (x0, y0, x1, y1)
        except KeyError:
            # addresse, root, speaker nodes
            return None
    
        
    def _add_syntax_nodes(self):
        syntax_layer = self.graph.syntax_subgraph
        syntax_node_trace = go.Scatter(x=[], y=[],hovertext=[], text=[], 
                                        mode='markers+text', textposition="bottom center",
                                        hoverinfo="text", 
                                        marker={'size': 15, 'color': 'LightSkyBlue'})
        for i, node in enumerate(syntax_layer):
            node_idx = int(node.split("-")[-1])
            syntax_node_trace['x'] += tuple([node_idx * self.node_offset])
            # alternate heights
            y = self.syntax_y + i%2*0.5
            syntax_node_trace['y'] += tuple([y])
            self.node_to_xy[node] = (node_idx * self.node_offset, y)
            
            syntax_node_trace['hovertext'] += tuple([self.graph.nodes[node]['form']])
            if self.do_shorten:
                syntax_node_trace['text'] += tuple([self.graph.nodes[node]['form'][0:3]])
            else:
                syntax_node_trace['text'] += tuple([self.graph.nodes[node]['form']])
                
        self.trace_list.append(syntax_node_trace)
        
    def _add_semantics_nodes(self):
        semantics_layer = self.graph.semantics_subgraph
        semantics_node_trace = go.Scatter(x=[], y=[], hovertext=[], text=[], mode='markers+text', textposition="top center",
                                hoverinfo="text", marker={'size': 20, 'color': 'firebrick'})
        for i, node in enumerate(semantics_layer):
            try:
                node_idx = int(node.split("-")[-1])
            except ValueError:
                # addressee, root, speaker nodes
                node_idx = 0

            semantics_node_trace['x'] += tuple([node_idx * self.node_offset])
            semantics_node_trace['y'] += tuple([self.semantics_y])
            semantics_node_trace['text'] += tuple([node_idx])
            semantics_node_trace['hovertext'] += tuple([self._get_attribute_str(node, is_node=True)])
            self.node_to_xy[node] = (node_idx * self.node_offset, self.semantics_y)

        self.trace_list.append(semantics_node_trace)
        
    def _add_syntax_edges(self):
        
        for (node_0, node_1) in self.graph.syntax_subgraph.edges:
            try:
                x0,y0,x1,y1 = self._get_xy_from_edge(node_0, node_1)
            except TypeError:
                continue
            x_range, y_range, height = self._format_line((x0,y0), (x1,y1))
            if x_range is None:
                continue

            edge_trace = go.Scatter(x=tuple(x_range), y=tuple(y_range),
                                   hoverinfo='skip',
                                   mode='lines',
                                   line={'width': 0.5},
                                   marker=dict(color='blue'),
                                   line_shape='spline',
                                   opacity=1)
            self.trace_list.append(edge_trace)
            if x1 < x0:
                direction = "left"
            else:
                direction = "right"
                
            self._add_arrowhead((x1,y1), x0, x1, direction, color="blue")

    def _add_semantics_edges(self):
        for (node_0, node_1) in self.graph.semantics_subgraph.edges:
            try:
                x0,y0,x1,y1 = self._get_xy_from_edge(node_0, node_1)
            except TypeError:
                continue

            # add a curve above for all semantic relations 
            x_range, y_range, height = self._format_line((x0,y0), (x1,y1))
            if x_range is None:
                continue 
                
            edge_trace = go.Scatter(x=tuple(x_range), y=tuple(y_range),
                                   hoverinfo='skip',
                                   mode='lines',
                                   line={'width': 1},
                                   marker=dict(color='black'),
                                   line_shape='spline',
                                   opacity=1)

            x_mid = x_range[int(len(x_range)/2)]
            attributes = self._get_attribute_str((node_0, node_1), is_node=False)
            if len(attributes) > 0:
                midpoint_trace = go.Scatter(x=tuple([x_mid]), y=tuple([height]), 
                                            hovertext=attributes,
                                            hoverinfo="text",
                                            mode='markers+text', 
                                            textposition="top center",
                                            marker={'size': 5, 'color': 'black'}
                                           )
                
                self.trace_list.append(midpoint_trace)
            self.trace_list.append(edge_trace)
#             print(f"arrow from {node_0} with {x0} to {node_1} with {x1}")
            if x1 < x0:
                direction = "left"
            else:
                direction = "right"
                
            self._add_arrowhead((x1,y1), x0, x1, direction)
            
    def _add_head_edges(self):  
        semantics_layer = self.graph.semantics_subgraph
        for node_0 in semantics_layer:
            node_name = "-".join(node_0.split("-")[0:3])
            try:
                node_1, __ = self.graph.head(node_0)
                node_1 = f"{node_name}-syntax-{node_1}"
                x0,y0,x1,y1 = self._get_xy_from_edge(node_0, node_1)
            except (KeyError, IndexError, TypeError) as e:
                continue

            edge_trace = go.Scatter(x=tuple([x0, x1]), y=tuple([y0,y1]),
                                   hoverinfo='skip',
                                   mode='lines',
                                   line={'width': 3},
                                   marker=dict(color='grey'),
                                   line_shape='spline',
                                   opacity=1)

            self.trace_list.append(edge_trace)
            self.added_edges.append((node_0, node_1))
            
    def _add_span_edges(self):
        for (node_0, node_1) in self.graph.instance_edges():
            if (node_0, node_1) not in self.added_edges:
                try:
                    x0,y0,x1,y1 = self._get_xy_from_edge(node_0, node_1)
                except (KeyError, TypeError, IndexError) as e:
                    continue
                    
                edge_trace = go.Scatter(x=tuple([x0, x1]), y=tuple([y0,y1]),
                                   hoverinfo='skip',
                                   mode='lines',
                                   line={'width': 1},
                                   marker=dict(color='grey'),
                                   line_shape='spline',
                                   opacity=1)

                self.trace_list.append(edge_trace)
                
                point = (x1, y1)
                self._add_arrowhead(point, x0, y0, "down", color="grey")


    def prepare_graph(self) -> Dict:
        """
        Convert a UDS graph into a Dash-ready layout
        """
        self._add_syntax_nodes()
        self._add_semantics_nodes()
        if self.add_syntax_edges:
            self._add_syntax_edges()
        self._add_semantics_edges()
        self._add_head_edges()
        if self.add_span_edges:
            self._add_span_edges()

        figure = {
                "data": self.trace_list,
                "layout": go.Layout(title=self.graph.name, showlegend=False,
                                    margin={'b': 40, 'l': 0, 'r': 0, 't': 40},
                                    xaxis={'showgrid': False, 'zeroline': False, 'showticklabels': False},
                                    yaxis={'showgrid': False, 'zeroline': False, 'showticklabels': False},
                                    width=600,
                                    height=600),
                }

        return figure
    
    def serve(self):
        figure = self.prepare_graph() 
        external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
        app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
        app.title = self.graph.name


        app.layout = html.Div([
                            html.Div(
                                className="eight columns",
                                children=[dcc.Graph(id="my-graph",
                                                    figure=figure)],
                                    ),
                         ])

        app.run_server(debug=True,use_reloader=False)

In [None]:
if __name__ == "__main__":
    vis = UDSVisualization(problem, add_syntax_edges=True)
    vis.serve()

Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Running on http://127.0.0.1:8050/
Debugger PIN: 223-353-111
Debugger PIN: 223-353-111
Debugger PIN: 223-353-111
Debu

In [74]:
def _add_arrowhead(point, x0, x1, direction):
    # get tangent line at point
    x,y = point
#     print(f"x0{x0} x1 {x1} point {point}")
#     print(f"derivative composed of {-1/(4*(x1-x0))} and {(2*x - x0 - x1)}")
    derivative = -1/(4*(x1-x0)) * (2*x - x0 - x1)
    print(derivative)
#     theta_degrees = np.arctan(derivative)
    theta_degrees = 15
    theta_radians = 3.14 * theta_degrees/180
#     theta_radians = 90 * 3.14/180
#     print(theta_degrees)
#     print(theta_radians)
    
    #before_triangle: (x0,y0), (x0-1/2 * l, y0 + l), (x0 + 1/2 * l, y0 + l)
    l = 1
    x1 = x0 - l
    x2 = x0 - l 
    y1 = y + 0.5*l
    y2 = y - 0.5*l
    y0 = y
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=[x0, x1, x2, x0], y=[y0, y1, y2, y0],
                    mode='lines',
                    name='original'))
    
    vertices = [[x0,y0], [x1, y1], [x2, y2], [x0,y0]]
    width = 1
    arrowhead_transformation = (matplotlib.transforms.Affine2D()
#                                 .scale(np.hypot(0.5*l,l), width)
                                .rotate(3.14/1)
                                .translate(x0, y0)
                                .frozen())
        
    vertices_prime = [arrowhead_transformation.transform_point((x,y)) for (x,y) in vertices]
    x0_prime, y0_prime = vertices_prime[0]
    x1_prime, y1_prime = vertices_prime[1]
    x2_prime, y2_prime = vertices_prime[2]
    
    print(x0_prime, y0_prime)
    print(x1_prime, y1_prime)
    print(x2_prime, y2_prime)
    
    fig.add_trace(go.Scatter(x=[x0_prime, x1_prime, x2_prime, x0_prime], y=[y0_prime, y1_prime, y2_prime, y0_prime],
                    mode='lines',
                    name='rotated'))
    
#     fig.update_xaxes(showgrid=True, gridwidth=1, range=[12, 16])
#     fig.update_yaxes(showgrid=True, gridwidth=1, range=[0, 4])
    fig.update_layout(width = 800,
                      height = 800,)

    
    fig.show()
    
    
    
# point = (21, 0)
# x0, x1 = (14, 21)

point = (0,0)
x0, x1 = (1,0)

_add_arrowhead(point, x0, x1, "right")


-0.25
1.2682724604973217e-06 0.0015926529164868282
0.9992036735417565 -0.49999936586376975
1.0007963264582433 0.49999936586376975
