In [None]:
import pathlib as pl
import numpy as np
import matplotlib.pyplot as plt
import flopy
from flopy.mf6.utils import Mf6Splitter

# Function to split model domain

In [None]:
def simple_mapping(
    nrow_blocks: int,
    ncol_blocks: int,
    modelgrid: flopy.discretization.StructuredGrid,
) -> np.ndarray:
    """
    Create a simple block-based mapping array for a structured grid

    Parameters
    ----------
    nrow_blocks: int
        Number of models in the row direction of a domain.
    ncol_blocks: int
        Number of models in the column direction of a domain.
    modelgrid: flopy.discretization.StructuredGrid
        flopy modelgrid object

    Returns
    -------
    mask: np.ndarray
        block-based mapping array for the model splitter

    """
    if modelgrid.grid_type != "structured":
        raise ValueError(
            f"modelgrid must be 'structured' not {modelgrid.grid_type}"
        )
    nrow, ncol = modelgrid.nrow, modelgrid.ncol
    row_inc, col_inc = int(nrow / nrow_blocks), int(ncol / ncol_blocks)

    # create a list of row boundaries
    icnt = 0
    row_blocks = [icnt]
    for i in range(nrow_blocks):
        icnt += row_inc
        row_blocks.append(icnt)
    if row_blocks[-1] < nrow:
        row_blocks[-1] = nrow

    # create a list of column boundaries
    icnt = 0
    col_blocks = [icnt]
    for i in range(ncol_blocks):
        icnt += col_inc
        col_blocks.append(icnt)
    if col_blocks[-1] < ncol:
        col_blocks[-1] = ncol

    # create masking array - zero-based model number
    mask = np.zeros((nrow, ncol), dtype=int)
    ival = 0
    model_row_col_offset = {}
    for idx in range(len(row_blocks) - 1):
        for jdx in range(len(col_blocks) - 1):
            mask[
                row_blocks[idx] : row_blocks[idx + 1],
                col_blocks[jdx] : col_blocks[jdx + 1],
            ] = ival
            model_row_col_offset[ival - 1] = (row_blocks[idx], col_blocks[jdx])
            # increment model number
            ival += 1

    return mask

# Base Model Location

In [None]:
name = "ex1"
ws = pl.Path("working/single")

In [None]:
ex_pth = "../../.pixi/env/bin/mf6"

# Load the base model

In [None]:
sim = flopy.mf6.MFSimulation.load(sim_name=name, sim_ws=ws, exe_name=ex_pth, use_pandas=False)

In [None]:
gwf = sim.get_model()

In [None]:
gwf.dis

# Split the model

## Create the splitting array

In [None]:
sarr = simple_mapping(2, 1, gwf.modelgrid)

In [None]:
sarr.shape

In [None]:
v = plt.imshow(sarr)
plt.colorbar(v)

## Split the base model

In [None]:
new_ws = pl.Path("working/split")

In [None]:
mfsplit = Mf6Splitter(sim)

In [None]:
new_sim = mfsplit.split_model(sarr)

In [None]:
new_sim.set_sim_path(new_ws)

In [None]:
new_sim.exe_name = ex_pth

## Write the model files and run the simulation

In [None]:
new_sim.write_simulation()

In [None]:
new_sim.run_simulation()

## Get model output

In [None]:
gwf.output.methods()

In [None]:
gwf.output.head().get_times()

In [None]:
gwf.output.head().get_kstpkper()

In [None]:
kstpkper = (0,2)

In [None]:
head = gwf.output.head().get_data(kstpkper=kstpkper)

In [None]:
spdis = gwf.output.budget().get_data(text="SPDIS", kstpkper=kstpkper)[0]

In [None]:
qx, qy, qz = flopy.utils.postprocessing.get_specific_discharge(spdis, gwf)

## Plot model results

In [None]:
plt_lay = 2

In [None]:
vmin, vmax = head[plt_lay].min(), head[plt_lay].max()
vmin, vmax

### Map

In [None]:
fig = plt.figure(figsize=(10,7))

ax = fig.add_subplot(1, 2, 1)
pmv = flopy.plot.PlotMapView(model=sim.get_model(), ax=ax, layer=plt_lay)
hp = pmv.plot_array(head, vmin=vmin, vmax=vmax)
pmv.plot_grid()
pmv.plot_vector(qx, qy, normalize=True)
cb = plt.colorbar(hp, ax=ax, shrink=0.75, orientation="horizontal")

ax = fig.add_subplot(1, 2, 2)
for name in new_sim.model_names:
    new_gwf = new_sim.get_model(name)
    h = new_gwf.output.head().get_data(kstpkper=kstpkper)
    spdis = new_gwf.output.budget().get_data(text="SPDIS", kstpkper=kstpkper)[0]
    qx, qy, qz = flopy.utils.postprocessing.get_specific_discharge(spdis, new_gwf)

    pmv = flopy.plot.PlotMapView(model=new_gwf, ax=ax, layer=plt_lay, extent=gwf.modelgrid.extent)
    hp = pmv.plot_array(h, vmin=vmin, vmax=vmax)
    pmv.plot_grid()
    pmv.plot_vector(qx, qy, normalize=True)
cb = plt.colorbar(hp, ax=ax, shrink=0.75, orientation="horizontal")   