# FloPy - Voronoi grid model for VTK export

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.ticker as mticker
from matplotlib import colors
import shapely
from shapely.geometry import Polygon, LineString
import flopy
from flopy.discretization import StructuredGrid, VertexGrid
from flopy.utils.triangle import Triangle
from flopy.utils.voronoi import VoronoiGrid
from flopy.utils.gridgen import Gridgen
import flopy.plot.styles as styles

In [None]:
# import the fine topography
fine_topo = flopy.utils.Raster.load("./grid_data/fine_topo.asc")

In [None]:
fine_topo.plot();

In [None]:
Lx = 180000
Ly = 100000
nlay = 5
dv0 = 5.0
extent = (0, Lx, 0, Ly)
levels = np.arange(10, 110, 10)
vmin, vmax = 0.0, 100.0

In [None]:
temp_path = "./temp"
if not os.path.isdir(temp_path):
    os.mkdir(temp_path)

# Basin Example

In [None]:
boundary = """1.868012422360248456e+05 4.695652173913043953e+04
1.790372670807453396e+05 5.204968944099379587e+04
1.729813664596273447e+05 5.590062111801243009e+04
1.672360248447204940e+05 5.987577639751553215e+04
1.631987577639751253e+05 6.335403726708075556e+04
1.563664596273291972e+05 6.819875776397516893e+04
1.509316770186335489e+05 7.229813664596274612e+04
1.453416149068323139e+05 7.527950310559007630e+04
1.395962732919254631e+05 7.627329192546584818e+04
1.357142857142857101e+05 7.664596273291927355e+04
1.329192546583850926e+05 7.751552795031057030e+04
1.268633540372670832e+05 8.062111801242237561e+04
1.218944099378881947e+05 8.285714285714286962e+04
1.145962732919254486e+05 8.571428571428572468e+04
1.069875776397515583e+05 8.869565217391305487e+04
1.023291925465838431e+05 8.931677018633540138e+04
9.456521739130433707e+04 9.068322981366459862e+04
8.804347826086955320e+04 9.080745341614908830e+04
7.950310559006211406e+04 9.267080745341615693e+04
7.562111801242236106e+04 9.391304347826087906e+04
6.692546583850930620e+04 9.602484472049689793e+04
5.667701863354037778e+04 9.763975155279504543e+04
4.906832298136646568e+04 9.689440993788820924e+04
3.897515527950309479e+04 9.540372670807455142e+04
3.167701863354036323e+04 9.304347826086958230e+04
2.375776397515527788e+04 8.757763975155279331e+04
1.847826086956521613e+04 8.161490683229814749e+04
1.164596273291925172e+04 7.739130434782608063e+04
6.211180124223596977e+03 7.055900621118013805e+04
4.347826086956512881e+03 6.422360248447205959e+04
1.863354037267072272e+03 6.037267080745341809e+04
2.639751552795024509e+03 5.602484472049689793e+04
1.552795031055893560e+03 5.279503105590062478e+04
7.763975155279410956e+02 4.186335403726709046e+04
2.018633540372667312e+03 3.813664596273292409e+04
6.055900621118013078e+03 3.341614906832297856e+04
1.335403726708074100e+04 2.782608695652173992e+04
2.577639751552794405e+04 2.086956521739130767e+04
3.416149068322980747e+04 1.763975155279503815e+04
4.642857142857142753e+04 1.440993788819875044e+04
5.636645962732918997e+04 1.130434782608694877e+04
6.459627329192546313e+04 9.813664596273290954e+03
8.555900621118012350e+04 6.832298136645956220e+03
9.829192546583850344e+04 5.093167701863346338e+03
1.085403726708074391e+05 4.347826086956525614e+03
1.200310559006211115e+05 4.223602484472040487e+03
1.296583850931677007e+05 4.347826086956525614e+03
1.354037267080745369e+05 5.590062111801232277e+03
1.467391304347825935e+05 1.267080745341615875e+04
1.563664596273291972e+05 1.937888198757762802e+04
1.630434782608695677e+05 2.198757763975155467e+04
1.694099378881987650e+05 2.434782608695652743e+04
1.782608695652173774e+05 2.981366459627329095e+04
1.833850931677018234e+05 3.180124223602484562e+04
1.868012422360248456e+05 3.577639751552795497e+04"""

