-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
105 lines (91 loc) · 3.57 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from qa.question_answering import QuestionAnswering
from qa.data_loader import MetaQADataLoader
import gradio as gr
import networkx as nx
import plotly.graph_objects as go
class QADemo:
def __init__(self):
self.data_loader = MetaQADataLoader('./data')
self.qa = QuestionAnswering('navidmadani/nl2logic_t5small_metaqa', self.data_loader)
def build_graph(self, query_trace, relation_trace, question_entity):
G = nx.DiGraph()
G.add_node(question_entity, color='red', name=question_entity)
id2hop = {'X': 1, 'Y': 2, 'Z': 3}
prev_node = question_entity
for ans in query_trace:
for id, ent in ans.items():
G.add_node(ent, color='blue', name=ent)
G.add_edge(prev_node, ent, label=f'{relation_trace[id2hop[id]-1]}')
prev_node = ent
prev_node = question_entity
return G
def run(self, question):
answer_component = self.qa.answer_question(question)
G = self.build_graph(answer_component['trace'], answer_component['relation_trace'], answer_component['qent'])
pos = nx.spring_layout(G)
node_trace = go.Scatter(
x=[],
y=[],
text=[],
mode='markers',
hoverinfo='text',
marker=dict(
color=[],
size=15,
line_width=2
)
)
for node in G.nodes():
x, y = pos[node]
node_trace['x'] += tuple([x])
node_trace['y'] += tuple([y])
if str(node) in answer_component['answers']:
node_trace['marker']['color'] += tuple(['green'])
elif str(node) == answer_component['qent']:
node_trace['marker']['color'] += tuple(['orange'])
else:
node_trace['marker']['color'] += tuple(['black'])
node_trace['text'] += tuple([str(node)]) # Display node name
edge_trace = go.Scatter(
x=[],
y=[],
text=[],
line=dict(width=1.0, color='#888'),
hoverinfo='text',
mode='lines'
)
for edge, label in zip(G.edges(), nx.get_edge_attributes(G, 'label').values()):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace['x'] += tuple([x0, x1, None])
edge_trace['y'] += tuple([y0, y1, None])
edge_trace['text'] += tuple(label) # Add edge names
# Create annotations for edge names
annotations = []
for edge, label in zip(G.edges(), nx.get_edge_attributes(G, 'label').values()):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
annotation = go.Annotation(
x=(x0 + x1) / 2,
y=(y0 + y1) / 2,
text=label,
showarrow=False
)
annotations.append(annotation)
fig = go.Figure(
data=[edge_trace, node_trace],
layout=go.Layout(
title='Graph',
titlefont=dict(size=16),
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
annotations=annotations, # Add edge name annotations
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)
)
return fig, answer_component
qa_demo = QADemo()
demo = gr.Interface(fn=qa_demo.run, inputs="text", outputs=["plot", "text"])
demo.launch()