diff --git a/mplotutils/cartopy_utils.py b/mplotutils/cartopy_utils.py index 37427c6..d7dc7cd 100644 --- a/mplotutils/cartopy_utils.py +++ b/mplotutils/cartopy_utils.py @@ -1,7 +1,9 @@ +import warnings + import cartopy.crs as ccrs import matplotlib.pyplot as plt import numpy as np -import shapely.geometry as sgeom +import shapely.geometry from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER from .colormaps import _get_label_attr @@ -123,7 +125,7 @@ def ylabel_map(s, labelpad=None, size=None, weight=None, y=0.5, ax=None, **kwarg rotation_mode=rotation_mode, size=size, weight=weight, - **kwargs + **kwargs, ) return h @@ -188,7 +190,7 @@ def xlabel_map(s, labelpad=None, size=None, weight=None, x=0.5, ax=None, **kwarg rotation_mode=rotation_mode, size=size, weight=weight, - **kwargs + **kwargs, ) return h @@ -206,7 +208,7 @@ def yticklabels( ha="right", va="center", bbox_props=dict(ec="none", fc="none"), - **kwargs + **kwargs, ): """ @@ -238,18 +240,18 @@ def yticklabels( """ - plt.draw() - # get ax if necessary if ax is None: ax = plt.gca() + ax.figure.canvas.draw() + labelpad, size, weight = _get_label_attr(labelpad, size, weight) boundary_pc = _get_boundary_platecarree(ax) # ensure labels are on rhs and not in the middle - if len(boundary_pc) == 1: + if len(boundary_pc.geoms) == 1: lonmin, lonmax = -180, 180 else: lonmin, lonmax = 0, 360 @@ -265,7 +267,7 @@ def yticklabels( "WARN: no points found for ylabel\n" "y_lim is: {:0.2f} to {:0.2f}".format(y_lim[0], y_lim[1]) ) - print(msg) + warnings.warn(msg) # get a transform instance that mpl understands transform = ccrs.PlateCarree()._as_mpl_transform(ax) @@ -281,7 +283,7 @@ def yticklabels( x = _determine_intersection(boundary_pc, [lonmin, y], [lonmax, y]) if x.size > 0: - x = x[0, 0] + x = x[:, 0].min() lp = labelpad[0] + labelpad[1] * np.abs(y) / 90 ax.annotate( @@ -295,7 +297,7 @@ def yticklabels( xytext=(-lp, 0), textcoords="offset points", bbox=bbox_props, - **kwargs + **kwargs, ) @@ -308,7 +310,7 @@ def xticklabels( ha="center", va="top", bbox_props=dict(ec="none", fc="none"), - **kwargs + **kwargs, ): """ @@ -340,12 +342,17 @@ def xticklabels( """ - plt.draw() - # get ax if necessary if ax is None: ax = plt.gca() + ax.figure.canvas.draw() + + # proj = ccrs.PlateCarree() + # points = shapely.geometry.MultiPoint([shapely.geometry.Point(x, 0) for x in x_ticks]) + # points = proj.project_geometry(points, proj) + # x_ticks = [x.x for x in points.geoms] + labelpad, size, weight = _get_label_attr(labelpad, size, weight) boundary_pc = _get_boundary_platecarree(ax) @@ -361,7 +368,7 @@ def xticklabels( "WARN: no points found for xlabel\n" "x_lim is: {:0.2f} to {:0.2f}".format(x_lim[0], x_lim[1]) ) - print(msg) + warnings.warn(msg) # get a transform instance that mpl understands transform = ccrs.PlateCarree()._as_mpl_transform(ax) @@ -373,7 +380,7 @@ def xticklabels( y = _determine_intersection(boundary_pc, [x, -90], [x, 90]) if y.size > 0: - y = y[0, 1] + y = y[:, 1].min() ax.annotate( msg, @@ -386,7 +393,7 @@ def xticklabels( xytext=(0, -labelpad), textcoords="offset points", bbox=bbox_props, - **kwargs + **kwargs, ) @@ -394,17 +401,34 @@ def _get_boundary_platecarree(ax): # get the bounding box of the map in lat/ lon coordinates # after ax._get_extent_geom proj = ccrs.PlateCarree() - boundary_poly = sgeom.Polygon(ax.outline_patch.get_path().vertices) + boundary_poly = shapely.geometry.Polygon(ax.spines["geo"].get_path().vertices) eroded_boundary = boundary_poly.buffer(-ax.projection.threshold / 100) boundary_pc = proj.project_geometry(eroded_boundary, ax.projection) + # boundary_pc = proj.project_geometry(boundary_poly, ax.projection) + return boundary_pc def _determine_intersection(polygon, xy1, xy2): - p1 = sgeom.Point(xy1) - p2 = sgeom.Point(xy2) - ls = sgeom.LineString([p1, p2]) + p1 = shapely.geometry.Point(xy1) + p2 = shapely.geometry.Point(xy2) + ls = shapely.geometry.LineString([p1, p2]) + + intersection = polygon.boundary.intersection(ls) + + if isinstance(intersection, shapely.geometry.MultiPoint): + arr = np.array([x.coords for x in intersection.geoms]).squeeze() + elif isinstance(intersection, shapely.geometry.Point): + arr = np.array([intersection.coords]).squeeze() + arr = np.atleast_2d(arr) + elif isinstance(intersection, shapely.geometry.LineString): + if intersection.is_empty: + return np.array([]) + else: + return np.array(intersection.coords) + else: + raise TypeError(f"Unexpected type: {type(intersection)}") - return np.asarray(polygon.boundary.intersection(ls)) + return arr diff --git a/mplotutils/tests/test_mapticklabels.py b/mplotutils/tests/test_mapticklabels.py new file mode 100644 index 0000000..c7f946f --- /dev/null +++ b/mplotutils/tests/test_mapticklabels.py @@ -0,0 +1,111 @@ +import sys + +import cartopy.crs as ccrs +import numpy as np + +import mplotutils as mpu + +from . import subplots_context + + +def test_yticklabels_robinson(): + + with subplots_context(subplot_kw=dict(projection=ccrs.Robinson())) as (f, ax): + + ax.set_global() + + lat = np.arange(-90, 91, 20) + + mpu.yticklabels(lat, ax=ax, size=8) + + x_pos = -179.99 + + # two elements are not added because they are beyond the map limits + lat = lat[1:-1] + + # remove when dropping py 3.9 + strict = {"strict": True} if sys.version_info >= (3, 10) else {} + for t, y_pos in zip(ax.texts, lat, **strict): + + np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01) + + assert ax.texts[0].get_text() == "70°S" + assert ax.texts[-1].get_text() == "70°N" + + +def test_yticklabels_robinson_180(): + + proj = ccrs.Robinson(central_longitude=180) + with subplots_context(subplot_kw=dict(projection=proj)) as (f, ax): + + ax.set_global() + + lat = np.arange(-90, 91, 20) + + mpu.yticklabels(lat, ax=ax, size=8) + + x_pos = 0.0 + + # two elements are not added because they are beyond the map limits + lat = lat[1:-1] + + # remove when dropping py 3.9 + strict = {"strict": True} if sys.version_info >= (3, 10) else {} + for t, y_pos in zip(ax.texts, lat, **strict): + + np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01) + + assert ax.texts[0].get_text() == "70°S" + assert ax.texts[-1].get_text() == "70°N" + + +def test_xticklabels_robinson(): + + with subplots_context(subplot_kw=dict(projection=ccrs.Robinson())) as (f, ax): + + ax.set_global() + + lon = np.arange(-180, 181, 60) + + mpu.xticklabels(lon, ax=ax, size=8) + + y_pos = -89.99 + + # two elements are not added because they are beyond the map limits + lon = lon[1:-1] + + # remove when dropping py 3.9 + strict = {"strict": True} if sys.version_info >= (3, 10) else {} + + for t, x_pos in zip(ax.texts, lon, **strict): + + np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01) + + assert ax.texts[0].get_text() == "120°W" + assert ax.texts[-1].get_text() == "120°E" + + +# TODO: https://github.com/mathause/mplotutils/issues/48 +# def test_xticklabels_robinson_180(): + +# proj = ccrs.Robinson(central_longitude=180) +# with subplots_context(subplot_kw=dict(projection=proj)) as (f, ax): + +# ax.set_global() + +# # lon = np.arange(-180, 181, 60) +# lon = np.arange(0, 360, 60) + + +# mpu.xticklabels(lon, ax=ax, size=8) + +# y_pos = -89.99 + +# # two elements are not added because they are beyond the map limits +# lon = lon[1:-1] +# for t, x_pos in zip(ax.texts, lon, strict=True): + +# np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01) + +# assert ax.texts[0].get_text() == "60°E" +# assert ax.texts[-1].get_text() == "60°W"