Skip to content

Commit

Permalink
refactor(PlotMapView): support color kwarg in plot_endpoint() (#1745)
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Mar 21, 2023
1 parent 656751a commit c307420
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 18 deletions.
75 changes: 65 additions & 10 deletions autotest/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from flopy.modpath import Modpath6, Modpath6Bas
from flopy.plot import PlotCrossSection, PlotMapView
from flopy.utils import CellBudgetFile, HeadFile, PathlineFile
from flopy.utils import CellBudgetFile, EndpointFile, HeadFile, PathlineFile


def test_map_view():
Expand Down Expand Up @@ -355,9 +355,8 @@ def test_model_dot_plot_export(function_tmpdir, example_data_path):
raise AssertionError("Plot filenames not written correctly")


@requires_pkg("pandas")
@requires_exe("mf2005")
def test_pathline_plot_xc(function_tmpdir, example_data_path):
@pytest.fixture
def modpath_model(function_tmpdir, example_data_path):
# test with multi-layer example
load_ws = example_data_path / "mp6"

Expand All @@ -383,11 +382,18 @@ def test_pathline_plot_xc(function_tmpdir, example_data_path):
packages="RCH",
start_time=(2, 0, 1.0),
)
mp.write_input()
return ml, mp, sim


@requires_pkg("pandas")
@requires_exe("mf2005", "mp6")
def test_xc_plot_particle_pathlines(modpath_model):
ml, mp, sim = modpath_model

mp.write_input()
mp.run_model(silent=False)

pthobj = PathlineFile(os.path.join(function_tmpdir, "ex6.mppth"))
pthobj = PathlineFile(os.path.join(mp.model_ws, "ex6.mppth"))
well_pathlines = pthobj.get_destination_pathline_data(
dest_cells=[(4, 12, 12)]
)
Expand All @@ -396,11 +402,60 @@ def test_pathline_plot_xc(function_tmpdir, example_data_path):
mx.plot_bc("WEL", kper=2, color="blue")
pth = mx.plot_pathline(well_pathlines, method="cell", colors="red")

if not isinstance(pth, LineCollection):
raise AssertionError()
assert isinstance(pth, LineCollection)
assert len(pth._paths) == 6

if len(pth._paths) != 6:
raise AssertionError()

@requires_pkg("pandas")
@requires_exe("mf2005", "mp6")
def test_map_plot_particle_endpoints(modpath_model):
ml, mp, sim = modpath_model
mp.write_input()
mp.run_model(silent=False)

pthobj = EndpointFile(os.path.join(mp.model_ws, "ex6.mpend"))
endpts = pthobj.get_alldata()

# color kwarg as scalar
mv = PlotMapView(model=ml)
mv.plot_bc("WEL", kper=2, color="blue")
ep = mv.plot_endpoint(endpts, direction="ending", color="red")
# plt.show()
assert isinstance(ep, PathCollection)

# c kwarg as array
mv = PlotMapView(model=ml)
mv.plot_bc("WEL", kper=2, color="blue")
ep = mv.plot_endpoint(
endpts,
direction="ending",
c=np.random.rand(625) * -1000,
cmap="viridis",
)
# plt.show()
assert isinstance(ep, PathCollection)

# colorbar: color by time to termination
mv = PlotMapView(model=ml)
mv.plot_bc("WEL", kper=2, color="blue")
ep = mv.plot_endpoint(
endpts, direction="ending", shrink=0.5, colorbar=True
)
# plt.show()
assert isinstance(ep, PathCollection)

# if both color and c are provided, c takes precedence
mv = PlotMapView(model=ml)
mv.plot_bc("WEL", kper=2, color="blue")
ep = mv.plot_endpoint(
endpts,
direction="ending",
color="red",
c=np.random.rand(625) * -1000,
cmap="viridis",
)
# plt.show()
assert isinstance(ep, PathCollection)


@pytest.fixture
Expand Down
16 changes: 8 additions & 8 deletions flopy/plot/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,17 +873,11 @@ def plot_endpoint(
"""

ax = kwargs.pop("ax", self.ax)

tep, _, xp, yp = plotutil.parse_modpath_selection_options(
ep, direction, selection, selection_direction
)
# scatter kwargs that users may redefine
if "c" not in kwargs:
c = tep["time"] - tep["time0"]
else:
c = np.empty((tep.shape[0]), dtype="S30")
c.fill(kwargs.pop("c"))

# marker size
s = kwargs.pop("s", np.sqrt(50))
s = float(kwargs.pop("size", s)) ** 2.0

Expand All @@ -904,7 +898,13 @@ def plot_endpoint(
arr = np.vstack((x0r, y0r)).T

# plot the end point data
sp = ax.scatter(arr[:, 0], arr[:, 1], c=c, s=s, **kwargs)
if "c" in kwargs or "color" in kwargs:
if "c" in kwargs and "color" in kwargs:
kwargs.pop("color")
sp = ax.scatter(arr[:, 0], arr[:, 1], s=s, **kwargs)
else:
c = tep["time"] - tep["time0"]
sp = ax.scatter(arr[:, 0], arr[:, 1], c=c, s=s, **kwargs)

# add a colorbar for travel times
if createcb:
Expand Down

0 comments on commit c307420

Please sign in to comment.