Skip to content

Commit

Permalink
fixes regression row_from_input test
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Dec 16, 2020
1 parent 83b5531 commit 13df2a7
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion tests/test_regression_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,23 @@ def test_random_index(self):

def test_row_from_input(self):
input_row = self.explainer.get_row_from_input(
1, 13, 2, 12, 12, 'Sex_male', 'A', 'Southampton')
self.explainer.X.iloc[[0]].values.tolist())
self.assertIsInstance(input_row, pd.DataFrame)

input_row = self.explainer.get_row_from_input(
self.explainer.X_cats.iloc[[0]].values.tolist())
self.assertIsInstance(input_row, pd.DataFrame)

input_row = self.explainer.get_row_from_input(
self.explainer.X_cats
[self.explainer.columns_ranked_by_shap(cats=True)]
.iloc[[0]].values.tolist(), ranked_by_shap=True)
self.assertIsInstance(input_row, pd.DataFrame)

input_row = self.explainer.get_row_from_input(
self.explainer.X
[self.explainer.columns_ranked_by_shap(cats=False)]
.iloc[[0]].values.tolist(), ranked_by_shap=True)
self.assertIsInstance(input_row, pd.DataFrame)

def test_prediction_result_df(self):
Expand Down

0 comments on commit 13df2a7

Please sign in to comment.