Skip to content

Commit

Permalink
Fix string labels in performance report / simple model comparison (#504)
Browse files Browse the repository at this point in the history
* use_histogram

* use_histogram

* use_histogram

* use_histogram
  • Loading branch information
JKL98ISR committed Jan 5, 2022
1 parent 682c4cf commit 1112b97
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 656 deletions.
10 changes: 5 additions & 5 deletions deepchecks/checks/performance/performance_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ def _performance_report(self, train_dataset: Dataset, test_dataset: Dataset, mod
validate_model(test_dataset, model)

task_type = task_type_check(model, train_dataset)
clasess = test_dataset.classes
clasess = train_dataset.classes

# Get default scorers if no alternative, or validate alternatives
scorers = get_scorers_list(model, test_dataset, self.alternative_scorers, multiclass_avg=False)
datasets = {'Train': train_dataset, 'Test': test_dataset}

if task_type in [ModelType.MULTICLASS, ModelType.BINARY]:
plot_x_axis = ['Class', 'Dataset']
plot_x_axis = 'Class'
results = []

for dataset_name, dataset in datasets.items():
Expand All @@ -128,7 +128,7 @@ def _performance_report(self, train_dataset: Dataset, test_dataset: Dataset, mod
]
results_df = pd.DataFrame(results, columns=['Dataset', 'Metric', 'Value', 'Number of samples'])

fig = px.bar(
fig = px.histogram(
results_df,
x=plot_x_axis,
y='Value',
Expand Down Expand Up @@ -340,7 +340,7 @@ def run_logic(self, context: ModelComparisonContext):
]
results_df = pd.DataFrame(results, columns=['Model', 'Value', 'Metric', 'Number of samples'])

fig = px.bar(
fig = px.histogram(
results_df,
x=plot_x_axis,
y='Value',
Expand All @@ -357,7 +357,7 @@ def run_logic(self, context: ModelComparisonContext):
fig.update_xaxes(title=None)

fig = (
fig.update_yaxes(title=None, matches=None, zerolinecolor='#444')
fig.update_yaxes(title=None, matches=None)
.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))
)
Expand Down
4 changes: 2 additions & 2 deletions deepchecks/checks/performance/simple_model_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def run(

# Plot the metrics in a graph, grouping by the model and class
fig = (
px.bar(
px.histogram(
results_df,
x=['Class', 'Model'],
y='Value',
Expand Down Expand Up @@ -211,7 +211,7 @@ def run(

# Plot the metrics in a graph, grouping by the model
fig = (
px.bar(
px.histogram(
results_df,
x='Model',
y='Value',
Expand Down

0 comments on commit 1112b97

Please sign in to comment.