streamseg1 = """1.868012422360248456e+05 4.086956521739130403e+04
1.824534161490683327e+05 4.086956521739130403e+04
1.770186335403726553e+05 4.124223602484472940e+04
1.737577639751552779e+05 4.186335403726709046e+04
1.703416149068323139e+05 4.310559006211180531e+04
1.670807453416148783e+05 4.397515527950310934e+04
1.636645962732919143e+05 4.484472049689441337e+04
1.590062111801242281e+05 4.559006211180124228e+04
1.555900621118012350e+05 4.559006211180124228e+04
1.510869565217391064e+05 4.546583850931677443e+04
1.479813664596273156e+05 4.534161490683229931e+04
1.453416149068323139e+05 4.496894409937888850e+04
1.377329192546583654e+05 4.447204968944099528e+04
1.326086956521739194e+05 4.447204968944099528e+04
1.285714285714285652e+05 4.434782608695652743e+04
1.245341614906832110e+05 4.472049689440993825e+04
1.215838509316770069e+05 4.509316770186335634e+04
1.161490683229813585e+05 4.509316770186335634e+04
1.125776397515527933e+05 4.459627329192547040e+04
1.074534161490683036e+05 4.385093167701864149e+04
1.018633540372670686e+05 4.347826086956522340e+04
9.798136645962731563e+04 4.360248447204969125e+04
9.223602484472049400e+04 4.310559006211180531e+04
8.602484472049689793e+04 4.198757763975155831e+04
7.981366459627327276e+04 4.173913043478261534e+04
7.468944099378881219e+04 4.248447204968944425e+04
7.034161490683228476e+04 4.385093167701864149e+04
6.785714285714285506e+04 4.621118012422360334e+04
6.583850931677018525e+04 4.919254658385094081e+04
6.319875776397513982e+04 5.192546583850932075e+04
6.009316770186335634e+04 5.677018633540373412e+04
5.605590062111800216e+04 5.950310559006211406e+04
5.279503105590060295e+04 6.124223602484472940e+04
4.751552795031056303e+04 6.211180124223603343e+04
3.990683229813664366e+04 6.335403726708075556e+04
3.276397515527949508e+04 6.409937888198757719e+04
2.934782608695651652e+04 6.509316770186336362e+04
2.546583850931676716e+04 6.832298136645962950e+04"""

streamseg2 = """6.972049689440995280e+04 4.347826086956522340e+04
6.816770186335404287e+04 4.273291925465839449e+04
6.490683229813665093e+04 4.211180124223603343e+04
6.164596273291925900e+04 4.173913043478262261e+04
5.776397515527951327e+04 4.124223602484472940e+04
5.450310559006211406e+04 4.049689440993789322e+04
4.984472049689442065e+04 3.937888198757764621e+04
4.534161490683231386e+04 3.801242236024845624e+04
4.114906832298137306e+04 3.664596273291926627e+04
3.913043478260868869e+04 3.565217391304348712e+04
3.649068322981366509e+04 3.416149068322981475e+04
3.322981366459628043e+04 3.242236024844721760e+04
3.012422360248447148e+04 3.105590062111801672e+04
2.608695652173913550e+04 2.956521739130435890e+04"""

