Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions autotest/t007_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,8 +1191,8 @@ def test_sr_with_Map():
plt.close()

def check_vertices():
xllp, yllp = pc._paths[780].vertices[3]
xulp, yulp = pc._paths[0].vertices[0]
xllp, yllp = pc._paths[0].vertices[0]
xulp, yulp = pc._paths[0].vertices[1]
assert np.abs(xllp - xll) < 1e-6
assert np.abs(yllp - yll) < 1e-6
assert np.abs(xulp - xul) < 1e-6
Expand Down Expand Up @@ -1304,7 +1304,7 @@ def test_modelgrid_with_PlotMapView():
xll, yll, rotation = 500000.0, 2934000.0, 45.0

def check_vertices():
xllp, yllp = pc._paths[780].vertices[3]
xllp, yllp = pc._paths[0].vertices[0]
assert np.abs(xllp - xll) < 1e-6
assert np.abs(yllp - yll) < 1e-6

Expand Down Expand Up @@ -1334,7 +1334,7 @@ def check_vertices():


def test_mapview_plot_bc():
from matplotlib.collections import QuadMesh, PatchCollection
from matplotlib.collections import QuadMesh, PathCollection
import matplotlib.pyplot as plt

sim_name = "mfsim.nam"
Expand All @@ -1352,7 +1352,7 @@ def test_mapview_plot_bc():
raise AssertionError("Boundary condition was not drawn")

for col in ax.collections:
if not isinstance(col, PatchCollection):
if not isinstance(col, (QuadMesh, PathCollection)):
raise AssertionError("Unexpected collection type")
plt.close()

Expand All @@ -1370,7 +1370,7 @@ def test_mapview_plot_bc():
raise AssertionError("Boundary condition was not drawn")

for col in ax.collections:
if not isinstance(col, PatchCollection):
if not isinstance(col, (QuadMesh, PathCollection)):
raise AssertionError("Unexpected collection type")
plt.close()

Expand All @@ -1395,7 +1395,7 @@ def test_mapview_plot_bc():
raise AssertionError("Boundary condition was not drawn")

for col in ax.collections:
if not isinstance(col, PatchCollection):
if not isinstance(col, (QuadMesh, PathCollection)):
raise AssertionError("Unexpected collection type")
plt.close()

Expand All @@ -1413,7 +1413,7 @@ def test_mapview_plot_bc():
raise AssertionError("Boundary condition was not drawn")

for col in ax.collections:
if not isinstance(col, PatchCollection):
if not isinstance(col, (QuadMesh, PathCollection)):
raise AssertionError("Unexpected collection type")
plt.close()

