Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions mplotutils/cartopy_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import warnings

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import shapely.geometry as sgeom
import shapely.geometry
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER

from .colormaps import _get_label_attr
Expand Down Expand Up @@ -123,7 +125,7 @@ def ylabel_map(s, labelpad=None, size=None, weight=None, y=0.5, ax=None, **kwarg
rotation_mode=rotation_mode,
size=size,
weight=weight,
**kwargs
**kwargs,
)

return h
Expand Down Expand Up @@ -188,7 +190,7 @@ def xlabel_map(s, labelpad=None, size=None, weight=None, x=0.5, ax=None, **kwarg
rotation_mode=rotation_mode,
size=size,
weight=weight,
**kwargs
**kwargs,
)

return h
Expand All @@ -206,7 +208,7 @@ def yticklabels(
ha="right",
va="center",
bbox_props=dict(ec="none", fc="none"),
**kwargs
**kwargs,
):

"""
Expand Down Expand Up @@ -238,18 +240,18 @@ def yticklabels(

"""

plt.draw()

# get ax if necessary
if ax is None:
ax = plt.gca()

ax.figure.canvas.draw()

labelpad, size, weight = _get_label_attr(labelpad, size, weight)

boundary_pc = _get_boundary_platecarree(ax)

# ensure labels are on rhs and not in the middle
if len(boundary_pc) == 1:
if len(boundary_pc.geoms) == 1:
lonmin, lonmax = -180, 180
else:
lonmin, lonmax = 0, 360
Expand All @@ -265,7 +267,7 @@ def yticklabels(
"WARN: no points found for ylabel\n"
"y_lim is: {:0.2f} to {:0.2f}".format(y_lim[0], y_lim[1])
)
print(msg)
warnings.warn(msg)

