Skip to content

Commit

Permalink
Backport PR #19812: FIX: size and color rendering for Path3DCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
QuLogic committed Mar 31, 2021
1 parent 6547ba2 commit 040649d
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 151 deletions.
257 changes: 106 additions & 151 deletions lib/mpl_toolkits/mplot3d/art3d.py
Expand Up @@ -302,8 +302,6 @@ def do_3d_projection(self, renderer=None):
"""
Project the points according to renderer matrix.
"""
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)
xyslist = [proj3d.proj_trans_points(points, self.axes.M)
for points in self._segments3d]
segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
Expand Down Expand Up @@ -448,16 +446,6 @@ def set_depthshade(self, depthshade):
self._depthshade = depthshade
self.stale = True

def set_facecolor(self, c):
# docstring inherited
super().set_facecolor(c)
self._facecolor3d = self.get_facecolor()

def set_edgecolor(self, c):
# docstring inherited
super().set_edgecolor(c)
self._edgecolor3d = self.get_edgecolor()

def set_sort_zpos(self, val):
"""Set the position to use for z-sorting."""
self._sort_zpos = val
Expand All @@ -474,34 +462,43 @@ def set_3d_properties(self, zs, zdir):
xs = []
ys = []
self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir)
self._facecolor3d = self.get_facecolor()
self._edgecolor3d = self.get_edgecolor()
self._vzs = None
self.stale = True

@_api.delete_parameter('3.4', 'renderer')
def do_3d_projection(self, renderer=None):
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)

fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else
self._facecolor3d)
fcs = mcolors.to_rgba_array(fcs, self._alpha)
super().set_facecolor(fcs)

ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else
self._edgecolor3d)
ecs = mcolors.to_rgba_array(ecs, self._alpha)
super().set_edgecolor(ecs)
self._vzs = vzs
super().set_offsets(np.column_stack([vxs, vys]))

if vzs.size > 0:
return min(vzs)
else:
return np.nan

def _maybe_depth_shade_and_sort_colors(self, color_array):
color_array = (
_zalpha(color_array, self._vzs)
if self._vzs is not None and self._depthshade
else color_array
)
if len(color_array) > 1:
color_array = color_array[self._z_markers_idx]
return mcolors.to_rgba_array(color_array, self._alpha)

def get_facecolor(self):
return self._maybe_depth_shade_and_sort_colors(super().get_facecolor())

def get_edgecolor(self):
# We need this check here to make sure we do not double-apply the depth
# based alpha shading when the edge color is "face" which means the
# edge colour should be identical to the face colour.
if cbook._str_equal(self._edgecolors, 'face'):
return self.get_facecolor()
return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor())


class Path3DCollection(PathCollection):
"""
Expand All @@ -525,9 +522,14 @@ def __init__(self, *args, zs=0, zdir='z', depthshade=True, **kwargs):
This is typically desired in scatter plots.
"""
self._depthshade = depthshade
self._in_draw = False
super().__init__(*args, **kwargs)
self.set_3d_properties(zs, zdir)

def draw(self, renderer):
with cbook._setattr_cm(self, _in_draw=True):
super().draw(renderer)

def set_sort_zpos(self, val):
"""Set the position to use for z-sorting."""
self._sort_zpos = val
Expand All @@ -544,12 +546,37 @@ def set_3d_properties(self, zs, zdir):
xs = []
ys = []
self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir)
self._facecolor3d = self.get_facecolor()
self._edgecolor3d = self.get_edgecolor()
self._sizes3d = self.get_sizes()
self._linewidth3d = self.get_linewidth()
# In the base draw methods we access the attributes directly which
# means we can not resolve the shuffling in the getter methods like
# we do for the edge and face colors.
#
# This means we need to carry around a cache of the unsorted sizes and
# widths (postfixed with 3d) and in `do_3d_projection` set the
# depth-sorted version of that data into the private state used by the
# base collection class in its draw method.
#
# Grab the current sizes and linewidths to preserve them.
self._sizes3d = self._sizes
self._linewidths3d = self._linewidths
xs, ys, zs = self._offsets3d

