diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8f0f992..37a37ef 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,9 +4,12 @@ What's new in version 0.9 - Fixes a bug where ``show_percentages`` used the incorrect denominator if filtering (e.g. ``min_subset_size``) was applied. This bug was a regression introduced in version 0.7. (:issue:`248`) +- Add a ``style_categories`` method to customize category plot styles, including + shading of rows in the intersection matrix, and bars in the totals plot. + (:issue:`261` with thanks to :user:`Marcel Albus `). - Ability to disable totals plot with `totals_plot_elements=0`. (:issue:`246`) - Ability to set totals y axis label (:issue:`243`) -- Added ``max_subset_rank`` to get only n most populous subsets. +- Added ``max_subset_rank`` to get only n most populous subsets. (:issue:`253`) - Added support for ``min_subset_size`` and ``max_subset_size`` specified as percentage. (:issue:`264`) diff --git a/examples/plot_customize_after_plot.py b/examples/plot_customize_after_plot.py index b763bcd..45a06d5 100644 --- a/examples/plot_customize_after_plot.py +++ b/examples/plot_customize_after_plot.py @@ -4,7 +4,7 @@ ======================= This example illustrates how the return value of the plot method can be used -to customize aspects of the plot, such as axis labels. +to customize aspects of the plot, such as axis labels, legend position, etc. """ from matplotlib import pyplot as plt @@ -18,14 +18,5 @@ plot_result = plot(example) plot_result["intersections"].set_ylabel("Subset size") -plot_result["totals"].set_ylabel("Category size") -plot_result["matrix"].set_xlabel("Subsets between categories") -plt.show() - - -########################################################################## -# Or we can place the totals label on the x axis - -plot_result = plot(example) plot_result["totals"].set_xlabel("Category size") plt.show() diff --git a/examples/plot_highlight.py b/examples/plot_highlight.py index 1cffbc6..bccc332 100644 --- a/examples/plot_highlight.py +++ b/examples/plot_highlight.py @@ -62,7 +62,9 @@ upset = UpSet(example, facecolor="gray") upset.style_subsets(present="cat0", label="Contains cat0", facecolor="blue") -upset.style_subsets(present="cat1", label="Contains cat1", hatch="xx") +upset.style_subsets( + present="cat1", label="Contains cat1", hatch="xx", edgecolor="black" +) upset.style_subsets(present="cat2", label="Contains cat2", edgecolor="red") # reduce legend size: diff --git a/examples/plot_highlight_categories.py b/examples/plot_highlight_categories.py new file mode 100644 index 0000000..eae84d9 --- /dev/null +++ b/examples/plot_highlight_categories.py @@ -0,0 +1,41 @@ +""" +================================ +Highlighting selected categories +================================ + +Demonstrates use of the `style_categories` method to mark some +categories differently. +""" + +from matplotlib import pyplot as plt + +from upsetplot import UpSet, generate_counts + +example = generate_counts() + + +########################################################################## +# Categories can be shaded by name with the ``shading_`` parameters. + +upset = UpSet(example) +upset.style_categories("cat2", shading_edgecolor="darkgreen", shading_linewidth=1) +upset.style_categories( + "cat1", + shading_facecolor="lavender", +) +upset.plot() +plt.suptitle("Shade or edge a category with color") +plt.show() + + +########################################################################## +# Category total bars can be styled with the ``bar_`` parameters. +# You can also specify categories using a list of names. + +upset = UpSet(example) +upset.style_categories( + ["cat2", "cat1"], bar_facecolor="aqua", bar_hatch="xx", bar_edgecolor="black" +) +upset.plot() +plt.suptitle("") +plt.show() diff --git a/upsetplot/plotting.py b/upsetplot/plotting.py index ce2eb9e..3d5630f 100644 --- a/upsetplot/plotting.py +++ b/upsetplot/plotting.py @@ -276,6 +276,7 @@ class UpSet: """ _default_figsize = (10, 6) + DPI = 100 # standard matplotlib value def __init__( self, @@ -348,6 +349,7 @@ def __init__( reverse=not self._horizontal, include_empty_subsets=include_empty_subsets, ) + self.category_styles = {} self.subset_styles = [ {"facecolor": facecolor} for i in range(len(self.intersections)) ] @@ -838,7 +840,7 @@ def plot_matrix(self, ax): tick_axis.set_ticklabels( data.index.names, rotation=0 if self._horizontal else -90 ) - ax.xaxis.set_ticks([]) + ax.xaxis.set_visible(False) ax.tick_params(axis="both", which="both", length=0) if not self._horizontal: ax.yaxis.set_ticks_position("top") @@ -942,28 +944,55 @@ def plot_totals(self, ax): ) self._label_sizes(ax, rects, "left" if self._horizontal else "top") + for category, rect in zip(self.totals.index.values, rects): + style = { + k[len("bar_") :]: v + for k, v in self.category_styles.get(category, {}).items() + if k.startswith("bar_") + } + style.setdefault("edgecolor", style.get("facecolor", self._facecolor)) + for attr, val in style.items(): + getattr(rect, "set_" + attr)(val) + max_total = self.totals.max() if self._horizontal: orig_ax.set_xlim(max_total, 0) for x in ["top", "left", "right"]: ax.spines[self._reorient(x)].set_visible(False) - ax.yaxis.set_visible(True) - ax.yaxis.set_ticklabels([]) - ax.yaxis.set_ticks([]) + ax.yaxis.set_visible(False) ax.xaxis.grid(True) ax.yaxis.grid(False) ax.patch.set_visible(False) def plot_shading(self, ax): - # alternating row shading (XXX: use add_patch(Rectangle)?) - for i in range(0, len(self.totals), 2): + # shade all rows, set every second row to zero visibility + for i, category in enumerate(self.totals.index): + default_shading = ( + self._shading_color if i % 2 == 0 else (0.0, 0.0, 0.0, 0.0) + ) + shading_style = { + k[len("shading_") :]: v + for k, v in self.category_styles.get(category, {}).items() + if k.startswith("shading_") + } + + lw = shading_style.get( + "linewidth", 1 if shading_style.get("edgecolor") else 0 + ) + lw_padding = lw / (self._default_figsize[0] * self.DPI) + start_x = lw_padding + end_x = 1 - lw_padding * 3 + rect = plt.Rectangle( - self._swapaxes(0, i - 0.4), - *self._swapaxes(*(1, 0.8)), - facecolor=self._shading_color, - lw=0, + self._swapaxes(start_x, i - 0.4), + *self._swapaxes(end_x, 0.8), + facecolor=shading_style.get("facecolor", default_shading), + edgecolor=shading_style.get("edgecolor", None), + ls=shading_style.get("linestyle", "-"), + lw=lw, zorder=0, ) + ax.add_patch(rect) ax.set_frame_on(False) ax.tick_params( @@ -982,6 +1011,66 @@ def plot_shading(self, ax): ax.set_xticklabels([]) ax.set_yticklabels([]) + def style_categories( + self, + categories, + *, + bar_facecolor=None, + bar_hatch=None, + bar_edgecolor=None, + bar_linewidth=None, + bar_linestyle=None, + shading_facecolor=None, + shading_edgecolor=None, + shading_linewidth=None, + shading_linestyle=None, + ): + """Updates the style of the categories. + + Select a category by name, and style either its total bar or its shading. + + .. versionadded:: 0.9 + + Parameters + ---------- + categories : str or list[str] + Category names where the changed style applies. + bar_facecolor : str or RGBA matplotlib color tuple, optional. + Override the default facecolor in the totals plot. + bar_hatch : str, optional + Set a hatch for the totals plot. + bar_edgecolor : str or matplotlib color, optional + Set the edgecolor for total bars. + bar_linewidth : int, optional + Line width in points for total bar edges. + bar_linestyle : str, optional + Line style for edges. + shading_facecolor : str or RGBA matplotlib color tuple, optional. + Override the default alternating shading for specified categories. + shading_edgecolor : str or matplotlib color, optional + Set the edgecolor for bars, dots, and the line between dots. + shading_linewidth : int, optional + Line width in points for edges. + shading_linestyle : str, optional + Line style for edges. + """ + if isinstance(categories, str): + categories = [categories] + style = { + "bar_facecolor": bar_facecolor, + "bar_hatch": bar_hatch, + "bar_edgecolor": bar_edgecolor, + "bar_linewidth": bar_linewidth, + "bar_linestyle": bar_linestyle, + "shading_facecolor": shading_facecolor, + "shading_edgecolor": shading_edgecolor, + "shading_linewidth": shading_linewidth, + "shading_linestyle": shading_linestyle, + } + style = {k: v for k, v in style.items() if v is not None} + for category_name in categories: + self.category_styles.setdefault(category_name, {}).update(style) + def plot(self, fig=None): """Draw all parts of the plot onto fig or a new figure diff --git a/upsetplot/tests/test_upsetplot.py b/upsetplot/tests/test_upsetplot.py index 8d10f35..3c33dea 100644 --- a/upsetplot/tests/test_upsetplot.py +++ b/upsetplot/tests/test_upsetplot.py @@ -1177,6 +1177,51 @@ def test_style_subsets_artists(orientation): # matrix_line_collection = upset_axes["matrix"].collections[1] +@pytest.mark.parametrize( + ( + "kwarg_list", + "expected_category_styles", + ), + [ + # Different forms of including two categories + ( + [{"categories": ["cat1", "cat2"], "shading_facecolor": "red"}], + { + "cat1": {"shading_facecolor": "red"}, + "cat2": {"shading_facecolor": "red"}, + }, + ), + ( + [ + {"categories": ["cat1", "cat2"], "shading_facecolor": "red"}, + {"categories": "cat1", "shading_facecolor": "green"}, + ], + { + "cat1": {"shading_facecolor": "green"}, + "cat2": {"shading_facecolor": "red"}, + }, + ), + ( + [ + {"categories": ["cat1", "cat2"], "shading_facecolor": "red"}, + {"categories": "cat1", "shading_edgecolor": "green"}, + ], + { + "cat1": {"shading_facecolor": "red", "shading_edgecolor": "green"}, + "cat2": {"shading_facecolor": "red"}, + }, + ), + ], +) +def test_categories(kwarg_list, expected_category_styles): + data = generate_counts() + upset = UpSet(data, facecolor="blue") + for kw in kwarg_list: + upset.style_categories(**kw) + actual_category_styles = upset.category_styles + assert actual_category_styles == expected_category_styles + + def test_many_categories(): # Tests regressions against GH#193 n_cats = 250