streamseg3 = """1.059006211180124228e+05 4.335403726708074828e+04
1.029503105590062187e+05 4.223602484472050128e+04
1.004658385093167890e+05 4.024844720496894297e+04
9.937888198757765349e+04 3.788819875776398112e+04
9.627329192546584818e+04 3.490683229813664366e+04
9.285714285714286962e+04 3.316770186335403559e+04
8.897515527950311662e+04 3.093167701863354159e+04
8.338509316770188161e+04 2.795031055900621504e+04
7.872670807453416637e+04 2.670807453416148928e+04
7.329192546583851799e+04 2.385093167701863058e+04
6.863354037267081731e+04 2.111801242236025064e+04
6.304347826086958230e+04 1.863354037267081003e+04"""

streamseg4 = """1.371118012422360480e+05 4.472049689440994553e+04
1.321428571428571595e+05 4.720496894409938250e+04
1.285714285714285652e+05 4.981366459627330187e+04
1.243788819875776535e+05 5.341614906832298584e+04
1.189440993788819906e+05 5.540372670807454415e+04
1.125776397515527933e+05 5.627329192546584818e+04
1.065217391304347839e+05 5.726708074534162733e+04
1.020186335403726698e+05 5.913043478260870324e+04
9.409937888198759174e+04 6.273291925465840177e+04
9.192546583850932075e+04 6.633540372670808574e+04
8.881987577639751544e+04 7.242236024844722124e+04
8.586956521739131131e+04 7.552795031055902655e+04
8.369565217391305487e+04 7.962732919254660374e+04"""


def string2geom(geostring):
    res = []
    for line in geostring.split("\n"):
        line = line.split(" ")
        x = float(line[0])
        y = float(line[1])
        res.append((x, y))
    return res


boundary_polygon = string2geom(boundary)
print("len boundary", len(boundary_polygon))
bp = np.array(boundary_polygon)

sgs = [
    string2geom(sg) for sg in (streamseg1, streamseg2, streamseg3, streamseg4)
]

colors = ("blue", "cyan", "green", "magenta")
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot()
ax.set_aspect("equal")

ax.plot(bp[:, 0], bp[:, 1], "ko-")
for idx, sg in enumerate(sgs):
    print("Len segment: ", len(sg))
    sa = np.array(sg)
    ax.plot(sa[:, 0], sa[:, 1], ls="-", color=colors[idx], marker="o")

In [None]:
def densify_geometry(line, step):
    line_geometry = shapely.geometry.LineString(line)
    length_m = line_geometry.length  # get the length
    xy = []  # to store new tuples of coordinates
    for distance_along_old_line in np.arange(0, int(length_m), step):
        point = line_geometry.interpolate(
            distance_along_old_line
        )  # interpolate a point every step along the old line
        xp, yp = point.x, point.y  # extract the coordinates
        xy.append((xp, yp))  # and store them in xy list
    return xy

stop

In [None]:
maximum_area = 5000 * 5000

nodes = []
for sg in sgs:
    sg_densify = densify_geometry(sg, 2000)
    nodes += sg_densify
nodes = np.array(nodes)

tri = Triangle(
    maximum_area=maximum_area, angle=30, nodes=nodes, model_ws=temp_path
)
poly = bp
tri.add_polygon(poly)
tri.build(verbose=False)

# create vor object and VertexGrid
vor = VoronoiGrid(tri)
gridprops = vor.get_gridprops_vertexgrid()
idomain = np.ones((nlay, vor.ncpl), dtype=int)
voronoi_grid = VertexGrid(**gridprops, nlay=nlay, idomain=idomain)

In [None]:
areas = []
for idx in range(voronoi_grid.ncpl):
    vertices = np.array(voronoi_grid.get_cell_vertices(idx))
    area = Polygon(vertices).area
    areas.append(area)
areas = np.array(areas)
areas.min(), areas.max()

In [None]:
top_vg = fine_topo.resample_to_grid(
    voronoi_grid,
    band=fine_topo.bands[0],
    method="linear",
    extrapolate_edges=True,
)

In [None]:
ixs = flopy.utils.GridIntersect(voronoi_grid, method="vertex")
cellids = []
for sg in sgs:
    v = ixs.intersect(LineString(sg), sort_by_cellid=True)
    cellids += v["cellids"].tolist()