# Sort the points based on z coordinates
# Performance optimization: Create a sorted index array and reorder
# points and point properties according to the index array
self._z_markers_idx = slice(-1)
self._vzs = None
self.stale = True

def set_sizes(self, sizes, dpi=72.0):
super().set_sizes(sizes, dpi)
if not self._in_draw:
self._sizes3d = sizes

def set_linewidth(self, lw):
super().set_linewidth(lw)
if not self._in_draw:
self._linewidth3d = lw

def get_depthshade(self):
return self._depthshade

Expand All @@ -566,142 +593,57 @@ def set_depthshade(self, depthshade):
self._depthshade = depthshade
self.stale = True

def set_facecolor(self, c):
# docstring inherited
super().set_facecolor(c)
self._facecolor3d = self.get_facecolor()

def set_edgecolor(self, c):
# docstring inherited
super().set_edgecolor(c)
self._edgecolor3d = self.get_edgecolor()

def set_sizes(self, sizes, dpi=72.0):
# docstring inherited
super().set_sizes(sizes, dpi=dpi)
self._sizes3d = self.get_sizes()

def set_linewidth(self, lw):
# docstring inherited
super().set_linewidth(lw)
self._linewidth3d = self.get_linewidth()

@_api.delete_parameter('3.4', 'renderer')
def do_3d_projection(self, renderer=None):
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)

fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else
self._facecolor3d)
ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else
self._edgecolor3d)
sizes = self._sizes3d
lws = self._linewidth3d

# Sort the points based on z coordinates
# Performance optimization: Create a sorted index array and reorder
# points and point properties according to the index array
z_markers_idx = np.argsort(vzs)[::-1]
z_markers_idx = self._z_markers_idx = np.argsort(vzs)[::-1]
self._vzs = vzs

# we have to special case the sizes because of code in collections.py
# as the draw method does
# self.set_sizes(self._sizes, self.figure.dpi)
# so we can not rely on doing the sorting on the way out via get_*

if len(self._sizes3d) > 1:
self._sizes = self._sizes3d[z_markers_idx]

if len(self._linewidths3d) > 1:
self._linewidths = self._linewidths3d[z_markers_idx]

# Re-order items
vzs = vzs[z_markers_idx]
vxs = vxs[z_markers_idx]
vys = vys[z_markers_idx]
if len(fcs) > 1:
fcs = fcs[z_markers_idx]
if len(ecs) > 1:
ecs = ecs[z_markers_idx]
if len(sizes) > 1:
sizes = sizes[z_markers_idx]
if len(lws) > 1:
lws = lws[z_markers_idx]
vps = np.column_stack((vxs, vys))

fcs = mcolors.to_rgba_array(fcs, self._alpha)
ecs = mcolors.to_rgba_array(ecs, self._alpha)

super().set_edgecolor(ecs)
super().set_facecolor(fcs)
super().set_sizes(sizes)
super().set_linewidth(lws)

PathCollection.set_offsets(self, vps)

return np.min(vzs) if vzs.size else np.nan
PathCollection.set_offsets(self, np.column_stack((vxs, vys)))

return np.min(vzs) if vzs.size else np.nan

