Skip to content

Commit

Permalink
Fix a regression in show_percentages / total calculation (#248)
Browse files Browse the repository at this point in the history
Fix #226

Fix #223
  • Loading branch information
jnothman committed Dec 28, 2023
1 parent 3bf5ad5 commit 7a6387f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
What's new in version 0.9
-------------------------
- Ability to disable totals plot with `totals_plot_elements=0`.

- Fixes a bug where ``show_percentages`` used the incorrect denominator if
filtering (e.g. ``min_subset_size``) was applied. This bug was a regression
introduced in version 0.7. (:issue:`248`)
- Ability to disable totals plot with `totals_plot_elements=0`. (:issue:`246`)
- Ability to set totals y axis label (:issue:`243`)

What's new in version 0.8
-------------------------
Expand All @@ -14,7 +19,7 @@ What's new in version 0.8
- Added `subsets` attribute to QueryResult. (:issue:`198`)
- Fixed a bug where more than 64 categories could result in an error. (:issue:`193`)

Patch release 0.8.1 handles deprecations in dependencies.
Patch release 0.8.2 handles deprecations in dependencies.

What's new in version 0.7
-------------------------
Expand Down
6 changes: 1 addition & 5 deletions upsetplot/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def _process_data(

df = results.data
agg = results.subset_sizes
totals = results.category_totals
total = agg.sum()

# add '_bin' to df indicating index in agg
# XXX: ugly!
Expand All @@ -75,7 +73,7 @@ def _pack_binary(X):
if reverse:
agg = agg[::-1]

return total, df, agg, totals
return results.total, df, agg, results.category_totals


def _multiply_alpha(c, mult):
Expand Down Expand Up @@ -673,8 +671,6 @@ def make_grid(self, fig=None):
fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)

text_nelems = int(np.ceil(figw / colw - non_text_nelems))
# print('textw', textw, 'figw', figw, 'colw', colw,
# 'ncols', figw/colw, 'text_nelems', text_nelems)

GS = self._reorient(matplotlib.gridspec.GridSpec)
gridspec = GS(
Expand Down
30 changes: 18 additions & 12 deletions upsetplot/reformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,20 @@ class QueryResult:
for `data`.
category_totals : Series
Total size of each category, regardless of selection.
total : number
Total number of samples / sum of value
"""

def __init__(self, data, subset_sizes, category_totals):
def __init__(self, data, subset_sizes, category_totals, total):
self.data = data
self.subset_sizes = subset_sizes
self.category_totals = category_totals
self.total = total

def __repr__(self):
return (
"QueryResult(data={data}, subset_sizes={subset_sizes}, "
"category_totals={category_totals}".format(**vars(self))
"category_totals={category_totals}, total={total}".format(**vars(self))
)

@property
Expand Down Expand Up @@ -267,7 +270,7 @@ def query(
-------
QueryResult
Including filtered ``data``, filtered and sorted ``subset_sizes`` and
overall ``category_totals``.
overall ``category_totals`` and ``total``.
Examples
--------
Expand Down Expand Up @@ -322,11 +325,12 @@ def query(

data, agg = _aggregate_data(data, subset_size, sum_over)
data = _check_index(data)
totals = [
grand_total = agg.sum()
category_totals = [
agg[agg.index.get_level_values(name).values.astype(bool)].sum()
for name in agg.index.names
]
totals = pd.Series(totals, index=agg.index.names)
category_totals = pd.Series(category_totals, index=agg.index.names)

if include_empty_subsets:
nlevels = len(agg.index.levels)
Expand Down Expand Up @@ -358,15 +362,17 @@ def query(

# sort:
if sort_categories_by in ("cardinality", "-cardinality"):
totals.sort_values(ascending=sort_categories_by[:1] == "-", inplace=True)
category_totals.sort_values(
ascending=sort_categories_by[:1] == "-", inplace=True
)
elif sort_categories_by == "-input":
totals = totals[::-1]
category_totals = category_totals[::-1]
elif sort_categories_by in (None, "input"):
pass
else:
raise ValueError("Unknown sort_categories_by: %r" % sort_categories_by)
data = data.reorder_levels(totals.index.values)
agg = agg.reorder_levels(totals.index.values)
data = data.reorder_levels(category_totals.index.values)
agg = agg.reorder_levels(category_totals.index.values)

if sort_by in ("cardinality", "-cardinality"):
agg = agg.sort_values(ascending=sort_by[:1] == "-")
Expand All @@ -380,12 +386,12 @@ def query(
pd.MultiIndex.from_tuples(index_tuples, names=agg.index.names)
)
elif sort_by == "-input":
print("<", agg)
agg = agg[::-1]
print(">", agg)
elif sort_by in (None, "input"):
pass
else:
raise ValueError("Unknown sort_by: %r" % sort_by)

return QueryResult(data=data, subset_sizes=agg, category_totals=totals)
return QueryResult(
data=data, subset_sizes=agg, category_totals=category_totals, total=grand_total
)
1 change: 1 addition & 0 deletions upsetplot/tests/test_upsetplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ def test_filter_subsets(filter_params, expected, sort_by):
)
# category totals should not be affected
assert_series_equal(upset_full.totals, upset_filtered.totals)
assert upset_full.total == pytest.approx(upset_filtered.total)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 7a6387f

Please sign in to comment.