1
+ from typing import Literal , Union
2
+ from pathlib import Path
3
+ from tqdm .auto import tqdm
4
+
1
5
from . import binary_classifier as bc
6
+ from . import multiclass_classifier as mc
2
7
3
- from tqdm .auto import tqdm
8
+
9
+
10
+ FILE_ENDINGS = Literal ['pdf' , 'png' , 'jpg' , 'jpeg' , 'svg' ]
4
11
5
12
6
13
7
- def binary_classifier (y_true , y_score , save_fig_path = None , plot_kwargs = {}):
14
+ def binary_classifier (y_true , y_score , save_fig_path = None , plot_kwargs = {}, file_type :FILE_ENDINGS = 'png' ):
15
+
8
16
9
17
# Create new tqdm instance
10
18
tqdm_instance = tqdm (total = 6 , desc = 'Binary classifier metrics' , leave = True )
@@ -14,32 +22,91 @@ def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}):
14
22
15
23
# 1) Plot ROC curve
16
24
roc_kwargs = plot_kwargs .get ('roc' , {})
17
- bc .plot_roc_curve (y_true , y_score , save_fig_path = save_fig_path , ** roc_kwargs )
25
+ save_path = get_file_path (save_fig_path , 'roc_curve' , file_type )
26
+ bc .plot_roc_curve (y_true , y_score , save_fig_path = save_path , ** roc_kwargs )
18
27
tqdm_instance .update ()
19
28
20
29
# 2) Plot precision-recall curve
21
30
pr_kwargs = plot_kwargs .get ('pr' , {})
22
- bc .plot_pr_curve (y_true , y_score , save_fig_path = save_fig_path , ** pr_kwargs )
31
+ save_path = get_file_path (save_fig_path , 'pr_curve' , file_type )
32
+ bc .plot_pr_curve (y_true , y_score , save_fig_path = save_path , ** pr_kwargs )
23
33
tqdm_instance .update ()
24
34
25
35
# 3) Plot calibration curve
26
36
cal_kwargs = plot_kwargs .get ('cal' , {})
27
- bc .plot_calibration_curve (y_true , y_score , save_fig_path = save_fig_path , ** cal_kwargs )
37
+ save_path = get_file_path (save_fig_path , 'calibration_curve' , file_type )
38
+ bc .plot_calibration_curve (y_true , y_score , save_fig_path = save_path , ** cal_kwargs )
28
39
tqdm_instance .update ()
29
40
30
41
# 3) Plot confusion matrix
31
42
cm_kwargs = plot_kwargs .get ('cm' , {})
32
- bc .plot_confusion_matrix (y_true , y_score , save_fig_path = save_fig_path , ** cm_kwargs )
43
+ save_path = get_file_path (save_fig_path , 'confusion_matrix' , file_type )
44
+ bc .plot_confusion_matrix (y_true , y_score , save_fig_path = save_path , ** cm_kwargs )
33
45
tqdm_instance .update ()
34
46
35
47
# 5) Plot classification report
36
48
cr_kwargs = plot_kwargs .get ('cr' , {})
37
- bc .plot_classification_report (y_true , y_score , save_fig_path = save_fig_path , ** cr_kwargs )
49
+ save_path = get_file_path (save_fig_path , 'classification_report' , file_type )
50
+ bc .plot_classification_report (y_true , y_score , save_fig_path = save_path , ** cr_kwargs )
51
+ tqdm_instance .update ()
52
+
53
+ # 6) Plot y_score histogram
54
+ hist_kwargs = plot_kwargs .get ('hist' , {})
55
+ save_path = get_file_path (save_fig_path , 'y_score_histogram' , file_type )
56
+ bc .plot_y_score_histogram (y_true , y_score , save_fig_path = save_path , ** hist_kwargs )
57
+ tqdm_instance .update ()
58
+
59
+ return
60
+
61
+
62
+
63
+
64
+ def multiclass_classifier (y_true , y_score , save_fig_path = None , plot_kwargs = {}):
65
+
66
+ # Create new tqdm instance
67
+ tqdm_instance = tqdm (total = 6 , desc = 'Binary classifier metrics' , leave = True )
68
+
69
+ # Update tqdm instance
38
70
tqdm_instance .update ()
39
71
72
+ # 1) Plot ROC curve
73
+ roc_kwargs = plot_kwargs .get ('roc' , {})
74
+ mc .plot_roc_curve (y_true , y_score , save_fig_path = save_fig_path , ** roc_kwargs )
75
+ tqdm_instance .update ()
76
+
77
+ # 2) Plot precision-recall curve
78
+ # pr_kwargs = plot_kwargs.get('pr', {})
79
+ # mc.plot_pr_curve(y_true, y_score, save_fig_path=save_fig_path, **pr_kwargs)
80
+ # tqdm_instance.update()
81
+
82
+ # 3) Plot calibration curve
83
+ # cal_kwargs = plot_kwargs.get('cal', {})
84
+ # mc.plot_calibration_curve(y_true, y_score, save_fig_path=save_fig_path, **cal_kwargs)
85
+ # tqdm_instance.update()
86
+
87
+ # 3) Plot confusion matrix
88
+ # cm_kwargs = plot_kwargs.get('cm', {})
89
+ # mc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_fig_path, **cm_kwargs)
90
+ # tqdm_instance.update()
91
+
92
+ # 5) Plot classification report
93
+ # cr_kwargs = plot_kwargs.get('cr', {})
94
+ # mc.plot_classification_report(y_true, y_score, save_fig_path=save_fig_path, **cr_kwargs)
95
+ # tqdm_instance.update()
96
+
40
97
# 6) Plot y_score histogram
41
98
hist_kwargs = plot_kwargs .get ('hist' , {})
42
- bc .plot_y_score_histogram (y_true , y_score , save_fig_path = save_fig_path , ** hist_kwargs )
99
+ mc .plot_y_score_histogram (y_true , y_score , save_fig_path = save_fig_path , ** hist_kwargs )
43
100
tqdm_instance .update ()
44
101
45
- return
102
+ return
103
+
104
+
105
+
106
+ def get_file_path (save_fig_path : Union [Path ,None , str ], name :str , ending :str ):
107
+ if save_fig_path is None :
108
+ return None
109
+ else :
110
+ result = Path (save_fig_path ) / f"{ name } .{ ending } "
111
+ print (result )
112
+ return str (result )
0 commit comments