Expand Down
29 changes: 2 additions & 27 deletions flopy/discretization/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,32 +493,7 @@ def cross_section_set_contour_arrays(

@property
def map_polygons(self):
"""
Get a list of matplotlib Polygon patches for plotting

Returns
-------
list of Polygon objects
"""
try:
from matplotlib.patches import Polygon
except ImportError:
raise ImportError("matplotlib required to use this method")
cache_index = "xyzgrid"
if (
cache_index not in self._cache_dict
or self._cache_dict[cache_index].out_of_date
):
self.xyzvertices
self._polygons = None

if self._polygons is None:
self._polygons = [
Polygon(self.get_cell_vertices(nn), closed=True)
for nn in range(self.ncpl)
]

return copy.copy(self._polygons)
raise NotImplementedError("must define map_polygons in child class")

def get_plottable_layer_array(self, plotarray, layer):
raise NotImplementedError(
Expand Down Expand Up @@ -558,7 +533,7 @@ def get_coords(self, x, y):
x = np.array(x)
y = np.array(y)
if not np.isscalar(x):
x, y = x.copy(), y.copy()
x, y = x.astype(float, copy=True), y.astype(float, copy=True)

x += self._xoff
y += self._yoff
Expand Down
26 changes: 26 additions & 0 deletions flopy/discretization/structuredgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,32 @@ def is_rectilinear(self):
else:
return self._cache_dict[cache_index].data_nocopy

@property
def map_polygons(self):
"""
Get a list of matplotlib Polygon patches for plotting

Returns
-------
list of Polygon objects
"""
try:
import matplotlib.path as mpath
except ImportError:
raise ImportError("matplotlib required to use this method")
cache_index = "xyzgrid"
if (
cache_index not in self._cache_dict
or self._cache_dict[cache_index].out_of_date
):
self.xyzvertices
self._polygons = None

if self._polygons is None:
self._polygons = (self.xvertices, self.yvertices)

return self._polygons

###############
### Methods ###
###############
Expand Down
6 changes: 3 additions & 3 deletions flopy/discretization/unstructuredgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def map_polygons(self):
list or dict of matplotlib.collections.Polygon
"""
try:
from matplotlib.patches import Polygon
from matplotlib.path import Path
except ImportError:
raise ImportError("matplotlib required to use this method")

Expand All @@ -429,11 +429,11 @@ def map_polygons(self):
if ilay not in self._polygons:
self._polygons[ilay] = []

p = Polygon(self.get_cell_vertices(nn), closed=True)
p = Path(self.get_cell_vertices(nn))
self._polygons[ilay].append(p)
else:
self._polygons = [
Polygon(self.get_cell_vertices(nn), closed=True)
Path(self.get_cell_vertices(nn))
for nn in range(self.ncpl[0])
]

Expand Down
28 changes: 28 additions & 0 deletions flopy/discretization/vertexgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,34 @@ def xyzvertices(self):
else:
return self._cache_dict[cache_index].data_nocopy

@property
def map_polygons(self):
"""
Get a list of matplotlib Polygon patches for plotting

Returns
-------
list of Polygon objects
"""
try:
import matplotlib.path as mpath
except ImportError:
raise ImportError("matplotlib required to use this method")
cache_index = "xyzgrid"
if (
cache_index not in self._cache_dict
or self._cache_dict[cache_index].out_of_date
):
self.xyzvertices
self._polygons = None
if self._polygons is None:
self._polygons = [
mpath.Path(self.get_cell_vertices(nn))
for nn in range(self.ncpl)
]

return copy.copy(self._polygons)

def intersect(self, x, y, local=False, forgive=False):
"""
Get the CELL2D number of a point with coordinates x and y
Expand Down
54 changes: 25 additions & 29 deletions flopy/plot/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
try:
import matplotlib.pyplot as plt
import matplotlib.colors
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from matplotlib.collections import PathCollection, LineCollection
from matplotlib.path import Path
except (ImportError, ModuleNotFoundError):
plt = None

Expand Down Expand Up @@ -112,7 +112,6 @@ def plot_array(self, a, masked_values=None, **kwargs):

# Use the model grid to pass back an array of the correct shape
plotarray = self.mg.get_plottable_layer_array(a, self.layer)
plotarray = plotarray.ravel()

# if masked_values are provided mask the plotting array
if masked_values is not None:
Expand All @@ -129,20 +128,26 @@ def plot_array(self, a, masked_values=None, **kwargs):
if isinstance(polygons, dict):
polygons = polygons[self.layer]

collection = PatchCollection(polygons)
collection.set_array(plotarray)
if len(polygons) == 0:
return

if not isinstance(polygons[0], Path):
collection = ax.pcolormesh(
self.mg.xvertices, self.mg.yvertices, plotarray
)

else:
plotarray = plotarray.ravel()
collection = PathCollection(polygons)
collection.set_array(plotarray)

# set max and min
vmin = kwargs.pop("vmin", None)
vmax = kwargs.pop("vmax", None)

# limit the color range
# set matplotlib kwargs
collection.set_clim(vmin=vmin, vmax=vmax)

# send rest of kwargs to quadmesh
collection.set(**kwargs)

# add collection to axis
ax.add_collection(collection)

# set limits
Expand Down Expand Up @@ -362,30 +367,21 @@ def plot_grid(self, **kwargs):
from matplotlib.collections import PatchCollection

ax = kwargs.pop("ax", self.ax)
edgecolor = kwargs.pop("colors", "grey")
edgecolor = kwargs.pop("color", edgecolor)
edgecolor = kwargs.pop("ec", edgecolor)
edgecolor = kwargs.pop("edgecolor", edgecolor)
facecolor = kwargs.pop("facecolor", "none")
facecolor = kwargs.pop("fc", facecolor)
colors = kwargs.pop("colors", "grey")
colors = kwargs.pop("color", colors)
colors = kwargs.pop("ec", colors)
colors = kwargs.pop("edgecolor", colors)

# use cached patch collection for plotting
polygons = self.mg.map_polygons
if isinstance(polygons, dict):
polygons = polygons[self.layer]
grid_lines = self.mg.grid_lines
if isinstance(grid_lines, dict):
grid_lines = grid_lines[self.layer]

if len(polygons) > 0:
patches = PatchCollection(
polygons, edgecolor=edgecolor, facecolor=facecolor, **kwargs
)
else:
patches = None
collection = LineCollection(grid_lines, colors=colors, **kwargs)

ax.add_collection(patches)
ax.add_collection(collection)
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])

return patches
return collection

def plot_bc(
self,
Expand Down