intersection_vg = np.zeros(voronoi_grid.shape[1:])
for loc in cellids:
    intersection_vg[loc] = 1

In [None]:
fig = plt.figure()
ax = fig.add_subplot()
pmv = flopy.plot.PlotMapView(modelgrid=voronoi_grid)
ax.set_aspect("equal")
pmv.plot_array(top_vg)
pmv.plot_array(
    intersection_vg,
    masked_values=[
        0,
    ],
    alpha=0.2,
    cmap="Reds_r",
)
# pmv.plot_grid()
pmv.plot_inactive()
ax.plot(bp[:, 0], bp[:, 1], "k-")
for sg in sgs:
    sa = np.array(sg)
    ax.plot(sa[:, 0], sa[:, 1], "b-")

cg = pmv.contour_array(top_vg, levels=levels, linewidths=0.3, colors="0.75")

### build a model

In [None]:
ixs = flopy.utils.GridIntersect(voronoi_grid, method="vertex")

In [None]:
drn_intersection = []
drn_cellids = []
drn_lengths = []
for sg in sgs:
    v = ixs.intersect(LineString(sg), sort_by_cellid=True)
    drn_intersection.append(v)
    drn_cellids += v["cellids"].tolist()
    drn_lengths += v["lengths"].tolist()

In [None]:
leakance = 1.0 / (0.5 * dv0)  # kv / b
drn_data = []
for node, length in zip(drn_cellids, drn_lengths):
    x = voronoi_grid.xcellcenters[node]
    width = 5.0 + (14.0 / Lx) * (Lx - x)
    conductance = leakance * length * width
    drn_data.append((0, node, top_vg[node], conductance))
drn_data[:10]

In [None]:
# groundwater discharge to surface
gw_discharge_data = []
for node in range(voronoi_grid.ncpl):
    if node not in drn_cellids:
        vertices = np.array(voronoi_grid.get_cell_vertices(node))
        conductance = leakance * Polygon(vertices).area
        gw_discharge_data.append(
            (0, node, top_vg[node] - 0.5, conductance, 1.0)
        )
gw_discharge_data[:10]

In [None]:
topc = np.zeros((nlay, vor.ncpl), dtype=float)
botm = np.zeros((nlay, vor.ncpl), dtype=float)
dv = dv0
topc[0] = top_vg.copy()
botm[0] = topc[0] - dv
for idx in range(1, nlay):
    dv *= 1.5
    topc[idx] = botm[idx - 1]
    botm[idx] = topc[idx] - dv

In [None]:
for k in range(nlay):
    print((topc[k] - botm[k]).mean())

In [None]:
exe_name = "/Users/jdhughes/Documents/Development/modflow6/modflow6/bin/mf6"
sim = flopy.mf6.MFSimulation(
    sim_name="create_vtk",
    sim_ws="temp_vtk",
    exe_name="mf6",
)

tdis = flopy.mf6.ModflowTdis(sim)
ims = flopy.mf6.ModflowIms(
    sim, linear_acceleration="bicgstab", complexity="simple"
)
gwf = flopy.mf6.ModflowGwf(
    sim, save_flows=True, newtonoptions="NEWTON UNDER_RELAXATION"
)

dis = flopy.mf6.ModflowGwfdisv(
    gwf,
    nlay=nlay,
    ncpl=vor.ncpl,
    nvert=vor.get_disv_gridprops()["nvert"],
    vertices=vor.get_disv_gridprops()["vertices"],
    cell2d=vor.get_disv_gridprops()["cell2d"],
    top=top_vg,
    botm=botm,
)

