# Split the basin example base model

## Notebook Setup

In [None]:
import os
import sys
import shutil
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from shapely.geometry import Polygon, LineString
import flopy
from flopy.discretization import StructuredGrid
import flopy.plot.styles as styles

In [None]:
# import all plot style information from defaults.py
from defaults import *

In [None]:
from model_splitter import Mf6Splitter

#### Load the base basin model

In [None]:
base_ws = "../examples/basin_base"

In [None]:
sim = flopy.mf6.MFSimulation.load(sim_name="basin", sim_ws=base_ws)

In [None]:
gwf = sim.get_model()

In [None]:
nlay, nrow, ncol = gwf.dis.nlay.array, gwf.dis.nrow.array, gwf.dis.ncol.array
nlay, nrow, ncol

### Build a splitting array

In [None]:
nrow_blocks, ncol_blocks = 2, 1

In [None]:
nproc = nrow_blocks * ncol_blocks
nproc

In [None]:
row_inc, col_inc = int(nrow / nrow_blocks), int(ncol / ncol_blocks)
row_inc, col_inc

In [None]:
icnt = 0
row_blocks = [icnt]
for i in range(nrow_blocks):
    icnt += row_inc
    row_blocks.append(icnt)
if row_blocks[-1] < nrow:
    row_blocks[-1] = nrow
row_blocks

In [None]:
icnt = 0
col_blocks = [icnt]
for i in range(ncol_blocks):
    icnt += col_inc
    col_blocks.append(icnt)
if col_blocks[-1] < ncol:
    col_blocks[-1] = ncol
col_blocks

In [None]:
mask = np.zeros((nrow, ncol), dtype=int)

In [None]:
# create masking array
ival = 1
model_row_col_offset = {}
for idx in range(len(row_blocks) - 1):
    for jdx in range(len(col_blocks) - 1):
        mask[
            row_blocks[idx] : row_blocks[idx + 1],
            col_blocks[jdx] : col_blocks[jdx + 1],
        ] = ival
        model_row_col_offset[ival - 1] = (row_blocks[idx], col_blocks[jdx])
        # increment model number
        ival += 1

In [None]:
model_row_col_offset

In [None]:
np.unique(mask)

In [None]:
plt.imshow(mask)

### Split into (nrow_blocks, ncol_blocks) models

In [None]:
new_ws = f"../examples/basin_{nrow_blocks}x{ncol_blocks}_{nproc}p"
new_ws

In [None]:
mfsplit = Mf6Splitter(sim)

In [None]:
new_sim = mfsplit.split_model(mask)

In [None]:
shutil.rmtree(new_ws, ignore_errors=True)
new_sim.set_sim_path(new_ws)

### Plot the submodels

In [None]:
with styles.USGSMap():
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot()
    ax.set_xlim(0, Lx)
    ax.set_ylim(0, Ly)
    ax.set_aspect("equal")
    for name in new_sim.model_names:
        m = new_sim.get_model(name)
        pmv = flopy.plot.PlotMapView(
            modelgrid=m.modelgrid, ax=ax, extent=extent
        )
        pmv.plot_array(m.dis.top.array, vmin=vmin, vmax=vmax)
        pmv.plot_inactive()
    # ax.plot(bp[:, 0], bp[:, 1], "r-")
    # for sg in sgs:
    #    sa = np.array(sg)
    #    ax.plot(sa[:, 0], sa[:, 1], "b-")

### Write and run the simulation 

In [None]:
new_sim.write_simulation()

### Write the PETSc rc file

In [None]:
write_petscdb(new_ws, nproc)

### Run the model in parallel

In [None]:
new_sim.exe_name = "mf6p"

In [None]:
new_sim.run_simulation(processors=nproc)

### Plot the multi-model and single model heads

In [None]:
gwf_base = sim.get_model()
gwfhead_tot = gwf_base.output.head().get_data()
hmin, hmax = (
    gwfhead_tot.min(),
    np.where(gwfhead_tot < 1e30, gwfhead_tot, 0).max(),
)
contours = np.arange(0, 100, 10)

#### Build a dictionary with the model heads for each partition

In [None]:
model_heads = {
    idx + 1: new_sim.get_model(name).output.head().get_data()
    for idx, name in enumerate(new_sim.model_names)
}

#### Build a single head array

In [None]:
head_tot = mfsplit.reconstruct_array(model_heads)

#### Plot the results

In [None]:
with styles.USGSMap():
    fig = plt.figure(figsize=(figwidth, figheight * 1.3333))
    t = head_tot - gwfhead_tot
    hv = [head_tot, gwfhead_tot, t]
    titles = ["Multiple models", "Single model", "Multiple - single"]
    for idx in range(3):
        ax = fig.add_subplot(3, 1, idx + 1)
        ax.set_aspect("equal")
        ax.set_title(titles[idx])

        if idx < 2:
            levels = contours
            vmin = hmin
            vmax = hmax
            masked_values = None
        else:
            levels = None
            vmin = None
            vmax = None
            masked_values = None

        pmv = flopy.plot.PlotMapView(model=gwf_base, ax=ax, layer=0)
        h = pmv.plot_array(hv[idx], vmin=vmin, vmax=vmax)
        if levels is not None:
            c = pmv.contour_array(
                hv[idx],
                levels=levels,
                colors="white",
                linewidths=0.75,
                linestyles=":",
            )
            plt.clabel(c, fontsize=8)
        pmv.plot_inactive()
        plt.colorbar(h, ax=ax, shrink=0.5)