Skip to content

Commit

Permalink
refine tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Dec 16, 2023
1 parent e8472f1 commit b3d2895
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 14 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
tests/test_results/










# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
Binary file added images/multiclass/roc_curves_multiclass.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 16 additions & 9 deletions plotsandgraphs/multiclass_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional, List, Callable, Dict, Tuple
from typing import Optional, List, Callable, Dict, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.patches import BoxStyle
from matplotlib.colors import to_rgba
Expand All @@ -26,14 +26,14 @@
def plot_roc_curve(
y_true: np.ndarray,
y_score: np.ndarray,
confidence_interval: float = 0.95,
confidence_interval: Optional[float] = 0.95,
highlight_roc_area: bool=True,
n_bootstraps: int=1,
figsize: Optional[Tuple[float, float]]=None,
class_labels: Optional[List[str]]=None,
split_plots: bool=True,
save_fig_path=None,
) -> Tuple[Figure, Figure]:
save_fig_path=Optional[Union[str, Tuple[str, str]]],
) -> Tuple[Figure, Union[Figure, None]]:
"""
Creates two plots.
1) ROC curves for a multiclass classifier. Includes the option for bootstrapping.
Expand All @@ -60,14 +60,19 @@ def plot_roc_curve(
The labels of the classes. By default None.
split_plots : bool, optional
Whether to split the plots into two separate figures. By default True.
save_fig_path : str, optional
Path to folder where the figure should be saved. If None then plot is not saved, by default None. E.g. 'figures/'.
save_fig_path : Optional[Union[str, Tuple[str, str]]], optional
Path to folder where the figure(s) should be saved. If None then plot is not saved, by default None. If `split_plots` is False, then a single str is required. If True, then a tuple of strings (Pos 1 Roc curves comparison, Pos 2 AUROCs comparison). E.g. `save_fig_path=('figures/roc_curves.png', 'figures/aurocs_comparison.png')`.
Returns
-------
figures : Tuple[Figure, Figure]
The figures of the calibration plot. First the roc curves, then the AUROC overview.
"""
# Sanity checks
if confidence_interval is None and highlight_roc_area is True:
raise ValueError("Confidence interval must be set when highlight_roc_area is True.")
if confidence_interval is None:
confidence_interval = 0.95 # default value, but it will not be displayed in the plot

num_classes = y_true.shape[-1]
class_labels = [f"Class {i}" for i in range(num_classes)] if class_labels is None else class_labels
Expand Down Expand Up @@ -194,6 +199,7 @@ def auroc_metric_function(y_true, y_score, average, multi_class):


# Either create a new figure or use the same figure as the roc curves for the auroc overview
fig_aurocs = None
if split_plots:
fig_aurocs = plt.figure(figsize=(5,5))
ax = fig_aurocs.add_subplot(111)
Expand All @@ -220,12 +226,13 @@ def auroc_metric_function(y_true, y_score, average, multi_class):