ic = flopy.mf6.ModflowGwfic(gwf, strt=top_vg.max())
npf = flopy.mf6.ModflowGwfnpf(
    gwf,
    save_specific_discharge=True,
    icelltype=1,
    k=1.0,
)
rch = flopy.mf6.ModflowGwfrcha(
    gwf,
    recharge=0.000001,
)
drn = flopy.mf6.ModflowGwfdrn(
    gwf,
    stress_period_data=drn_data,
    pname="river",
)
drn_gwd = flopy.mf6.ModflowGwfdrn(
    gwf,
    auxiliary=["depth"],
    auxdepthname="depth",
    stress_period_data=gw_discharge_data,
    pname="gwd",
)
oc = flopy.mf6.ModflowGwfoc(
    gwf,
    head_filerecord=f"{gwf.name}.hds",
    budget_filerecord=f"{gwf.name}.cbc",
    saverecord=[("HEAD", "ALL"), ("BUDGET", "ALL")],
)

In [None]:
sim.write_simulation()
sim.run_simulation()

In [None]:
def plot_river(
    ax=None,
    lw=1.0,
):
    if ax is None:
        ax = plt.gca()
    for sg in sgs:
        sg_densify = np.array(densify_geometry(sg, 2000))
        ax.plot(sg_densify[:, 0], sg_densify[:, 1], "b-", lw=lw)
    return ax

In [None]:
head = gwf.output.head().get_data().squeeze()

In [None]:
head[0].min(), head[0].max()

In [None]:
gwf.output.methods()

In [None]:
cbc = gwf.output.budget()
cbc.list_unique_records(), cbc.list_unique_packages()

In [None]:
spdis = cbc.get_data(text="DATA-SPDIS")[0]
qx, qy, qz = flopy.utils.postprocessing.get_specific_discharge(spdis, gwf)
qx.shape

In [None]:
riv_q, gwd_q = cbc.get_data(text="DRN", full3D=True)
riv_q.shape, gwd_q.shape

In [None]:
riv_q_loc = np.zeros(riv_q.shape, dtype=int)
gwd_q_loc = np.zeros(gwd_q.shape, dtype=int)

In [None]:
drn_q_loc = np.zeros(gwd_q.shape, dtype=int)
drn_q_loc[riv_q < 0.0] = 1
drn_q_loc[gwd_q < 0.0] = 2

In [None]:
idx = riv_q < 0.0
riv_q_loc[idx] = 1

In [None]:
idx = gwd_q < 0.0
gwd_q_loc[idx] = 1

In [None]:
dry_cell_loc = np.zeros(head.shape, dtype=int)
dry_cell_loc[head < botm] = 1

In [None]:
upper_active_layer = np.ones(head.shape[1], dtype=int) * 99
for k in range(nlay):
    idx = (dry_cell_loc[k] < 1) & (upper_active_layer == 99)
    upper_active_layer[idx] = k

In [None]:
qx_top = np.zeros(head.shape[1])
qy_top = np.zeros(head.shape[1])
for node, k in enumerate(upper_active_layer):
    qx_top[node] = qx[k, node]
    qy_top[node] = qy[k, node]

In [None]:
fig = plt.figure(figsize=(10, 6), constrained_layout=True)
mm = flopy.plot.PlotMapView(model=gwf)
cb = mm.plot_array(head, ec="0.5")
plot_river(ax=mm.ax)
mm.plot_vector(qx_top, qy_top, normalize=True)
mm.ax.axhline(y=42500, lw=2, color="red")
mm.ax.axvline(x=72500, lw=2, color="red")
plt.colorbar(cb, orientation="horizontal");

In [None]:
extent = mm.extent
extent

In [None]:
fx = flopy.plot.PlotCrossSection(
    model=gwf, line={"line": [(0, 42500), (extent[1], 42500)]}
)
fx.plot_array(head, head=head)
fx.plot_grid()

In [None]:
fx = flopy.plot.PlotCrossSection(
    model=gwf,
    line={"line": [(72500, extent[2]), (72500, extent[3])]},
)
fx.plot_array(head, head=head)
fx.plot_grid()

In [None]:
topc.shape, head.shape

In [None]:
dtw = topc[0] - head[0]
dtw.shape, dtw.min(), dtw.max()