# get a transform instance that mpl understands
transform = ccrs.PlateCarree()._as_mpl_transform(ax)
Expand All @@ -281,7 +283,7 @@ def yticklabels(
x = _determine_intersection(boundary_pc, [lonmin, y], [lonmax, y])

if x.size > 0:
x = x[0, 0]
x = x[:, 0].min()
lp = labelpad[0] + labelpad[1] * np.abs(y) / 90

ax.annotate(
Expand All @@ -295,7 +297,7 @@ def yticklabels(
xytext=(-lp, 0),
textcoords="offset points",
bbox=bbox_props,
**kwargs
**kwargs,
)


Expand All @@ -308,7 +310,7 @@ def xticklabels(
ha="center",
va="top",
bbox_props=dict(ec="none", fc="none"),
**kwargs
**kwargs,
):

"""
Expand Down Expand Up @@ -340,12 +342,17 @@ def xticklabels(

"""

plt.draw()

# get ax if necessary
if ax is None:
ax = plt.gca()

ax.figure.canvas.draw()

# proj = ccrs.PlateCarree()
# points = shapely.geometry.MultiPoint([shapely.geometry.Point(x, 0) for x in x_ticks])
# points = proj.project_geometry(points, proj)
# x_ticks = [x.x for x in points.geoms]

labelpad, size, weight = _get_label_attr(labelpad, size, weight)

boundary_pc = _get_boundary_platecarree(ax)
Expand All @@ -361,7 +368,7 @@ def xticklabels(
"WARN: no points found for xlabel\n"
"x_lim is: {:0.2f} to {:0.2f}".format(x_lim[0], x_lim[1])
)
print(msg)
warnings.warn(msg)

# get a transform instance that mpl understands
transform = ccrs.PlateCarree()._as_mpl_transform(ax)
Expand All @@ -373,7 +380,7 @@ def xticklabels(

y = _determine_intersection(boundary_pc, [x, -90], [x, 90])
if y.size > 0:
y = y[0, 1]
y = y[:, 1].min()

ax.annotate(
msg,
Expand All @@ -386,25 +393,42 @@ def xticklabels(
xytext=(0, -labelpad),
textcoords="offset points",
bbox=bbox_props,
**kwargs
**kwargs,
)


def _get_boundary_platecarree(ax):
# get the bounding box of the map in lat/ lon coordinates
# after ax._get_extent_geom
proj = ccrs.PlateCarree()
boundary_poly = sgeom.Polygon(ax.outline_patch.get_path().vertices)
boundary_poly = shapely.geometry.Polygon(ax.spines["geo"].get_path().vertices)
eroded_boundary = boundary_poly.buffer(-ax.projection.threshold / 100)
boundary_pc = proj.project_geometry(eroded_boundary, ax.projection)

# boundary_pc = proj.project_geometry(boundary_poly, ax.projection)

return boundary_pc


def _determine_intersection(polygon, xy1, xy2):

p1 = sgeom.Point(xy1)
p2 = sgeom.Point(xy2)
ls = sgeom.LineString([p1, p2])
p1 = shapely.geometry.Point(xy1)
p2 = shapely.geometry.Point(xy2)
ls = shapely.geometry.LineString([p1, p2])

intersection = polygon.boundary.intersection(ls)

if isinstance(intersection, shapely.geometry.MultiPoint):
arr = np.array([x.coords for x in intersection.geoms]).squeeze()
elif isinstance(intersection, shapely.geometry.Point):
arr = np.array([intersection.coords]).squeeze()
arr = np.atleast_2d(arr)
elif isinstance(intersection, shapely.geometry.LineString):
if intersection.is_empty:
return np.array([])
else:
return np.array(intersection.coords)
else:
raise TypeError(f"Unexpected type: {type(intersection)}")

return np.asarray(polygon.boundary.intersection(ls))
return arr
111 changes: 111 additions & 0 deletions mplotutils/tests/test_mapticklabels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import sys

import cartopy.crs as ccrs
import numpy as np

import mplotutils as mpu

from . import subplots_context


def test_yticklabels_robinson():

with subplots_context(subplot_kw=dict(projection=ccrs.Robinson())) as (f, ax):

ax.set_global()

lat = np.arange(-90, 91, 20)

mpu.yticklabels(lat, ax=ax, size=8)

x_pos = -179.99

# two elements are not added because they are beyond the map limits
lat = lat[1:-1]

# remove when dropping py 3.9
strict = {"strict": True} if sys.version_info >= (3, 10) else {}
for t, y_pos in zip(ax.texts, lat, **strict):

np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)

assert ax.texts[0].get_text() == "70°S"
assert ax.texts[-1].get_text() == "70°N"


def test_yticklabels_robinson_180():

proj = ccrs.Robinson(central_longitude=180)
with subplots_context(subplot_kw=dict(projection=proj)) as (f, ax):

ax.set_global()

lat = np.arange(-90, 91, 20)

mpu.yticklabels(lat, ax=ax, size=8)

x_pos = 0.0

# two elements are not added because they are beyond the map limits
lat = lat[1:-1]

# remove when dropping py 3.9
strict = {"strict": True} if sys.version_info >= (3, 10) else {}
for t, y_pos in zip(ax.texts, lat, **strict):

np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)

assert ax.texts[0].get_text() == "70°S"
assert ax.texts[-1].get_text() == "70°N"


def test_xticklabels_robinson():

with subplots_context(subplot_kw=dict(projection=ccrs.Robinson())) as (f, ax):

ax.set_global()

lon = np.arange(-180, 181, 60)

mpu.xticklabels(lon, ax=ax, size=8)

y_pos = -89.99

# two elements are not added because they are beyond the map limits
lon = lon[1:-1]

# remove when dropping py 3.9
strict = {"strict": True} if sys.version_info >= (3, 10) else {}

for t, x_pos in zip(ax.texts, lon, **strict):

np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)

assert ax.texts[0].get_text() == "120°W"
assert ax.texts[-1].get_text() == "120°E"


# TODO: https://github.com/mathause/mplotutils/issues/48
# def test_xticklabels_robinson_180():

# proj = ccrs.Robinson(central_longitude=180)
# with subplots_context(subplot_kw=dict(projection=proj)) as (f, ax):

# ax.set_global()

# # lon = np.arange(-180, 181, 60)
# lon = np.arange(0, 360, 60)


# mpu.xticklabels(lon, ax=ax, size=8)

# y_pos = -89.99

# # two elements are not added because they are beyond the map limits
# lon = lon[1:-1]
# for t, x_pos in zip(ax.texts, lon, strict=True):

# np.testing.assert_allclose((x_pos, y_pos), t.xy, atol=0.01)

# assert ax.texts[0].get_text() == "60°E"
# assert ax.texts[-1].get_text() == "60°W"