Skip to content

Commit

Permalink
Scale the default scatterplot edge width by the point radius (#2078)
Browse files Browse the repository at this point in the history
* Scale the default scatterplot edge width by the point radius

* Reorder operations in scatterplot plot
  • Loading branch information
mwaskom committed May 17, 2020
1 parent 92f160a commit 7f55378
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 19 deletions.
2 changes: 2 additions & 0 deletions doc/releases/v0.11.0.txt
Expand Up @@ -12,6 +12,8 @@ v0.11.0 (Unreleased)

- Added a ``tight_layout`` method to :class:`FacetGrid` and :class:`PairGrid`, which runs the :func:`matplotlib.pyplot.tight_layout` algorithm without interference from the external legend. GH2073

- Changed how :func:`scatterplot` sets the default linewidth for the edges of the scatter points to scale with the point size themselves (on a plot-wise, not point-wise basis). This change also slightly reduces the default width when point sizes are not varied. Set ``linewidth=0.75`` to repoduce the previous behavior. GH2078

- Added an explicit warning in :func:`swarmplot` when more than 2% of the points are overlap in the "gutters" of the swarm. GH2045

- Added the ``axes_dict`` attribute to :class:`FacetGrid` for named access to the component axes. GH2046
Expand Down
43 changes: 24 additions & 19 deletions seaborn/relational.py
Expand Up @@ -774,42 +774,47 @@ def plot(self, ax, kws):

kws.pop("color", None) # TODO is this optimal?

kws.setdefault("linewidth", .75) # TODO scale with marker size?
kws.setdefault("edgecolor", "w")

if self.markers:
# Use a representative marker so scatter sets the edgecolor
# properly for line art markers. We currently enforce either
# all or none line art so this works.
example_marker = list(self.markers.values())[0]
kws.setdefault("marker", example_marker)

# TODO this makes it impossible to vary alpha with hue which might
# otherwise be useful? Should we just pass None?
kws["alpha"] = 1 if self.alpha == "auto" else self.alpha

# Assign arguments for plt.scatter and draw the plot
# --- Determine the visual attributes of the plot

data = self.plot_data[list(self.variables)].dropna()
if not data.size:
return

# Define the vectors of x and y positions
x = data.get(["x"], np.full(len(data), np.nan))
y = data.get(["y"], np.full(len(data), np.nan))

# Define vectors of hue and size values
# There must be some reason this doesn't use data[var].map(attr_dict)
# But I do not remember what it is!
if self.palette:
c = [self.palette.get(val) for val in data["hue"]]

if self.sizes:
s = [self.sizes.get(val) for val in data["size"]]

# Set defaults for other visual attributres
kws.setdefault("linewidth", .08 * np.sqrt(np.percentile(s, 10)))
kws.setdefault("edgecolor", "w")

if self.markers:
# Use a representative marker so scatter sets the edgecolor
# properly for line art markers. We currently enforce either
# all or none line art so this works.
example_marker = list(self.markers.values())[0]
kws.setdefault("marker", example_marker)

# TODO this makes it impossible to vary alpha with hue which might
# otherwise be useful? Should we just pass None?
kws["alpha"] = 1 if self.alpha == "auto" else self.alpha

# Draw the scatter plot
args = np.asarray(x), np.asarray(y), np.asarray(s), np.asarray(c)
points = ax.scatter(*args, **kws)

# Update the paths to get different marker shapes. This has to be
# done here because plt.scatter allows varying sizes and colors
# but only a single marker shape per call.

# Update the paths to get different marker shapes.
# This has to be done here because ax.scatter allows varying sizes
# and colors but only a single marker shape per call.
if self.paths:
p = [self.paths.get(val) for val in data["style"]]
points.set_paths(p)
Expand Down
36 changes: 36 additions & 0 deletions seaborn/tests/test_relational.py
Expand Up @@ -2062,6 +2062,42 @@ def test_scatterplot_axes(self, wide_df):
ax = scatterplot(data=wide_df, ax=ax1)
assert ax is ax1

def test_linewidths(self, long_df):

f, ax = plt.subplots()

scatterplot(data=long_df, x="x", y="y", s=10)
scatterplot(data=long_df, x="x", y="y", s=20)
points1, points2 = ax.collections
assert (
points1.get_linewidths().item() < points2.get_linewidths().item()
)

# These tests don't work because changes in matplotlib casue an error
# when we draw the scount with non-scalar s or c
"""
ax.clear()
scatterplot(data=long_df, x="x", y="y", s=long_df["x"])
scatterplot(data=long_df, x="x", y="y", s=long_df["x"] * 2)
points1, points2 = ax.collections
assert (
points1.get_linewidths().item() < points2.get_linewidths().item()
)
"""

ax.clear()
scatterplot(data=long_df, x="x", y="y", size=long_df["x"])
scatterplot(data=long_df, x="x", y="y", size=long_df["x"] * 2)
points1, points2, *_ = ax.collections
assert (
points1.get_linewidths().item() < points2.get_linewidths().item()
)

ax.clear()
lw = 2
scatterplot(data=long_df, x="x", y="y", linewidth=lw)
assert ax.collections[0].get_linewidths().item() == lw

def test_scatterplot_smoke(
self,
wide_df, wide_array,
Expand Down

0 comments on commit 7f55378

Please sign in to comment.