In [None]:
fig = plt.figure(figsize=(10, 6), constrained_layout=True)
mm = flopy.plot.PlotMapView(model=gwf)
cb = mm.plot_array(dtw)
cs = mm.contour_array(
    dtw,
    levels=[1, 5, 10],
    colors="white",
    linewidths=1,
)
mm.ax.clabel(cs, inline=1, fmt="%2.0f", fontsize=12, inline_spacing=0)
plot_river(ax=mm.ax)
plt.colorbar(cb, orientation="horizontal");

In [None]:
layer_cmap = colors.ListedColormap(["white", "green", "blue"])
drain_cmap = colors.ListedColormap(["red", "cyan"])
font_dict = {"fontsize": 5, "color": "black"}
contour_color = "0.8"

In [None]:
fig = plt.figure(figsize=(10, 6), constrained_layout=True)
mm = flopy.plot.PlotMapView(model=gwf)
dp = mm.plot_array(drn_q_loc, masked_values=[0], cmap=drain_cmap, edgecolor="none", alpha=0.5, vmin=0.5, vmax=2.5)
al = mm.plot_array(upper_active_layer + 1, alpha=0.25, edgecolor="none", cmap=layer_cmap, vmin=0.5, vmax=3.5,)
mm.plot_grid(color="black", lw=0.5)
plot_river(ax=mm.ax)

cax = mm.ax.inset_axes([.75, .85, .2, .05],)
cbar = plt.colorbar(al, orientation="horizontal", cax=cax)
cbar.ax.tick_params(
    labelsize=5,
    labelcolor="black",
    color="white",
    length=6,
    pad=2,
)
cax.set_xticks([1, 2, 3])
cax.set_xticklabels([1, 2, 3])
cbar.ax.set_title(
    "Upper most active model layer",
    pad=2.5,
    loc="left",
    fontdict=font_dict,
)

cax = mm.ax.inset_axes([.02, .065, .2, .05],)
cbar = plt.colorbar(dp, orientation="horizontal", cax=cax)
cbar.ax.tick_params(
    labelsize=5,
    labelcolor="black",
    color="white",
    length=6,
    pad=2,
)
cax.set_xticks([1, 2])
cax.set_xticklabels(["River", "Groundwater\nseepage"])
cbar.ax.set_title(
    "Discharge type",
    pad=2.5,
    loc="left",
    fontdict=font_dict,
);

In [None]:
def set_map_axis_labels(ax):
    ax.set_xticks(np.arange(0, 200000, 50000))
    ax.set_xticklabels(np.arange(0, 200, 50))
    ax.set_yticks(np.arange(0, 150000, 50000))
    ax.set_yticklabels(np.arange(0, 150, 50))
    ax.set_xlabel("x position (km)")
    ax.set_ylabel("y position (km)")  

In [None]:
def set_xsection_axis_labels(ax):
    xlim = ax.get_xlim()
    ax.set_ylim(-75, 125)
    ax.set_xticks(np.arange(0, xlim[1], 25000))
    ax.set_xticklabels([f"{value:.0f}" for value in np.arange(0, xlim[1]/1000, 25)])
    ax.set_yticks(np.arange(-75, 150, 25))
    ax.set_yticklabels([f"{value:.0f}" for value in np.arange(-75, 150, 25)])
    ax.set_xlabel("cross-section distance (km)")
    ax.set_ylabel("elevation (m)")  

In [None]:
figwidth = 17.15 / 2.54
figheight = 2.55 * (Ly / Lx) * 8.25 / 2.54
extent = (0, 180000, 0, 100000)
grid_dict = {"color":"black", "lw": 0.5}
arrowprops = dict(arrowstyle="-", edgecolor='red', lw=0.5, shrinkA=0.15, shrinkB=0.15,)
cbar_axis = [.75, .825, .2, .05]

