-
Notifications
You must be signed in to change notification settings - Fork 717
/
interactive.py
237 lines (178 loc) · 7.1 KB
/
interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license
import sys
import logging
from ..provider.visualize import AutoVisualizeProvider, PreserveProvider, DashProvider
log = logging.getLogger(__name__)
this = sys.modules[__name__]
this._preserve_provider = None
this.visualize_provider = None
def get_visualize_provider():
""" Gets visualization provider for show() related calls.
Returns:
Visualization provider.
"""
return this.visualize_provider
def set_visualize_provider(provider):
""" Sets visualization provider for show() related calls.
Args:
provider: Visualization provider found in "interpret.provider.visualize".
"""
has_render_method = hasattr(provider, "render")
if provider is None or has_render_method:
this.visualize_provider = provider
else: # pragma: no cover
raise ValueError(
"Object of type {} is not a visualize provider.".format(type(provider))
)
def set_show_addr(addr):
""" Set a (ip, port) for inline visualizations and dashboard. Has side effects stated below.
Side effect: restarts the app runner for 'show' method.
Args:
addr: (ip, port) tuple as address to assign show method to.
Returns:
None.
"""
addr = (addr[0], int(addr[1]))
init_show_server(addr)
def get_show_addr():
""" Returns (ip, port) used for show method.
Returns:
Address tuple (ip, port).
"""
if isinstance(this.visualize_provider, DashProvider):
addr = (
this.visualize_provider.app_runner.ip,
this.visualize_provider.app_runner.port,
)
return addr
else:
return None
def shutdown_show_server():
""" This is a hard shutdown method for the show method's backing server.
Returns:
True if show server has stopped.
"""
if isinstance(this.visualize_provider, DashProvider):
return this.visualize_provider.app_runner.stop()
return True # pragma: no cover
def status_show_server():
""" Returns status and associated information of show method's backing server.
Returns:
Status and associated information as a dictionary.
"""
status_dict = {}
if isinstance(this.visualize_provider, DashProvider):
status_dict["app_runner_exists"] = True
status_dict.update(this.visualize_provider.app_runner.status())
else:
status_dict["app_runner_exists"] = False
return status_dict
def init_show_server(addr=None, base_url=None, use_relative_links=False):
""" Initializes show method's backing server.
Args:
addr: (ip, port) tuple as address to assign show method to.
base_url: Base url path as string. Used mostly when server is running behind a proxy.
use_relative_links: Use relative links for what's returned to client. Otherwise have absolute links.
Returns:
None.
"""
# If the user uses old methods such as init_show_server, we do an immediate override to the visualization provider.
if isinstance(this.visualize_provider, DashProvider):
log.info("Stopping previous dash provider")
shutdown_show_server()
log.info(
"Replacing visualize provider: {} with {}".format(
type(this.visualize_provider), type(DashProvider)
)
)
set_visualize_provider(
DashProvider.from_address(
addr=addr, base_url=base_url, use_relative_links=use_relative_links
)
)
this.visualize_provider.idempotent_start()
addr = (
this.visualize_provider.app_runner.ip,
this.visualize_provider.app_runner.port,
)
log.info("Running dash provider at {}".format(addr))
return None
def _get_integer_key(key, explanation):
if key is not None and not isinstance(key, int):
series = explanation.selector[explanation.selector.columns[0]]
if key not in series.values: # pragma: no cover
raise ValueError("Key {} not in explanation's selector".format(key))
key = series[series == key].index[0]
return key
def show(explanation, key=-1, **kwargs):
""" Provides an interactive visualization for a given explanation(s).
By default, visualization provided is not preserved when the notebook exits.
Args:
explanation: Either a scalar Explanation or list of Explanations to render as visualization.
key: Specific index of explanation to visualize.
**kwargs: Kwargs passed down to provider's render() call.
Returns:
None.
"""
try:
# Get explanation key
key = _get_integer_key(key, explanation)
# Set default render if needed
if this.visualize_provider is None:
this.visualize_provider = AutoVisualizeProvider()
# Render
this.visualize_provider.render(explanation, key=key, **kwargs)
except Exception as e: # pragma: no cover
log.error(e, exc_info=True)
raise e
return None
def show_link(explanation, share_tables=None):
""" Provides the backing URL link behind the associated 'show' call for explanation.
Args:
explanation: Either a scalar Explanation or list of Explanations
that would be provided to 'show'.
share_tables: Boolean or dictionary that dictates if Explanations
should all use the same selector as provided to 'show'.
(table used for selecting in the Dashboard).
Returns:
URL as a string.
"""
# Initialize server if needed
if not isinstance(this.visualize_provider, DashProvider): # pragma: no cover
init_show_server()
# Register
this.visualize_provider.app_runner.register(explanation, share_tables=share_tables)
try:
url = this.visualize_provider.app_runner.display_link(explanation)
return url
except Exception as e: # pragma: no cover
log.error(e, exc_info=True)
raise e
def preserve(explanation, selector_key=None, file_name=None, **kwargs):
""" Preserves an explanation's visualization for Jupyter cell, or file.
If file_name is not None the following occurs:
- For Plotly figures, saves to HTML using `plot`.
- For dataframes, saves to HTML using `to_html`.
- For strings (html), saves to HTML.
- For Dash components, fails with exception. This is currently not supported.
Args:
explanation: An explanation.
selector_key: If integer, treat as index for explanation. Otherwise, looks up value in first column, gets index.
file_name: If assigned, will save the visualization to this filename.
**kwargs: Kwargs which are passed to the underlying render/export call.
Returns:
None.
"""
if this._preserve_provider is None:
this._preserve_provider = PreserveProvider()
try:
# Get explanation key
key = _get_integer_key(selector_key, explanation)
this._preserve_provider.render(
explanation, key=key, file_name=file_name, **kwargs
)
return None
except Exception as e: # pragma: no cover
log.error(e, exc_info=True)
raise e