Skip to content

Commit

Permalink
Swtch lineplot to use errorbar for bars-style CIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed May 26, 2018
1 parent e933619 commit 1095bd2
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions seaborn/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .external.six import string_types

from . import utils
from .utils import categorical_order, get_color_cycle, sort_df
from .utils import categorical_order, get_color_cycle, ci_to_errsize, sort_df
from .algorithms import bootstrap
from .palettes import color_palette

Expand Down Expand Up @@ -604,46 +604,51 @@ def plot(self, ax, kws):
kws["marker"] = self.markers.get(style, orig_marker)
kws["linewidth"] = self.sizes.get(size, orig_linewidth)

# --- Draw the main line

line, = ax.plot([], [], **kws)
line_color = line.get_color()
line_alpha = line.get_alpha()
line_capstyle = line.get_solid_capstyle()
line.remove()

# --- Draw the main line

x, y = np.asarray(x), np.asarray(y)

if self.units is None:

line, = ax.plot(x.values, y.values, **kws)
# TODO standardize using asarray in this function
line, = ax.plot(x, y, **kws)

else:

for u in units.unique():
ax.plot(x[units == u].values, y[units == u].values, **kws)
rows = np.asarray(units == u)
ax.plot(x[rows], y[rows], **kws)

# --- Draw the confidence intervals

if y_ci is not None:

# TODO we want some way to get kwargs to the error plotters
low, high = np.asarray(y_ci["low"]), np.asarray(y_ci["high"])

if self.errstyle == "band":

ax.fill_between(x, y_ci["low"], y_ci["high"],
color=line_color, alpha=.2)
# TODO note that these are passed in as pandas objects
ax.fill_between(x, low, high, color=line_color, alpha=.2)

elif self.errstyle == "bars":

ci_xy = np.empty((len(x), 2, 2))
ci_xy[:, :, 0] = x[:, np.newaxis]
ci_xy[:, :, 1] = y_ci.values
lines = LineCollection(ci_xy,
color=line_color,
alpha=line_alpha)
try:
lines.set_capstyle(line_capstyle)
except AttributeError:
pass
ax.add_collection(lines)
ax.autoscale_view()
y_err = ci_to_errsize((low, high), y)
ebars = ax.errorbar(x, y, y_err, linestyle="",
color=line_color, alpha=line_alpha)

for obj in ebars.get_children():
try:
obj.set_capstyle(line_capstyle)
except AttributeError:
# Does not exist on mpl < 2.2
pass

else:
err = "`errstyle` must by 'band' or 'bars', not {}"
Expand Down Expand Up @@ -1012,7 +1017,7 @@ def lineplot(x=None, y=None, hue=None, size=None, style=None, data=None,


lineplot.__doc__ = dedent("""\
Draw a plot with numeric x and y values where the points are connected.
Draw a line plot with up to several semantic groupings.
{main_api_narrative}
Expand Down

0 comments on commit 1095bd2

Please sign in to comment.