with styles.USGSMap():
    fig = plt.figure(figsize=(figwidth, figheight), constrained_layout=True)
    gs = gridspec.GridSpec(ncols=2, nrows=11, figure=fig)
    axs = [fig.add_subplot(gs[:5, 0])]
    axs.append(fig.add_subplot(gs[:5, 1]))
    axs.append(fig.add_subplot(gs[5:10, 0]))
    axs.append(fig.add_subplot(gs[5:10, 1]))
    axs.append(fig.add_subplot(gs[10:, :]))
    
    # head
    ax = axs[0]
    ax.set_aspect("equal", "box")
    styles.heading(ax=ax, idx=0)    
    mm = flopy.plot.PlotMapView(model=gwf, ax=ax)
    cb = mm.plot_array(head, ec="none", vmin=vmin, vmax=vmax)
    mm.plot_grid(**grid_dict)
    plot_river(ax=ax)
    cs = mm.contour_array(
        head, colors=contour_color, levels=levels, linewidths=1,
    )
    ax.clabel(cs, inline=1, fmt="%1.0f", fontsize=6, inline_spacing=0)

    q = mm.plot_vector(qx_top, qy_top, normalize=True)
    qk = plt.quiverkey(q, 0.95, 1.05, 1, label="1 m/d", labelpos="W", labelcolor="black", fontproperties={"size":8})
    mm.ax.axhline(y=42500, lw=1, color="red")
    set_map_axis_labels(ax)

    # colorbar for head
    cax = mm.ax.inset_axes(cbar_axis,)
    cbar = plt.colorbar(cb, orientation="horizontal", cax=cax)
    cbar.ax.tick_params(
        labelsize=5,
        labelcolor="black",
        color="white",
        length=5,
        pad=2,
    )
    cbar.ax.set_title(
        "Head (m)",
        pad=2.5,
        loc="left",
        fontdict=font_dict,
    )    
    
    # cross-section lines
    styles.add_text(ax=ax, text="A", x=400, y=42600, transform=False, bold=True, color="red",)
    styles.add_text(ax=ax, text="A'", x=extent[1]-400, y=42600, transform=False, bold=True, color="red",)
    mm.ax.plot([72500,72500],[9000,95000], lw=1, color="red") 
    styles.add_annotation(ax=ax, 
                          text="B", 
                          xy=(72500,95000), 
                          xytext=(-15, -15), 
                          textcoords="offset points",
                          arrowprops=arrowprops,
                          bold=True, color="red",)
    styles.add_annotation(ax=ax, 
                          text="B'", 
                          xy=(72500,9000), 
                          xytext=(15, 10), 
                          textcoords="offset points",
                          arrowprops=arrowprops,
                          bold=True, color="red",)
    

    # cell-by-cell
    ax = axs[1]
    ax.set_aspect("equal", "box")
    styles.heading(ax=ax, idx=1)    
    mm = flopy.plot.PlotMapView(model=gwf, ax=ax)
    aml = mm.plot_array(upper_active_layer + 1, edgecolor="none", cmap=layer_cmap,
                        vmin=0.5, vmax=3.5,)
    dp = mm.plot_array(drn_q_loc, masked_values=[0], cmap=drain_cmap, edgecolor="none", vmin=0.5, vmax=2.5,)
    mm.plot_grid(**grid_dict)
    plot_river(ax=ax);    
    set_map_axis_labels(ax)
    
    # color bar for B (model layer)
    cax = mm.ax.inset_axes(cbar_axis,)
    cbar = plt.colorbar(aml, orientation="horizontal", cax=cax)
    cbar.ax.tick_params(
        labelsize=5,
        labelcolor="black",
        color="none",
        length=5,
        pad=2,
    )
    cax.set_xticks([1, 2, 3])
    cax.set_xticklabels([1, 2, 3])
    cbar.ax.set_title(
        "Water-table layer",
        pad=2.5,
        loc="left",
        fontdict=font_dict,
    )

    # color bar for B (drain locations)
    cax = mm.ax.inset_axes([.02, .065, .2, .05],)
    cbar = plt.colorbar(dp, orientation="horizontal", cax=cax)
    cbar.ax.tick_params(
        labelsize=5,
        labelcolor="black",
        color="none",
        length=6,
        pad=2,
    )
    cax.set_xticks([1, 2])
    cax.set_xticklabels(["River", "Seepage"])
    cbar.ax.set_title(
        "Discharge type",
        pad=2.5,
        loc="left",
        fontdict=font_dict,
    );    
    
    # east-west cross-section
    ax = axs[2]
    styles.heading(ax=ax, idx=2)    
    fx = flopy.plot.PlotCrossSection(
        model=gwf, 
        ax=ax, 
        line={"line": [(0, 42500), (extent[1], 42500)]}
    )
    cb = fx.plot_array(head, head=head, vmin=vmin, vmax=vmax)
    fx.plot_grid(**grid_dict)    
    set_xsection_axis_labels(ax)

    # colorbar for head
    cax = fx.ax.inset_axes(cbar_axis,)
    cbar = plt.colorbar(cb, orientation="horizontal", cax=cax)
    cbar.ax.tick_params(
        labelsize=5,
        labelcolor="black",
        color="white",
        length=5,
        pad=2,
    )
    cbar.ax.set_title(
        "Head (m)",
        pad=2.5,
        loc="left",
        fontdict=font_dict,
    )    
    styles.add_annotation(ax=ax, 
                          text="A", 
                          xy=(0,105), 
                          xytext=(15, 2), 
                          textcoords="offset points",
                          arrowprops=arrowprops,
                          bold=True, color="red",)
    styles.add_annotation(ax=ax, 
                          text="A'", 
                          xy=(ax.get_xlim()[1],0), 
                          xytext=(-15, 17), 
                          textcoords="offset points",
                          arrowprops=arrowprops,
                          bold=True, color="red",)

    
    # north-south cross-section
    ax = axs[3]
    styles.heading(ax=ax, idx=3)    
    fx = flopy.plot.PlotCrossSection(
        model=gwf,
        ax=ax,
        line={"line": [(72500, extent[3]), (72500, extent[2])]},
    )
    cb = fx.plot_array(head, head=head, vmin=60, vmax=70)
    fx.plot_grid(**grid_dict)
    set_xsection_axis_labels(ax)

    # colorbar for head
    cax = fx.ax.inset_axes(cbar_axis,)
    cbar = plt.colorbar(cb, orientation="horizontal", cax=cax)
    cbar.ax.tick_params(
        labelsize=5,
        labelcolor="black",
        color="white",
        length=5,
        pad=2,
    )
    cbar.ax.set_title(
        "Head (m)",
        pad=2.5,
        loc="left",
        fontdict=font_dict,
    )    
    styles.add_annotation(ax=ax, 
                          text="B", 
                          xy=(0,80), 
                          xytext=(15, 15), 
                          textcoords="offset points",
                          arrowprops=arrowprops,
                          bold=True, color="red",)
    styles.add_annotation(ax=ax, 
                          text="B'", 
                          xy=(ax.get_xlim()[1],68), 
                          xytext=(-11, 19), 
                          textcoords="offset points",
                          arrowprops=arrowprops,
                          bold=True, color="red",)

    
    # legend
    ax = axs[4]
    xy0 = (-100, -100)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_axis_off()

    ax.axhline(xy0[0], color="blue", lw=0.5, label="River")
    ax.axhline(xy0[0], color="red", lw=1, label="Cross-section line")
    ax.axhline(xy0[0], color=contour_color, lw=1, label="Head contour (m)")
    styles.graph_legend(
        ax,
        ncol=3,
        loc="lower center",
        labelspacing=0.5,
        columnspacing=0.6,
        handletextpad=0.3,
    )

    

    fpth = os.path.join("..", "doc", "figures", "grids_flopy_plots.png")
    plt.savefig(fpth, dpi=300);