# save auroc comparison plot
if save_fig_path and split_plots is True:
path = Path(save_fig_path) / "aurocs_comparison.png"
path = Path(save_fig_path[1])
path.parent.mkdir(parents=True, exist_ok=True)
fig_aurocs.savefig(path, bbox_inches="tight")
# save roc curves plot
if save_fig_path:
path = Path(save_fig_path) / "roc_curves.png"
if save_fig_path is not None:
path = save_fig_path[0] if split_plots is True else save_fig_path
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path, bbox_inches="tight")
return fig, fig_aurocs
Expand Down
64 changes: 59 additions & 5 deletions tests/test_multiclass_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,80 @@ def test_roc_curve(random_data_multiclass_classifier):
random_data_binary_classifier : Tuple[np.ndarray, np.ndarray]
The simulated data.
"""
# helper function for file name, to avoid repeating code
def get_path_name(confidence_interval, highlight_roc_area, n_bootstraps, figsize, split_plots):
if split_plots is False:
fig_path = TEST_RESULTS_PATH / f"roc_curves_conf_{confidence_interval}_highlight_{highlight_roc_area}_nboot_{n_bootstraps}_figsize_{figsize}.png"
else:
fig_path_1 = TEST_RESULTS_PATH / f"roc_curves_split_conf_{confidence_interval}_highlight_{highlight_roc_area}_nboot_{n_bootstraps}_figsize_{figsize}.png"
fig_path_2 = TEST_RESULTS_PATH / f"auroc_comparison_conf_{confidence_interval}_highlight_{highlight_roc_area}_nboot_{n_bootstraps}_figsize_{figsize}.png"
fig_path = [fig_path_1, fig_path_2]
return fig_path


y_true, y_prob = random_data_multiclass_classifier

confidence_intervals = [None, 0.99]
highlight_roc_area = [True, False]
n_bootstraps = [1, 10000]
figsizes = [None, (10,10)]
n_bootstraps = [1, 1000]
figsizes = [None]
split_plots = [True, False]

# From the previous lists I want all possible combinations
combinations = list(product(confidence_intervals, highlight_roc_area, n_bootstraps, figsizes, split_plots))

for confidence_interval, highlight_roc_area, n_bootstraps, figsize, split_plots in combinations:
# WE NEED THE CORRECT SAVE FIG PATH NAMES!!!!!
multiclass.plot_roc_curve(y_true=y_true, y_score=y_prob,
# check if one or two figures should be saved (splot_plots=True or False)
# if split_plots is False:
# fig_path = TEST_RESULTS_PATH / f"roc_curves_conf_{confidence_interval}_highlight_{highlight_roc_area}_nboot_{n_bootstraps}_figsize_{figsize}.png"
# else:
# fig_path_1 = TEST_RESULTS_PATH / f"roc_curves_conf_{confidence_interval}_highlight_{highlight_roc_area}_nboot_{n_bootstraps}_figsize_{figsize}.png"
# fig_path_2 = TEST_RESULTS_PATH / f"auroc_comparison_conf_{confidence_interval}_highlight_{highlight_roc_area}_nboot_{n_bootstraps}_figsize_{figsize}.png"
# fig_path = [fig_path_1, fig_path_2]

fig_path = get_path_name(confidence_interval, highlight_roc_area, n_bootstraps, figsize, split_plots)

# It should raise an error when confidence_interval is None but highlight_roc_area is True
if confidence_interval is None and highlight_roc_area is True:
with pytest.raises(ValueError):
multiclass.plot_roc_curve(y_true=y_true, y_score=y_prob,
confidence_interval=confidence_interval,
highlight_roc_area=highlight_roc_area,
n_bootstraps=n_bootstraps,
figsize=figsize,
split_plots=split_plots,
save_fig_path=fig_path)
# Otherwise no error
else:
multiclass.plot_roc_curve(y_true=y_true, y_score=y_prob,
confidence_interval=confidence_interval,
highlight_roc_area=highlight_roc_area,
n_bootstraps=n_bootstraps,
figsize=figsize,
split_plots=split_plots,
save_fig_path=fig_path)

# check for SMALL figure size
confidence_interval, highlight_roc_area, n_bootstraps, figsize, split_plots = 0.95, True, 100, (3,3), False
fig_path = get_path_name(confidence_interval, highlight_roc_area, n_bootstraps, figsize, split_plots)
multiclass.plot_roc_curve(y_true=y_true, y_score=y_prob,
confidence_interval=confidence_interval,
highlight_roc_area=highlight_roc_area,
n_bootstraps=n_bootstraps,
figsize=figsize,
split_plots=split_plots,
save_fig_path=fig_path)

# check for BIG figure size
confidence_interval, highlight_roc_area, n_bootstraps, figsize, split_plots = 0.95, True, 100, (15, 15), False
fig_path = get_path_name(confidence_interval, highlight_roc_area, n_bootstraps, figsize, split_plots)
multiclass.plot_roc_curve(y_true=y_true, y_score=y_prob,
confidence_interval=confidence_interval,
highlight_roc_area=highlight_roc_area,
n_bootstraps=n_bootstraps,
figsize=figsize,
split_plots=split_plots,
save_fig_path=TEST_RESULTS_PATH / "roc_curve.png",)
save_fig_path=fig_path)



Expand Down

0 comments on commit b3d2895

Please sign in to comment.