Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implement legend (colorbar) for noncategorical dataframe plots #172

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 15 additions & 6 deletions geopandas/plotting.py
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from six import next
from six.moves import xrange
from matplotlib.colorbar import make_axes


def plot_polygon(ax, poly, facecolor='red', edgecolor='black', alpha=0.5):
Expand Down Expand Up @@ -153,8 +154,7 @@ def plot_dataframe(s, column=None, colormap=None, alpha=0.5,
lines or points.

legend : bool (default False)
Plot a legend (Experimental; currently for categorical
plots only)
Plot a legend or colorbar (Experimental)

axes : matplotlib.pyplot.Artist (default None)
axes on which to draw the plot
Expand All @@ -169,7 +169,11 @@ def plot_dataframe(s, column=None, colormap=None, alpha=0.5,
Returns
-------

matplotlib axes instance
ax : matplotlib axes instance
Axes on which the dataframe was plotted
cbar : matplotlib.colorbar.Colorbar (optional)
Colorbar of the noncategorical plot. Returned only if `categorical`
is False and `legend` is True.
"""
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -207,6 +211,7 @@ def plot_dataframe(s, column=None, colormap=None, alpha=0.5,
# TODO: color point geometries
elif geom.type == 'Point':
plot_point(ax, geom)
cbar = None
if legend:
if categorical:
patches = []
Expand All @@ -216,10 +221,13 @@ def plot_dataframe(s, column=None, colormap=None, alpha=0.5,
markersize=10, markerfacecolor=cmap.to_rgba(value)))
ax.legend(patches, categories, numpoints=1, loc='best')
else:
# TODO: show a colorbar
raise NotImplementedError
cax = make_axes(ax)[0]
cbar = ax.get_figure().colorbar(cmap, cax=cax)
plt.draw()
return ax
if cbar:
return ax, cbar
else:
return ax


def __pysal_choro(values, scheme, k=5):
Expand Down Expand Up @@ -296,6 +304,7 @@ def norm_cmap(values, cmap, normalize, cm):
mn, mx = min(values), max(values)
norm = normalize(vmin=mn, vmax=mx)
n_cmap = cm.ScalarMappable(norm=norm, cmap=cmap)
n_cmap.set_array(values)
return n_cmap


Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline_images/test_plotting/df_plot.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline_images/test_plotting/lines_plot.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline_images/test_plotting/points_plot.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline_images/test_plotting/poly_plot.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 46 additions & 2 deletions tests/test_plotting.py
Expand Up @@ -8,23 +8,38 @@
import matplotlib
matplotlib.use('Agg', warn=False)
from matplotlib.pyplot import Artist, savefig, clf
from matplotlib.colorbar import Colorbar
from matplotlib.backends import backend_agg
from matplotlib.testing.noseclasses import ImageComparisonFailure
from matplotlib.testing.compare import compare_images
from shapely.geometry import Polygon, LineString, Point
from six.moves import xrange

from geopandas import GeoSeries
from geopandas import GeoSeries, read_file
from .util import download_nybb


# If set to True, generate images rather than perform tests (all tests will pass!)
GENERATE_BASELINE = False

BASELINE_DIR = os.path.join(os.path.dirname(__file__), 'baseline_images', 'test_plotting')


class PlotTests(unittest.TestCase):

def setUp(self):
# hardcode settings for comparison tests
# settings adapted from ggplot test suite
matplotlib.rcdefaults() # Start with all defaults
matplotlib.rcParams['text.hinting'] = True
matplotlib.rcParams['text.antialiased'] = True
matplotlib.rcParams['font.sans-serif'] = 'Bitstream Vera Sans'
backend_agg.RendererAgg._fontd.clear()

nybb_filename = download_nybb()

self.df = read_file('/nybb_14a_av/nybb.shp',
vfs='zip://' + nybb_filename)
self.df['values'] = [0.1, 0.2, 0.1, 0.3, 0.4]
self.tempdir = tempfile.mkdtemp()
return

Expand Down Expand Up @@ -74,5 +89,34 @@ def test_line_plot(self):
ax = lines.plot()
self._compare_images(ax=ax, filename=filename)

def test_dataframe_plot(self):
""" Test plotting of a dataframe """
clf()
filename = 'df_plot.png'
ax = self.df.plot()
self._compare_images(ax=ax, filename=filename)

def test_dataframe_categorical_plot(self):
""" Test plotting of a categorical GeoDataFrame with legend """
clf()
filename = 'df_cat_leg_plot.png'
ax = self.df.plot(column='values', categorical=True, legend=True)
self._compare_images(ax=ax, filename=filename)

def test_dataframe_noncategorical_plot(self):
""" Test plotting of a noncategorical GeoDataFrame"""
clf()
filename = 'df_noncat_plot.png'
ax = self.df.plot(column='values', categorical=False)
self._compare_images(ax=ax, filename=filename)

def test_dataframe_noncategorical_leg_plot(self):
""" Test plotting of a noncategorical GeoDataFrame"""
clf()
filename = 'df_noncat_leg_plot.png'
ax, cbar = self.df.plot(column='values', categorical=False, legend=True)
self._compare_images(ax=ax, filename=filename)
self.assertTrue(isinstance(cbar, Colorbar))

if __name__ == '__main__':
unittest.main()