-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit_shap.py
126 lines (97 loc) · 4.38 KB
/
streamlit_shap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import base64
# Shap plots internally call plt.show()
# On Linux, prevent plt.show() from emitting a non-GUI backend warning.
import os
from io import BytesIO
import matplotlib.pyplot as plt
import shap
import streamlit.components.v1 as components
from matplotlib.figure import Figure
os.environ.pop("DISPLAY", None)
# Text plots return a IPython.core.display.HTML object
# Set diplay=False to return HTML string instead
shap.plots.text.__defaults__ = (0, 0.01, "", None, None, None, False)
# Prevent clipping of the ticks and axis labels
plt.rcParams["figure.autolayout"] = True
# Note: Colorbar changes (introduced bugs) in matplotlib>3.4.3
# cause the colorbar of certain shap plots (e.g. beeswarm) to not display properly
# See: https://github.com/matplotlib/matplotlib/issues/22625 and
# https://github.com/matplotlib/matplotlib/issues/22087
# If colorbars are not displayed properly, try downgrading matplotlib to 3.4.3
def st_shap(plot, height=None, width=None):
"""Takes a SHAP plot as input, and returns a streamlit.delta_generator.DeltaGenerator as output.
It is recommended to set the height and width
parameter to have the plot fit to the window.
Parameters
----------
plot : None or matplotlib.figure.Figure or SHAP plot object
The SHAP plot object.
height: int or None
The height of the plot in pixels.
width: int or None
The width of the plot in pixels.
Returns
-------
streamlit.delta_generator.DeltaGenerator
A SHAP plot as a streamlit.delta_generator.DeltaGenerator object.
"""
# Plots such as waterfall and bar have no return value
# They create a new figure and call plt.show()
if plot is None:
# Test whether there is currently a Figure on the pyplot figure stack
# A Figure exists if the shap plot called plt.show()
if plt.get_fignums():
fig = plt.gcf()
ax = plt.gca()
# Save it to a temporary buffer
buf = BytesIO()
if height is None:
_, height = fig.get_size_inches() * fig.dpi
if width is None:
width, _ = fig.get_size_inches() * fig.dpi
fig.set_size_inches(width / fig.dpi, height / fig.dpi, forward=True)
fig.savefig(buf, format="png")
# Embed the result in the HTML output
data = base64.b64encode(buf.getbuffer()).decode("ascii")
html_str = f"<img src='data:image/png;base64,{data}'/>"
# Enable pyplot to properly clean up the memory
plt.cla()
plt.close(fig)
fig = components.html(html_str, height=height, width=width)
else:
fig = components.html(
"<p>[Error] No plot to display. Received object of type <class 'NoneType'>.</p>"
)
# SHAP plots return a matplotlib.figure.Figure object when passed show=False as an argument
elif isinstance(plot, Figure):
fig = plot
# Save it to a temporary buffer
buf = BytesIO()
if height is None:
_, height = fig.get_size_inches() * fig.dpi
if width is None:
width, _ = fig.get_size_inches() * fig.dpi
fig.set_size_inches(width / fig.dpi, height / fig.dpi, forward=True)
fig.savefig(buf, format="png")
# Embed the result in the HTML output
data = base64.b64encode(buf.getbuffer()).decode("ascii")
html_str = f"<img src='data:image/png;base64,{data}'/>"
# Enable pyplot to properly clean up the memory
plt.cla()
plt.close(fig)
fig = components.html(html_str, height=height, width=width)
# SHAP plots containing JS/HTML have one or more of the following callable attributes
elif hasattr(plot, "html") or hasattr(plot, "data") or hasattr(plot, "matplotlib"):
shap_js = f"{shap.getjs()}".replace("height=350", f"height={height}").replace(
"width=100", f"width={width}"
)
shap_html = f"<head>{shap_js}</head><body>{plot.html()}</body>"
fig = components.html(shap_html, height=height, width=width)
# shap.plots.text plots have been overridden to return a string
elif isinstance(plot, str):
fig = components.html(plot, height=height, width=width, scrolling=True)
else:
fig = components.html(
"<p>[Error] No plot to display. Unable to understand input.</p>"
)
return fig