def _update_scalarmappable(sm):
"""
Update a 3D ScalarMappable.
With ScalarMappable objects if the data, colormap, or norm are
changed, we need to update the computed colors. This is handled
by the base class method update_scalarmappable. This method works
by detecting if work needs to be done, and if so stashing it on
the ``self._facecolors`` attribute.
With 3D collections we internally sort the components so that
things that should be "in front" are rendered later to simulate
having a z-buffer (in addition to doing the projections). This is
handled in the ``do_3d_projection`` methods which are called from the
draw method of the 3D Axes. These methods:
- do the projection from 3D -> 2D
- internally sort based on depth
- stash the results of the above in the 2D analogs of state
- return the z-depth of the whole artist
the last step is so that we can, at the Axes level, sort the children by
depth.
The base `draw` method of the 2D artists unconditionally calls
update_scalarmappable and rely on the method's internal caching logic to
lazily evaluate.
These things together mean you can have the sequence of events:
- we create the artist, do the color mapping and stash the results
in a 3D specific state.
- change something about the ScalarMappable that marks it as in
need of an update (`ScalarMappable.changed` and friends).
- We call do_3d_projection and shuffle the stashed colors into the
2D version of face colors
- the draw method calls the update_scalarmappable method which
overwrites our shuffled colors
- we get a render that is wrong
- if we re-render (either with a second save or implicitly via
tight_layout / constrained_layout / bbox_inches='tight' (ex via
inline's defaults)) we again shuffle the 3D colors
- because the CM is not marked as changed update_scalarmappable is
a no-op and we get a correct looking render.
This function is an internal helper to:
- sort out if we need to do the color mapping at all (has data!)
- sort out if update_scalarmappable is going to be a no-op
- copy the data over from the 2D -> 3D version
This must be called first thing in do_3d_projection to make sure that
the correct colors get shuffled.
def _maybe_depth_shade_and_sort_colors(self, color_array):
color_array = (
_zalpha(color_array, self._vzs)
if self._vzs is not None and self._depthshade
else color_array
)
if len(color_array) > 1:
color_array = color_array[self._z_markers_idx]
return mcolors.to_rgba_array(color_array, self._alpha)

Parameters
----------
sm : ScalarMappable
The ScalarMappable to update and stash the 3D data from
def get_facecolor(self):
return self._maybe_depth_shade_and_sort_colors(super().get_facecolor())

"""
if sm._A is None:
return
copy_state = sm._update_dict['array']
ret = sm.update_scalarmappable()
if copy_state:
if sm._face_is_mapped:
sm._facecolor3d = sm._facecolors
elif sm._edge_is_mapped: # Should this be plain "if"?
sm._edgecolor3d = sm._edgecolors
def get_edgecolor(self):
# We need this check here to make sure we do not double-apply the depth
# based alpha shading when the edge color is "face" which means the
# edge colour should be identical to the face colour.
if cbook._str_equal(self._edgecolors, 'face'):
return self.get_facecolor()
return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor())


def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
Expand All @@ -727,6 +669,7 @@ def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
elif isinstance(col, PatchCollection):
col.__class__ = Patch3DCollection
col._depthshade = depthshade
col._in_draw = False
col.set_3d_properties(zs, zdir)


Expand Down Expand Up @@ -841,9 +784,21 @@ def do_3d_projection(self, renderer=None):
"""
Perform the 3D projection for this object.
"""
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)

if self._A is not None:
# force update of color mapping because we re-order them
# below. If we do not do this here, the 2D draw will call
# this, but we will never port the color mapped values back
# to the 3D versions.
#
# We hold the 3D versions in a fixed order (the order the user
# passed in) and sort the 2D version by view depth.
copy_state = self._update_dict['array']
self.update_scalarmappable()
if copy_state:
if self._face_is_mapped:
self._facecolor3d = self._facecolors
if self._edge_is_mapped:
self._edgecolor3d = self._edgecolors
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, self.axes.M)
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]

Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 13 additions & 0 deletions lib/mpl_toolkits/tests/test_mplot3d.py
Expand Up @@ -1427,3 +1427,16 @@ def test_subfigure_simple():
sf = fig.subfigures(1, 2)
ax = sf[0].add_subplot(1, 1, 1, projection='3d')
ax = sf[1].add_subplot(1, 1, 1, projection='3d', label='other')


@image_comparison(baseline_images=['scatter_spiral.png'],
remove_text=True,
style='default')
def test_scatter_spiral():
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
th = np.linspace(0, 2 * np.pi * 6, 256)
sc = ax.scatter(np.sin(th), np.cos(th), th, s=(1 + th * 5), c=th ** 2)

# force at least 1 draw!
fig.canvas.draw()

0 comments on commit 040649d

Please sign in to comment.