Skip to content

Commit

Permalink
dtreeviz 2.0 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Feb 11, 2023
1 parent 6b21dc1 commit bfefde5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
38 changes: 19 additions & 19 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import shap

from dtreeviz import *
from dtreeviz import DTreeVizAPI
from dtreeviz.models.shadow_decision_tree import ShadowDecTree

from sklearn.model_selection import KFold
Expand Down Expand Up @@ -3628,7 +3628,7 @@ def get_decisionpath_summary_df(self, tree_idx, index, round=2, pos_label=None):
self.get_decisionpath_df(tree_idx, index, pos_label=pos_label),
classifier=self.is_classifier, round=round, units=self.units)

def decisiontree_file(self, tree_idx, index, show_just_path=False):
def decisiontree_view(self, tree_idx, index, show_just_path=False):
"""get a dtreeviz visualization of a particular tree in the random forest.
Args:
Expand All @@ -3638,19 +3638,23 @@ def decisiontree_file(self, tree_idx, index, show_just_path=False):
tree. Defaults to False.
Returns:
the path where the .svg file is stored.
DTreeVizRender
"""
if not self.graphviz_available:
print("No graphviz 'dot' executable available!")
return None

viz = DTreeVizAPI(self.shadow_trees[tree_idx])

viz = dtreeviz(self.shadow_trees[tree_idx],
X=self.get_X_row(index).squeeze(),
return viz.view(x=self.get_X_row(index).squeeze(),
fancy=False,
show_node_labels = False,
show_just_path=show_just_path)
return viz.save_svg()


def decisiontree_file(self, tree_idx, index, show_just_path=False):
return self.decisiontree_view(tree_idx, index, show_just_path).save_svg()

def decisiontree(self, tree_idx, index, show_just_path=False):
"""get a dtreeviz visualization of a particular tree in the random forest.
Expand All @@ -3665,13 +3669,9 @@ def decisiontree(self, tree_idx, index, show_just_path=False):
a IPython display SVG object for e.g. jupyter notebook.
"""
if not self.graphviz_available:
print("No graphviz 'dot' executable available!")
return None

from IPython.display import SVG
svg_file = self.decisiontree_file(tree_idx, index, show_just_path)
return SVG(open(svg_file,'rb').read())

return SVG(self.decisiontree_view(tree_idx, index, show_just_path).svg())

def decisiontree_encoded(self, tree_idx, index, show_just_path=False):
"""get a dtreeviz visualization of a particular tree in the random forest.
Expand All @@ -3690,9 +3690,8 @@ def decisiontree_encoded(self, tree_idx, index, show_just_path=False):
if not self.graphviz_available:
print("No graphviz 'dot' executable available!")
return None

svg_file = self.decisiontree_file(tree_idx, index, show_just_path)
encoded = base64.b64encode(open(svg_file,'rb').read())
svg = open(self.decisiontree_file(tree_idx, index, show_just_path), "rb").read()
encoded = base64.b64encode(svg)
svg_encoded = 'data:image/svg+xml;base64,{}'.format(encoded.decode())
return svg_encoded

Expand Down Expand Up @@ -3865,7 +3864,7 @@ def get_decisionpath_summary_df(self, tree_idx, index, round=2, pos_label=None):
return get_xgboost_path_summary_df(self.get_decisionpath_df(tree_idx, index, pos_label=pos_label))

@insert_pos_label
def decisiontree_file(self, tree_idx, index, show_just_path=False, pos_label=None):
def decisiontree_view(self, tree_idx, index, show_just_path=False, pos_label=None):
"""get a dtreeviz visualization of a particular tree in the random forest.
Args:
Expand All @@ -3887,12 +3886,13 @@ def decisiontree_file(self, tree_idx, index, show_just_path=False, pos_label=Non
if len(self.labels) > 2:
tree_idx = tree_idx * len(self.labels) + pos_label

viz = dtreeviz(self.shadow_trees[tree_idx],
X=self.get_X_row(index).squeeze(),
viz = DTreeVizAPI(self.shadow_trees[tree_idx])

return viz.view(x=self.get_X_row(index).squeeze(),
fancy=False,
show_node_labels = False,
show_just_path=show_just_path)
return viz.save_svg()


@insert_pos_label
def plot_trees(self, index, highlight_tree=None, round=2,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ click
dash-auth
dash-bootstrap-components>=1
dash>=2.3.1
dtreeviz>=1.3
dtreeviz>=2.1
flask_simplelogin
graphviz>=0.18.2
joblib
Expand Down

0 comments on commit bfefde5

Please sign in to comment.