Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
save_budget={"*": "all"},
)

# sim.write()
sim.run(verbose=True)

# check CHD
Expand Down
7 changes: 3 additions & 4 deletions flopy4/mf6/codec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flopy4.mf6.codec.converter import (
structure_array,
unstructure_array,
unstructure_chd,
unstructure_component,
unstructure_oc,
unstructure_tdis,
Expand Down Expand Up @@ -40,24 +41,22 @@

def _make_converter() -> Converter:
from flopy4.mf6.component import Component
from flopy4.mf6.gwf.chd import Chd
from flopy4.mf6.gwf.oc import Oc
from flopy4.mf6.tdis import Tdis

converter = Converter()
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
converter.register_unstructure_hook(Component, unstructure_component)
converter.register_unstructure_hook(Tdis, unstructure_tdis)
converter.register_unstructure_hook(Chd, unstructure_chd)
converter.register_unstructure_hook(Oc, unstructure_oc)
return converter


_CONVERTER = _make_converter()


# TODO unstructure arrays into sparse dicts
# TODO combine OC fields into list input as defined in the MF6 dfn


def loads(data: str) -> Any:
# TODO
pass
Expand Down
171 changes: 119 additions & 52 deletions flopy4/mf6/codec/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import numpy as np
import sparse
import xattree
from flopy.discretization.grid import Grid
from flopy.discretization.structuredgrid import StructuredGrid
from flopy.discretization.unstructuredgrid import UnstructuredGrid
from flopy.discretization.vertexgrid import VertexGrid
from numpy.typing import NDArray
from xarray import DataArray
from xattree import get_xatspec

from flopy4.mf6.component import Component
from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA
from flopy4.mf6.spec import get_blocks
from flopy4.mf6.spec import get_blocks, is_list_field


# TODO: convert to a cattrs structuring hook so we don't have to
Expand Down Expand Up @@ -87,20 +91,16 @@ def _get_nn(cellid):
match len(shape):
case 1:
set_(a, period, kper)
# a[(kper,)] = period
case _:
for cellid, v in period.items():
nn = _get_nn(cellid)
set_(a, v, kper, nn)
# a[(kper, nn)] = v
if kper == "*":
break
else:
for cellid, v in value.items():
nn = _get_nn(cellid)
set_(a, v, nn)
# a[(nn,)] = v

return final(a)


Expand All @@ -109,36 +109,47 @@ def unstructure_array(value: DataArray) -> dict:
Convert a dense numpy array or a sparse COO array to a sparse
dictionary representation suitable for serialization into the
MF6 list-based input format.

The input array must have a time dimension named 'nper', i.e.
it must be stress period data for some MODFLOW 6 component.

Returns:
dict: {kper: {spatial indices: value, ...}, ...}
"""
# make sure dim 'kper' is present
time_dim = "nper"
if time_dim not in value.dims:
if (time_dim := "nper") not in value.dims:
raise ValueError(f"Array must have dimension '{time_dim}'")

if isinstance(value.data, sparse.COO):
coords = value.coords
data = value.data
else:
coords = np.array(np.nonzero(value.data)).T # type: ignore
coords = np.array(np.where(value.data != FILL_DNODATA)).T # type: ignore
data = value.data[tuple(coords.T)] # type: ignore
if not coords.size: # type: ignore
return {}
result = {}
match value.ndim:
case 1:
return {int(k): v for k, v in zip(coords[:, 0], data)} # type: ignore
case 2:
return {(int(k), int(j)): v for (k, j), v in zip(coords, data)} # type: ignore
case 3:
return {(int(k), int(i), int(j)): v for (k, i, j), v in zip(coords, data)} # type: ignore
return {}
# Only kper, no spatial dims
for kper, v in zip(coords[:, 0], data):
result[int(kper)] = v
case _:
# kper + spatial dims
for row, v in zip(coords, data):
kper = int(row[0]) # type: ignore
spatial = tuple(int(x) for x in row[1:]) # type: ignore
if kper not in result:
result[kper] = {}
# flatten spatial index if only one spatial dim
key = spatial[0] if len(spatial) == 1 else spatial
result[kper][key] = v
return result


def unstructure_component(value: Component) -> dict[str, Any]:
data = xattree.asdict(value)
for block in get_blocks(value.dfn).values():
for field_name, field in block.items():
# unstructure arrays destined for list-based input
if field["type"] == "recarray" and field["reader"] != "readarray":
if is_list_field(field):
data[field_name] = unstructure_array(data[field_name])
return data

Expand All @@ -148,63 +159,119 @@ def unstructure_tdis(value: Any) -> dict[str, Any]:
blocks = get_blocks(value.dfn)
for block_name, block in blocks.items():
if block_name == "perioddata":
array_fields = list(block.keys())

# Unstructure all arrays and collect all unique periods
arrays = {}
arrs_d = {}
periods = set() # type: ignore
for field_name in array_fields:
arr = unstructure_array(data.get(field_name, {}))
arrays[field_name] = arr
periods.update(arr.keys())
for field_name in block.keys():
arr = data.get(field_name, None)
arr_d = {} if arr is None else unstructure_array(arr)
arrs_d[field_name] = arr_d
periods.update(arr_d.keys())
periods = sorted(periods) # type: ignore

perioddata = {} # type: ignore
for kper in periods:
line = []
for arr in arrays.values():
if kper not in perioddata:
perioddata[kper] = [] # type: ignore
line.append(arr[kper])
if kper not in perioddata:
perioddata[kper] = [] # type: ignore
for arr_d in arrs_d.values():
if val := arr_d.get(kper, None):
line.append(val)
perioddata[kper] = tuple(line)

data["perioddata"] = perioddata
return data


def get_kij(nn: int, nlay: int, nrow: int, ncol: int) -> tuple[int, int, int]:
nodes = nlay * nrow * ncol
if nn < 0 or nn >= nodes:
raise ValueError(f"Node number {nn} is out of bounds (1 to {nodes})")
k = (nn - 1) / (ncol * nrow) + 1
ij = nn - (k - 1) * ncol * nrow
i = (ij - 1) / ncol + 1
j = ij - (i - 1) * ncol
return int(k), int(i), int(j)


def get_jk(nn: int, ncpl: int) -> tuple[int, int]:
if nn < 0 or nn >= ncpl:
raise ValueError(f"Node number {nn} is out of bounds (1 to {ncpl})")
k = (nn - 1) / ncpl + 1
j = nn - (k - 1) * ncpl
return int(j), int(k)


def get_cellid(nn: int, grid: Grid) -> tuple[int, ...]:
match grid:
case StructuredGrid():
return get_kij(nn, *grid.shape)
case VertexGrid():
return get_jk(nn, grid.ncpl)
case UnstructuredGrid():
return (nn,)
case _:
raise TypeError(f"Unsupported grid type: {type(grid)}")


def unstructure_chd(value: Any) -> dict[str, Any]:
if (parent := value.parent) is None:
raise ValueError(
"CHD cannot be unstructured without a parent "
"model and corresponding grid discretization."
)
grid = parent.grid
data = xattree.asdict(value)
blocks = get_blocks(value.dfn)
for block_name, block in blocks.items():
if block_name == "period":
arrs_d = {}
periods = set() # type: ignore
for field_name in block.keys():
arr = data.get(field_name, None)
arr_d = {} if arr is None else unstructure_array(arr)
arrs_d[field_name] = arr_d
periods.update(arr_d.keys())
periods = sorted(periods) # type: ignore
perioddata = {} # type: ignore
for kper in periods:
line = []
if kper not in perioddata:
perioddata[kper] = [] # type: ignore
for arr_d in arrs_d.values():
if val := arr_d.get(kper, None):
for nn, v in val.items():
cellid = get_cellid(nn, grid)
line.append((*cellid, v))
perioddata[kper] = tuple(line)
data["period"] = perioddata
return data


def unstructure_oc(value: Any) -> dict[str, Any]:
data = xattree.asdict(value)
blocks = get_blocks(value.dfn)
for block_name, block in blocks.items():
if block_name == "period":
# Dynamically collect all recarray fields in perioddata block
array_fields = []
fields = []
for field_name, field in block.items():
# Try to split field_name into action and kind, e.g. save_head -> ("save", "head")
action, rtype = field_name.split("_")
array_fields.append((action, rtype, field_name))

# Unstructure all arrays and collect all unique periods
arrays = {}
fields.append((action, rtype, field_name))
arrs_d = {}
periods = set() # type: ignore
for action, rtype, field_name in array_fields:
arr = unstructure_array(data.get(field_name, {}))
arrays[(action, rtype)] = arr
periods.update(arr.keys())
for action, rtype, field_name in fields:
arr = data.get(field_name, None)
arr_d = {} if arr is None else unstructure_array(arr)
arrs_d[(action, rtype)] = arr_d
periods.update(arr_d.keys())
periods = sorted(periods) # type: ignore

perioddata = {} # type: ignore
for kper in periods:
for (action, rtype), arr in arrays.items():
if kper in arr:
if kper not in perioddata:
perioddata[kper] = []
perioddata[kper].append((action, rtype, arr[kper]))

if kper not in perioddata:
perioddata[kper] = []
for (action, rtype), arr_d in arrs_d.items():
if arr := arr_d.get(kper, None):
perioddata[kper].append((action, rtype, arr))
data["period"] = perioddata
else:
for field_name, field in block.items():
# unstructure arrays destined for list-based input
if field["type"] == "recarray" and field["reader"] != "readarray":
if is_list_field(field):
data[field_name] = unstructure_array(data[field_name])
return data
14 changes: 3 additions & 11 deletions flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@
from modflow_devtools.dfn import Dfn, Field
from numpy.typing import NDArray

from flopy4.mf6.spec import get_blocks


def _is_list_block(block: dict) -> bool:
return (
len(block) == 1
and (field := next(iter(block.values())))["type"] == "recarray"
and field["reader"] != "readarray"
) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values()))
from flopy4.mf6.spec import get_blocks, is_list_block


def dict_blocks(dfn: Dfn) -> dict:
Expand All @@ -28,13 +20,13 @@ def dict_blocks(dfn: Dfn) -> dict:
return {
block_name: block
for block_name, block in get_blocks(dfn).items()
if not _is_list_block(block)
if not is_list_block(block)
}


def list_blocks(dfn: Dfn) -> dict:
return {
block_name: block for block_name, block in get_blocks(dfn).items() if _is_list_block(block)
block_name: block for block_name, block in get_blocks(dfn).items() if is_list_block(block)
}


Expand Down
8 changes: 4 additions & 4 deletions flopy4/mf6/gwf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ class Output:
def head(self) -> xr.DataArray:
# TODO support other extensions than .hds (e.g. .hed)
return open_hds(
self.parent.parent.path / f"{self.parent.name}.hds", # type: ignore
self.parent.parent.path / f"{self.parent.name}.dis.grb", # type: ignore
self.parent.parent.workspace / f"{self.parent.name}.hds", # type: ignore
self.parent.parent.workspace / f"{self.parent.name}.dis.grb", # type: ignore
)

@property
def budget(self):
# TODO support other extensions than .bud (e.g. .cbc)
return open_cbc(
self.parent.parent.path / f"{self.parent.name}.bud",
self.parent.parent.path / f"{self.parent.name}.dis.grb",
self.parent.parent.workspace / f"{self.parent.name}.bud",
self.parent.parent.workspace / f"{self.parent.name}.dis.grb",
)

dis: Dis = field(converter=convert_grid)
Expand Down
Loading
Loading