Skip to content

Commit

Permalink
raise value error with y_missing
Browse files Browse the repository at this point in the history
certain regression plots depend on y, but when no y was passed to explainer, now raise ValueError
  • Loading branch information
oegedijk committed Nov 18, 2020
1 parent dac44ce commit 47a5fce
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,9 @@ def prediction_result_markdown(self, index, include_percentile=True, round=2, **

def metrics(self):
"""dict of performance metrics: rmse, mae and R^2"""

if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot calculate metrics!")
metrics_dict = {
'rmse' : np.sqrt(mean_squared_error(self.y, self.preds)),
'mae' : mean_absolute_error(self.y, self.preds),
Expand Down Expand Up @@ -2513,6 +2516,8 @@ def plot_predicted_vs_actual(self, round=2, logs=False, log_x=False, log_y=False
Plotly fig
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot plot predicted vs actual!")
return plotly_predicted_vs_actual(self.y, self.preds,
target=self.target, units=self.units, idxs=self.idxs.values,
logs=logs, log_x=log_x, log_y=log_y, round=round,
Expand All @@ -2531,6 +2536,8 @@ def plot_residuals(self, vs_actual=False, round=2, residuals='difference'):
Plotly fig
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot plot residuals!")
return plotly_plot_residuals(self.y, self.preds, idxs=self.idxs.values,
vs_actual=vs_actual, target=self.target,
units=self.units, residuals=residuals,
Expand All @@ -2554,6 +2561,8 @@ def plot_residuals_vs_feature(self, col, residuals='difference', round=2,
Returns:
plotly fig
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot plot residuals!")
assert col in self.columns or col in self.columns_cats, \
f'{col} not in columns or columns_cats!'
col_vals = self.X_cats[col] if self.check_cats(col) else self.X[col]
Expand All @@ -2579,6 +2588,8 @@ def plot_y_vs_feature(self, col, residuals='difference', round=2,
Returns:
plotly fig
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot plot y vs feature!")
assert col in self.columns or col in self.columns_cats, \
f'{col} not in columns or columns_cats!'
col_vals = self.X_cats[col] if self.check_cats(col) else self.X[col]
Expand Down

0 comments on commit 47a5fce

Please sign in to comment.