Skip to content

Commit

Permalink
Merge branch 'size'
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed Feb 24, 2019
2 parents 9df00cf + d7a516d commit f83b941
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 5 deletions.
6 changes: 6 additions & 0 deletions examples/plot_generated.py
Expand Up @@ -13,7 +13,13 @@
print(example)

plot(example)
plt.title('Ordered by degree')
plt.show()

plot(example, sort_by='cardinality')
plt.title('Ordered by cardinality')
plt.show()

plot(example, show_counts='%d')
plt.title('With counts shown')
plt.show()
5 changes: 5 additions & 0 deletions examples/plot_vertical.py
Expand Up @@ -11,4 +11,9 @@

example = generate_data(aggregated=True)
plot(example, orientation='vertical')
plt.title('A vertical plot')
plt.show()

plot(example, orientation='vertical', show_counts='%d')
plt.title('A vertical plot with counts shown')
plt.show()
51 changes: 46 additions & 5 deletions upsetplot/plotting.py
Expand Up @@ -140,12 +140,17 @@ class UpSet:
totals_plot_elements : int
The totals plot should be large enough to fit this many matrix
elements.
show_counts : bool or str, default=False
Whether to label the intersection size bars with the cardinality
of the intersection. When a string, this formats the number.
For example, '%d' is equivalent to True.
"""

def __init__(self, data, orientation='horizontal', sort_by='degree',
sort_sets_by='cardinality', facecolor='black',
with_lines=True, element_size=32,
intersection_plot_elements=6, totals_plot_elements=2):
intersection_plot_elements=6, totals_plot_elements=2,
show_counts=''):

self._horizontal = orientation == 'horizontal'
self._reorient = _identity if self._horizontal else _transpose
Expand All @@ -154,6 +159,7 @@ def __init__(self, data, orientation='horizontal', sort_by='degree',
self._element_size = element_size
self._totals_plot_elements = totals_plot_elements
self._intersection_plot_elements = intersection_plot_elements
self._show_counts = show_counts

(self.intersections,
self.totals) = _process_data(data,
Expand Down Expand Up @@ -262,8 +268,11 @@ def plot_intersections(self, ax):
"""Plot bars indicating intersection size
"""
ax = self._reorient(ax)
ax.bar(np.arange(len(self.intersections)), self.intersections,
.5, color=self._facecolor, zorder=10, align='center')
rects = ax.bar(np.arange(len(self.intersections)), self.intersections,
.5, color=self._facecolor, zorder=10, align='center')

self._label_sizes(ax, rects, 'top' if self._horizontal else 'right')

ax.xaxis.set_visible(False)
for x in ['top', 'bottom', 'right']:
ax.spines[self._reorient(x)].set_visible(False)
Expand All @@ -273,13 +282,45 @@ def plot_intersections(self, ax):
tick_axis.set_label('Intersection size')
# tick_axis.set_tick_params(direction='in')

def _label_sizes(self, ax, rects, where):
if not self._show_counts:
return
fmt = '%d' if self._show_counts is True else self._show_counts
if where == 'right':
margin = 0.01 * abs(np.diff(ax.get_xlim()))
for rect in rects:
width = rect.get_width()
ax.text(width + margin,
rect.get_y() + rect.get_height() * .5,
fmt % width,
ha='left', va='center')
elif where == 'left':
margin = 0.01 * abs(np.diff(ax.get_xlim()))
for rect in rects:
width = rect.get_width()
ax.text(width + margin,
rect.get_y() + rect.get_height() * .5,
fmt % width,
ha='right', va='center')
elif where == 'top':
margin = 0.01 * abs(np.diff(ax.get_ylim()))
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() * .5,
height + margin, fmt % height,
ha='center', va='bottom')
else:
raise NotImplementedError('unhandled where: %r' % where)

def plot_totals(self, ax):
"""Plot bars indicating total set size
"""
orig_ax = ax
ax = self._reorient(ax)
ax.barh(np.arange(len(self.totals.index.values)), self.totals,
.5, color=self._facecolor, align='center')
rects = ax.barh(np.arange(len(self.totals.index.values)), self.totals,
.5, color=self._facecolor, align='center')
self._label_sizes(ax, rects, 'left' if self._horizontal else 'top')

max_total = self.totals.max()
if self._horizontal:
orig_ax.set_xlim(max_total, 0)
Expand Down
34 changes: 34 additions & 0 deletions upsetplot/tests/test_upsetplot.py
Expand Up @@ -133,3 +133,37 @@ def test_element_size():
assert figsize_before == figsize_after

# TODO: make sure axes are all within figure



def _walk_artists(el):
children = el.get_children()
yield el, children
for ch in children:
for x in _walk_artists(ch):
yield x


def _count_descendants(el):
return sum(len(children) for x, children in _walk_artists(el))


@pytest.mark.parametrize('orientation', ['horizontal', 'vertical'])
def test_show_counts(orientation):
fig = matplotlib.figure.Figure()
X = generate_data(n_samples=100)
plot(X, fig)
n_artists_no_sizes = _count_descendants(fig)

fig = matplotlib.figure.Figure()
plot(X, fig, show_counts=True)
n_artists_yes_sizes = _count_descendants(fig)
assert n_artists_yes_sizes - n_artists_no_sizes > 6

fig = matplotlib.figure.Figure()
plot(X, fig, show_counts='%0.2g')
assert n_artists_yes_sizes == _count_descendants(fig)

with pytest.raises(ValueError):
fig = matplotlib.figure.Figure()
plot(X, fig, show_counts='%0.2h')

0 comments on commit f83b941

Please sign in to comment.