Skip to content

Commit

Permalink
Merge pull request #19816 from QuLogic/fix-scatter-legend
Browse files Browse the repository at this point in the history
Fix legend of colour-mapped scatter plots.
  • Loading branch information
tacaswell committed Mar 31, 2021
2 parents 3bfe69c + ee8dc54 commit 6547ba2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
5 changes: 4 additions & 1 deletion lib/matplotlib/collections.py
Expand Up @@ -944,8 +944,11 @@ def update_from(self, other):

artist.Artist.update_from(self, other)
self._antialiaseds = other._antialiaseds
self._mapped_colors = other._mapped_colors
self._edge_is_mapped = other._edge_is_mapped
self._original_edgecolor = other._original_edgecolor
self._edgecolors = other._edgecolors
self._face_is_mapped = other._face_is_mapped
self._original_facecolor = other._original_facecolor
self._facecolors = other._facecolors
self._linewidths = other._linewidths
Expand All @@ -958,7 +961,7 @@ def update_from(self, other):
self._A = other._A
self.norm = other.norm
self.cmap = other.cmap
# do we need to copy self._update_dict? -JJL
self._update_dict = other._update_dict.copy()
self.stale = True


Expand Down
18 changes: 12 additions & 6 deletions lib/matplotlib/tests/test_axes.py
Expand Up @@ -52,14 +52,16 @@ def test_get_labels():
@check_figures_equal()
def test_label_loc_vertical(fig_test, fig_ref):
ax = fig_test.subplots()
sc = ax.scatter([1, 2], [1, 2], c=[1, 2])
sc = ax.scatter([1, 2], [1, 2], c=[1, 2], label='scatter')
ax.legend()
ax.set_ylabel('Y Label', loc='top')
ax.set_xlabel('X Label', loc='right')
cbar = fig_test.colorbar(sc)
cbar.set_label("Z Label", loc='top')

ax = fig_ref.subplots()
sc = ax.scatter([1, 2], [1, 2], c=[1, 2])
sc = ax.scatter([1, 2], [1, 2], c=[1, 2], label='scatter')
ax.legend()
ax.set_ylabel('Y Label', y=1, ha='right')
ax.set_xlabel('X Label', x=1, ha='right')
cbar = fig_ref.colorbar(sc)
Expand All @@ -69,14 +71,16 @@ def test_label_loc_vertical(fig_test, fig_ref):
@check_figures_equal()
def test_label_loc_horizontal(fig_test, fig_ref):
ax = fig_test.subplots()
sc = ax.scatter([1, 2], [1, 2], c=[1, 2])
sc = ax.scatter([1, 2], [1, 2], c=[1, 2], label='scatter')
ax.legend()
ax.set_ylabel('Y Label', loc='bottom')
ax.set_xlabel('X Label', loc='left')
cbar = fig_test.colorbar(sc, orientation='horizontal')
cbar.set_label("Z Label", loc='left')

ax = fig_ref.subplots()
sc = ax.scatter([1, 2], [1, 2], c=[1, 2])
sc = ax.scatter([1, 2], [1, 2], c=[1, 2], label='scatter')
ax.legend()
ax.set_ylabel('Y Label', y=0, ha='left')
ax.set_xlabel('X Label', x=0, ha='left')
cbar = fig_ref.colorbar(sc, orientation='horizontal')
Expand All @@ -88,14 +92,16 @@ def test_label_loc_rc(fig_test, fig_ref):
with matplotlib.rc_context({"xaxis.labellocation": "right",
"yaxis.labellocation": "top"}):
ax = fig_test.subplots()
sc = ax.scatter([1, 2], [1, 2], c=[1, 2])
sc = ax.scatter([1, 2], [1, 2], c=[1, 2], label='scatter')
ax.legend()
ax.set_ylabel('Y Label')
ax.set_xlabel('X Label')
cbar = fig_test.colorbar(sc, orientation='horizontal')
cbar.set_label("Z Label")

ax = fig_ref.subplots()
sc = ax.scatter([1, 2], [1, 2], c=[1, 2])
sc = ax.scatter([1, 2], [1, 2], c=[1, 2], label='scatter')
ax.legend()
ax.set_ylabel('Y Label', y=1, ha='right')
ax.set_xlabel('X Label', x=1, ha='right')
cbar = fig_ref.colorbar(sc, orientation='horizontal')
Expand Down

0 comments on commit 6547ba2

Please sign in to comment.