Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End caps on barplot confidence intervals #606 update #898

Merged
merged 13 commits into from
Apr 24, 2016
104 changes: 82 additions & 22 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,31 +1498,61 @@ def estimate_statistic(self, estimator, ci, n_boot):
self.value_label = "{}({})".format(estimator.__name__,
self.value_label)

def draw_confints(self, ax, at_group, confint, colors, **kws):

kws.setdefault("lw", mpl.rcParams["lines.linewidth"] * 1.8)
def draw_confints(self,
ax, at_group,
confint,
colors,
conf_lw=None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change here and in the external interface to errwidth for consistency with errcolor.

capsize=None,
**kws):

if conf_lw:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are testing against None, do so explicitly.

kws.setdefault("lw", mpl.rcParams["lines.linewidth"] * conf_lw)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way of doing this is confusing because conf_lw sounds like you are specifying a width value, but it's really a scaling factor. I would just set the linewidth directly using this parameter.

else:
kws.setdefault("lw", mpl.rcParams["lines.linewidth"] * 1.8)

for at, (ci_low, ci_high), color in zip(at_group,
confint,
colors):
if self.orient == "v":
ax.plot([at, at], [ci_low, ci_high], color=color, **kws)
else:
ax.plot([ci_low, ci_high], [at, at], color=color, **kws)
if capsize:
for at, (ci_low, ci_high), color in zip(at_group,
confint,
colors):
if self.orient == "v":
ax.plot([at, at], [ci_low, ci_high], color=color, **kws)
ax.plot([at - capsize / 2, at + capsize / 2],
[ci_low, ci_low], color=color, **kws)
ax.plot([at - capsize / 2, at + capsize / 2],
[ci_high, ci_high], color=color, **kws)
else:
ax.plot([ci_low, ci_high], [at, at], color=color, **kws)
ax.plot([ci_low, ci_low],
[at - capsize / 2, at + capsize / 2],
color=color, **kws)
ax.plot([ci_high, ci_high],
[at - capsize / 2, at + capsize / 2],
color=color, **kws)
else:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This else clause is duplicating code. I would restructure to write the loop logic once and handle the capsize logic inside the loop.

for at, (ci_low, ci_high), color in zip(at_group, confint, colors):
if self.orient == "v":
ax.plot([at, at], [ci_low, ci_high], color=color, **kws)
else:
ax.plot([ci_low, ci_high], [at, at], color=color, **kws)


class _BarPlotter(_CategoricalStatPlotter):
"""Show point estimates and confidence intervals with bars."""

def __init__(self, x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
orient, color, palette, saturation, errcolor):
orient, color, palette, saturation, errcolor, conf_lw=None,
capsize=None):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient,
order, hue_order, units)
self.establish_colors(color, palette, saturation)
self.estimate_statistic(estimator, ci, n_boot)

self.errcolor = errcolor
self.conf_lw = conf_lw
self.capsize = capsize

def draw_bars(self, ax, kws):
"""Draw the bars onto `ax`."""
Expand All @@ -1538,7 +1568,12 @@ def draw_bars(self, ax, kws):

# Draw the confidence intervals
errcolors = [self.errcolor] * len(barpos)
self.draw_confints(ax, barpos, self.confint, errcolors)
self.draw_confints(ax,
barpos,
self.confint,
errcolors,
self.conf_lw,
self.capsize)

else:

Expand All @@ -1554,7 +1589,12 @@ def draw_bars(self, ax, kws):
if self.confint.size:
confint = self.confint[:, j]
errcolors = [self.errcolor] * len(offpos)
self.draw_confints(ax, offpos, confint, errcolors)
self.draw_confints(ax,
offpos,
confint,
errcolors,
self.conf_lw,
self.capsize)

