Skip to content

Commit de8a750

Browse files
authored
Black box titles for multiclass histograms (#24)
* update readme * add black boxes to histplot, update test
1 parent 23c83cc commit de8a750

File tree

6 files changed

+88
-203
lines changed

6 files changed

+88
-203
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ Furthermore, this library presents other useful visualizations, such as **compar
6464
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_prob histogram |
6565

6666

67-
| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/raincloud.png?raw=true" width="300" alt="Your Image"> | <img src="" width="300" height="300" alt=""> | <img src="" width="300" height="300" alt=""> |
67+
| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/histogram_4_classes.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/roc_curves_multiclass.png?raw=true" width="300" alt=""> | <img src="" width="300" height="300" alt=""> |
6868
|:--------------------------------------------------:|:-------------------------------------------------:| :-------------------------------------------------:|
69-
| Raincloud | | |
69+
| Histogram (y_scores) | ROC curves (AUROC) with bootstrapping | |
7070

7171

7272

7.19 KB
Loading

notebooks/multiclass_classification.ipynb

Lines changed: 26 additions & 168 deletions
Large diffs are not rendered by default.

plotsandgraphs/multiclass_classifier.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sklearn.utils import resample
2121
from tqdm import tqdm
2222

23-
from plotsandgraphs.utils import bootstrap, set_black_title_box, scale_ax_bbox, get_cmap
23+
from plotsandgraphs.utils import bootstrap, set_black_title_boxes, scale_ax_bbox, get_cmap
2424

2525

2626
def plot_roc_curve(
@@ -32,7 +32,7 @@ def plot_roc_curve(
3232
figsize: Optional[Tuple[float, float]] = None,
3333
class_labels: Optional[List[str]] = None,
3434
split_plots: bool = True,
35-
save_fig_path=Optional[Union[str, Tuple[str, str]]],
35+
save_fig_path:Optional[Union[str, Tuple[str, str]]] = None,
3636
) -> Tuple[Figure, Union[Figure, None]]:
3737
"""
3838
Creates two plots.
@@ -188,22 +188,11 @@ def roc_metric_function(y_true, y_score):
188188
for i in range(num_classes, len(axes.flat)):
189189
axes.flat[i].axis("off")
190190

191-
# make the subplot tiles (and black boxes)
192-
for i in range(num_classes):
193-
set_black_title_box(axes.flat[i], f"Class {i}")
194-
plt.tight_layout(h_pad=1.5)
195-
# make the subplot tiles (and black boxes)
196-
# First time to get the approx. correct spacing with plt.tight_layout()
197-
# Second time to get the correct width of the black box
198-
# Thank you matplotlib ...
199-
for i in range(num_classes):
200-
set_black_title_box(
201-
axes.flat[i],
202-
f"Class {i}",
203-
set_title_kwargs={
204-
"fontdict": {"fontname": "Arial Black", "fontweight": "bold"}
205-
},
206-
)
191+
# create the subplot tiles (and black boxes)
192+
set_black_title_boxes(axes.flat[:num_classes], class_labels)
193+
194+
195+
207196

208197
# ---------- AUROC overview plot comparing classes ----------
209198
# Make an AUROC overview plot comparing the aurocs per class and combined
@@ -281,13 +270,12 @@ def auroc_metric_function(y_true, y_score, average, multi_class):
281270
def plot_y_prob_histogram(
282271
y_true: np.ndarray, y_prob: Optional[np.ndarray] = None, save_fig_path=None
283272
) -> Figure:
273+
num_classes = y_true.shape[-1]
274+
class_labels = [f"Class {i}" for i in range(num_classes)]
275+
284276
# Aiming for a square plot
285-
plot_cols = np.ceil(np.sqrt(y_true.shape[-1])).astype(
286-
int
287-
) # Number of plots in a row
288-
plot_rows = np.ceil(y_true.shape[-1] / plot_cols).astype(
289-
int
290-
) # Number of plots in a column
277+
plot_cols = np.ceil(np.sqrt(num_classes)).astype(int) # Number of plots in a row # noqa
278+
plot_rows = np.ceil(num_classes / plot_cols).astype(int) # Number of plots in a column # noqa
291279
fig, axes = plt.subplots(
292280
nrows=plot_rows,
293281
ncols=plot_cols,
@@ -298,11 +286,11 @@ def plot_y_prob_histogram(
298286
plt.suptitle("Predicted probability histogram")
299287

300288
# Flatten axes if there is only one class, even though this function is designed for multiclasses
301-
if y_true.shape[-1] == 1:
289+
if num_classes == 1:
302290
axes = np.array([axes])
303291

304292
for i, ax in enumerate(axes.flat):
305-
if i >= y_true.shape[-1]:
293+
if i >= num_classes:
306294
ax.axis("off")
307295
continue
308296

@@ -327,7 +315,7 @@ def plot_y_prob_histogram(
327315
linewidth=2,
328316
rwidth=1,
329317
)
330-
ax.set_title(f"Class {i}")
318+
ax.set_title(class_labels[i])
331319
ax.set_xlim((-0.005, 1.0))
332320
# if subplot in first column
333321
if (i % plot_cols) == 0:
@@ -342,7 +330,8 @@ def plot_y_prob_histogram(
342330
if i == 0:
343331
ax.legend()
344332

345-
plt.tight_layout()
333+
# create the subplot tiles (and black boxes)
334+
set_black_title_boxes(axes.flat[:num_classes], class_labels)
346335

347336
# save plot
348337
if save_fig_path is not None:

plotsandgraphs/utils.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Optional, List, Callable, Dict, Tuple, Union, TYPE_CHECKING
1+
from typing import Optional, List, Callable, Dict, Tuple, Union, TYPE_CHECKING, Literal
22
from tqdm import tqdm
33
from sklearn.utils import resample
44
import numpy as np
5+
import matplotlib.pyplot as plt
56
from matplotlib.path import Path
67
from matplotlib.patches import BoxStyle
78
from matplotlib.colors import LinearSegmentedColormap
@@ -99,8 +100,9 @@ def __call__(self, x0, y0, width, height, mutation_size):
99100
closed=True)
100101

101102

102-
def set_black_title_box(ax: "Axes", title=str, backgroundcolor='black', color='white', set_title_kwargs: Dict={}):
103+
def _set_black_title_box(ax: "Axes", title:str, backgroundcolor='black', color='white', title_kwargs: Optional[Dict]=None):
103104
"""
105+
Note: Do not use this function by itself, instead use `set_black_title_boxes()`.
104106
Sets the title of the given axes with a black bounding box.
105107
Note: When using `plt.tight_layout()` the box might not have the correct width. First call `plt.tight_layout()` and then `set_black_title_box()`.
106108
@@ -111,14 +113,50 @@ def set_black_title_box(ax: "Axes", title=str, backgroundcolor='black', color='w
111113
- color: The color of the title text (default: 'white').
112114
- set_title_kwargs: Keyword arguments to pass to `ax.set_title()`.
113115
"""
116+
if title_kwargs is None:
117+
title_kwargs = {'fontdict': {"fontname": "Arial Black", "fontweight": "bold"}}
114118
BoxStyle._style_list["ext"] = ExtendedTextBox_v2
115119
ax_width = ax.get_window_extent().width
116120
# make title with black bounding box
117-
title = ax.set_title(title, backgroundcolor=backgroundcolor, color=color, **set_title_kwargs)
118-
bb = title.get_bbox_patch() # get bbox from title
121+
title_instance = ax.set_title(title, backgroundcolor=backgroundcolor, color=color, **title_kwargs)
122+
bb = title_instance.get_bbox_patch() # get bbox from title
119123
bb.set_boxstyle("ext", pad=0.1, width=ax_width) # use custom style
120124

121125

126+
def set_black_title_boxes(axes: "np.ndarray[Axes]", titles: List[str], backgroundcolor='black', color='white', title_kwargs: Optional[Dict]=None, tight_layout_kwargs: Dict={}):
127+
"""
128+
Creates black boxes for the subtitles above the given axes with the given titles. The subtitles are centered above the axes.
129+
130+
Parameters
131+
----------
132+
axes : np.ndarray["Axes"]
133+
np.ndarray of matplotlib.axes.Axes objects. (Usually returned by plt.subplots() call)
134+
titles : List[str]
135+
List of titles for the axes. Same length as axes.
136+
backgroundcolor : str, optional
137+
Background color of boxes, by default 'black'
138+
color : str, optional
139+
Font color, by default 'white'
140+
title_kwargs : Dict, optional
141+
kwargs for the `ax.set_title()` call, by default {}
142+
tight_layout_kwargs : Dict, optional
143+
kwargs for the `plt.tight_layout()` call, by default {}
144+
"""
145+
146+
for i, ax in enumerate(axes.flat):
147+
_set_black_title_box(ax, titles[i], backgroundcolor, color, title_kwargs)
148+
149+
plt.tight_layout(**tight_layout_kwargs)
150+
151+
for i, ax in enumerate(axes.flat):
152+
_set_black_title_box(ax, titles[i], backgroundcolor, color, title_kwargs)
153+
154+
155+
return
156+
157+
158+
159+
122160
def scale_ax_bbox(ax: "Axes", factor: float):
123161
# Get the current position of the subplot
124162
box = ax.get_position()

tests/test_multiclass_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_hist_plot():
5454
random_data_binary_classifier : Tuple[np.ndarray, np.ndarray]
5555
The simulated data.
5656
"""
57-
for num_classes in [2, 3, 4, 5, 10, 16, 25]:
57+
for num_classes in [1, 2, 3, 4, 5, 10, 16, 25]:
5858
y_true, y_prob = random_data_multiclass_classifier(num_classes=num_classes)
5959
print(TEST_RESULTS_PATH)
6060
multiclass.plot_y_prob_histogram(y_true=y_true, y_prob=y_prob, save_fig_path=TEST_RESULTS_PATH / f"histogram_{num_classes}_classes.png")

0 commit comments

Comments
 (0)