In [None]:
import pathlib as pl
import numpy as np
import matplotlib.pyplot as plt
import flopy
from flopy.mf6.utils import Mf6Splitter

# Split Model Location

In [None]:
name = "ex1"
ws = pl.Path("working/single")

In [None]:
ex_pth = "../../.pixi/env/bin/mf6"

# Load the base model

In [None]:
sim = flopy.mf6.MFSimulation.load(sim_name=name, sim_ws=ws, exe_name=ex_pth)

In [None]:
gwf = sim.get_model()

# Load the split model

In [None]:
split_ws = pl.Path("working/split")

In [None]:
new_sim = flopy.mf6.MFSimulation.load(sim_name=name, sim_ws=split_ws, exe_name=ex_pth)

# Change the workspace, write the model, and run in parallel

In [None]:
new_ws = pl.Path("working/parallel")
new_sim.set_sim_path(new_ws)

In [None]:
new_sim.write_simulation()

In [None]:
nprocessors = len(new_sim.model_names)

In [None]:
new_sim.run_simulation(processors=nprocessors)

## Get model output

### Single model

In [None]:
kstpkper = (0,3)

In [None]:
head = gwf.output.head().get_data(kstpkper=kstpkper)
head.shape

In [None]:
spdis = gwf.output.budget().get_data(text="SPDIS", kstpkper=kstpkper)[0]

In [None]:
qx, qy, qz = flopy.utils.postprocessing.get_specific_discharge(spdis, gwf)

### Parallel model

In [None]:
json_path = pl.Path("working/split/split.json")
mfsplit = Mf6Splitter(sim)
mfsplit.load_node_mapping(new_sim, json_path)

In [None]:
head_dict = {}
for idx, modelname in enumerate(new_sim.model_names):
    mnum = int(modelname.split("_")[-1])
    h = new_sim.get_model(modelname).output.head().get_data(kstpkper=kstpkper)
    head_dict[mnum] = h

In [None]:
new_head = mfsplit.reconstruct_array(head_dict)
new_head.shape

## Compare head results

In [None]:
np.allclose(head, new_head)

## Plot model results

In [None]:
plt_lay = 2

In [None]:
vmin, vmax = head[plt_lay].min(), head[plt_lay].max()
vmin, vmax

### Map

In [None]:
fig = plt.figure(figsize=(15,7))

ax = fig.add_subplot(1, 3, 1)
pmv = flopy.plot.PlotMapView(model=sim.get_model(), ax=ax, layer=plt_lay)
hp = pmv.plot_array(head, vmin=vmin, vmax=vmax)
pmv.plot_grid()
pmv.plot_vector(qx, qy, normalize=True)
cb = plt.colorbar(hp, ax=ax, shrink=0.75, orientation="horizontal")

ax = fig.add_subplot(1, 3, 2)
pmv = flopy.plot.PlotMapView(model=sim.get_model(), ax=ax, layer=plt_lay)
hp = pmv.plot_array(new_head, vmin=vmin, vmax=vmax)
for name in new_sim.model_names:
    new_gwf = new_sim.get_model(name)
    # h = new_gwf.output.head().get_data(kstpkper=kstpkper)
    spdis = new_gwf.output.budget().get_data(text="SPDIS", kstpkper=kstpkper)[0]
    qx, qy, qz = flopy.utils.postprocessing.get_specific_discharge(spdis, new_gwf)

    pmv = flopy.plot.PlotMapView(model=new_gwf, ax=ax, layer=plt_lay, extent=gwf.modelgrid.extent)
    # hp = pmv.plot_array(h, vmin=vmin, vmax=vmax)
    pmv.plot_grid()
    pmv.plot_vector(qx, qy, normalize=True)
cb = plt.colorbar(hp, ax=ax, shrink=0.75, orientation="horizontal")   

ax = fig.add_subplot(1, 3, 3)
pmv = flopy.plot.PlotMapView(model=sim.get_model(), ax=ax, layer=plt_lay)
dp = pmv.plot_array(head-new_head)
pmv.plot_grid()
cb = plt.colorbar(dp, ax=ax, shrink=0.75, orientation="horizontal")