def plot(self, ax, bar_kws):
"""Make the plot."""
Expand All @@ -1569,7 +1609,7 @@ class _PointPlotter(_CategoricalStatPlotter):
def __init__(self, x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
markers, linestyles, dodge, join, scale,
orient, color, palette):
orient, color, palette, conf_lw=None, capsize=None):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient,
order, hue_order, units)
Expand Down Expand Up @@ -1602,6 +1642,8 @@ def __init__(self, x, y, hue, data, order, hue_order,
self.dodge = dodge
self.join = join
self.scale = scale
self.conf_lw = conf_lw
self.capsize = capsize

@property
def hue_offsets(self):
Expand Down Expand Up @@ -1634,8 +1676,8 @@ def draw_points(self, ax):
color=color, ls=ls, lw=lw)

# Draw the confidence intervals
self.draw_confints(ax, pointpos, self.confint, self.colors, lw=lw)

self.draw_confints(ax, pointpos, self.confint, self.colors,
self.conf_lw, self.capsize)
# Draw the estimate points
marker = self.markers[0]
if self.orient == "h":
Expand Down Expand Up @@ -1675,7 +1717,8 @@ def draw_points(self, ax):
confint = self.confint[:, j]
errcolors = [self.colors[j]] * len(offpos)
self.draw_confints(ax, offpos, confint, errcolors,
zorder=z, lw=lw)
self.conf_lw, self.capsize,
zorder=z)

# Draw the estimate points
marker = self.markers[j]
Expand Down Expand Up @@ -2025,6 +2068,17 @@ def plot(self, ax, boxplot_kws):
``1`` if you want the plot colors to perfectly match the input color
spec.\
"""),
capsize=dedent("""\
capsize : float, optional
Length of caps on confidence interval (drawn perpendicular to
primary line. If 0.0 (default), no caps will be drawn.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your'e missing a closing parens.

Typical values are between 0.03 and 0.1.\
"""),
conf_lw=dedent("""\
conf_lw : float, optional
Thickness of lines draw for the confidence interval (and caps).
Default is 1.8.\
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't specify the default value in the docstring.

"""),
width=dedent("""\
width : float, optional
Width of a full element when not using hue nesting, or width of all the
Expand Down Expand Up @@ -2074,6 +2128,9 @@ def plot(self, ax, boxplot_kws):
lvplot=dedent("""\
lvplot : An extension of the boxplot for long-tailed and large data sets.
"""),



)

_categorical_docs.update(_facet_docs)
Expand Down Expand Up @@ -2831,7 +2888,7 @@ def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
estimator=np.mean, ci=95, n_boot=1000, units=None,
orient=None, color=None, palette=None, saturation=.75,
errcolor=".26", ax=None, **kwargs):
errcolor=".26", conf_lw=None, capsize=None, ax=None, **kwargs):

# Handle some deprecated arguments
if "hline" in kwargs:
Expand All @@ -2850,7 +2907,7 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
plotter = _BarPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
orient, color, palette, saturation,
errcolor)
errcolor, conf_lw, capsize)

if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -2894,6 +2951,8 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
errcolor : matplotlib color
Color for the lines that represent the confidence interval.
{ax_in}
{conf_lw}
{capsize}
kwargs : key, value mappings
Other keyword arguments are passed through to ``plt.bar`` at draw
time.
Expand Down Expand Up @@ -2989,7 +3048,8 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
estimator=np.mean, ci=95, n_boot=1000, units=None,
markers="o", linestyles="-", dodge=False, join=True, scale=1,
orient=None, color=None, palette=None, ax=None, **kwargs):
orient=None, color=None, palette=None, ax=None, conf_lw=None,
capsize=None, **kwargs):

# Handle some deprecated arguments
if "hline" in kwargs:
Expand All @@ -3008,7 +3068,7 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
plotter = _PointPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
markers, linestyles, dodge, join, scale,
orient, color, palette)
orient, color, palette, conf_lw, capsize)

if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -3187,7 +3247,7 @@ def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
elif x is not None and y is not None:
raise TypeError("Cannot pass values for both `x` and `y`")
else:
raise TypeError("Must pass valus for either `x` or `y`")
raise TypeError("Must pass values for either `x` or `y`")

plotter = _BarPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
Expand Down