Skip to content

Commit 1cc259d

Browse files
committed
pipeline multiclass
1 parent 59aa72b commit 1cc259d

File tree

4 files changed

+81
-12
lines changed

4 files changed

+81
-12
lines changed

plotsandgraphs/binary_classifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ def plot_pr_curve(
580580

581581
# Save the figure if save_fig_path is specified
582582
if save_fig_path:
583-
plt.savefig(save_fig_path, bbox_inches="tight")
583+
path = Path(save_fig_path)
584+
path.parent.mkdir(parents=True, exist_ok=True)
585+
fig.savefig(save_fig_path, bbox_inches="tight")
584586

585587
return fig

plotsandgraphs/multiclass_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def auroc_metric_function(y_true, y_score, average, multi_class):
267267
return fig, fig_aurocs
268268

269269

270-
def plot_y_prob_histogram(
270+
def plot_y_score_histogram(
271271
y_true: np.ndarray, y_score: Optional[np.ndarray] = None, save_fig_path: Optional[str]=None
272272
) -> Figure:
273273
"""

plotsandgraphs/pipeline.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1+
from typing import Literal, Union
2+
from pathlib import Path
3+
from tqdm.auto import tqdm
4+
15
from . import binary_classifier as bc
6+
from . import multiclass_classifier as mc
27

3-
from tqdm.auto import tqdm
8+
9+
10+
FILE_ENDINGS = Literal['pdf', 'png', 'jpg', 'jpeg', 'svg']
411

512

613

7-
def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}):
14+
def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}, file_type:FILE_ENDINGS='png'):
15+
816

917
# Create new tqdm instance
1018
tqdm_instance = tqdm(total=6, desc='Binary classifier metrics', leave=True)
@@ -14,32 +22,91 @@ def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}):
1422

1523
# 1) Plot ROC curve
1624
roc_kwargs = plot_kwargs.get('roc', {})
17-
bc.plot_roc_curve(y_true, y_score, save_fig_path=save_fig_path, **roc_kwargs)
25+
save_path = get_file_path(save_fig_path, 'roc_curve', file_type)
26+
bc.plot_roc_curve(y_true, y_score, save_fig_path=save_path, **roc_kwargs)
1827
tqdm_instance.update()
1928

2029
# 2) Plot precision-recall curve
2130
pr_kwargs = plot_kwargs.get('pr', {})
22-
bc.plot_pr_curve(y_true, y_score, save_fig_path=save_fig_path, **pr_kwargs)
31+
save_path = get_file_path(save_fig_path, 'pr_curve', file_type)
32+
bc.plot_pr_curve(y_true, y_score, save_fig_path=save_path, **pr_kwargs)
2333
tqdm_instance.update()
2434

2535
# 3) Plot calibration curve
2636
cal_kwargs = plot_kwargs.get('cal', {})
27-
bc.plot_calibration_curve(y_true, y_score, save_fig_path=save_fig_path, **cal_kwargs)
37+
save_path = get_file_path(save_fig_path, 'calibration_curve', file_type)
38+
bc.plot_calibration_curve(y_true, y_score, save_fig_path=save_path, **cal_kwargs)
2839
tqdm_instance.update()
2940

3041
# 3) Plot confusion matrix
3142
cm_kwargs = plot_kwargs.get('cm', {})
32-
bc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_fig_path, **cm_kwargs)
43+
save_path = get_file_path(save_fig_path, 'confusion_matrix', file_type)
44+
bc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_path, **cm_kwargs)
3345
tqdm_instance.update()
3446

3547
# 5) Plot classification report
3648
cr_kwargs = plot_kwargs.get('cr', {})
37-
bc.plot_classification_report(y_true, y_score, save_fig_path=save_fig_path, **cr_kwargs)
49+
save_path = get_file_path(save_fig_path, 'classification_report', file_type)
50+
bc.plot_classification_report(y_true, y_score, save_fig_path=save_path, **cr_kwargs)
51+
tqdm_instance.update()
52+
53+
# 6) Plot y_score histogram
54+
hist_kwargs = plot_kwargs.get('hist', {})
55+
save_path = get_file_path(save_fig_path, 'y_score_histogram', file_type)
56+
bc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_path, **hist_kwargs)
57+
tqdm_instance.update()
58+
59+
return
60+
61+
62+
63+
64+
def multiclass_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}):
65+
66+
# Create new tqdm instance
67+
tqdm_instance = tqdm(total=6, desc='Binary classifier metrics', leave=True)
68+
69+
# Update tqdm instance
3870
tqdm_instance.update()
3971

72+
# 1) Plot ROC curve
73+
roc_kwargs = plot_kwargs.get('roc', {})
74+
mc.plot_roc_curve(y_true, y_score, save_fig_path=save_fig_path, **roc_kwargs)
75+
tqdm_instance.update()
76+
77+
# 2) Plot precision-recall curve
78+
# pr_kwargs = plot_kwargs.get('pr', {})
79+
# mc.plot_pr_curve(y_true, y_score, save_fig_path=save_fig_path, **pr_kwargs)
80+
# tqdm_instance.update()
81+
82+
# 3) Plot calibration curve
83+
# cal_kwargs = plot_kwargs.get('cal', {})
84+
# mc.plot_calibration_curve(y_true, y_score, save_fig_path=save_fig_path, **cal_kwargs)
85+
# tqdm_instance.update()
86+
87+
# 3) Plot confusion matrix
88+
# cm_kwargs = plot_kwargs.get('cm', {})
89+
# mc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_fig_path, **cm_kwargs)
90+
# tqdm_instance.update()
91+
92+
# 5) Plot classification report
93+
# cr_kwargs = plot_kwargs.get('cr', {})
94+
# mc.plot_classification_report(y_true, y_score, save_fig_path=save_fig_path, **cr_kwargs)
95+
# tqdm_instance.update()
96+
4097
# 6) Plot y_score histogram
4198
hist_kwargs = plot_kwargs.get('hist', {})
42-
bc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_fig_path, **hist_kwargs)
99+
mc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_fig_path, **hist_kwargs)
43100
tqdm_instance.update()
44101

45-
return
102+
return
103+
104+
105+
106+
def get_file_path(save_fig_path: Union[Path,None, str], name:str, ending:str):
107+
if save_fig_path is None:
108+
return None
109+
else:
110+
result = Path(save_fig_path) / f"{name}.{ending}"
111+
print(result)
112+
return str(result)

tests/test_binary_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import plotsandgraphs.binary_classifier as binary
66

7-
TEST_RESULTS_PATH = Path(r"tests\test_results")
7+
TEST_RESULTS_PATH = Path(r"tests\test_results\binary_classifier")
88

99

1010
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)