diff --git a/docs/examples/quickstart.py b/docs/examples/quickstart.py index fc6a24cd..cba8044e 100644 --- a/docs/examples/quickstart.py +++ b/docs/examples/quickstart.py @@ -30,6 +30,7 @@ save_budget={"*": "all"}, ) +# sim.write() sim.run(verbose=True) # check CHD diff --git a/flopy4/mf6/codec/__init__.py b/flopy4/mf6/codec/__init__.py index b9d5f05b..83e279fc 100644 --- a/flopy4/mf6/codec/__init__.py +++ b/flopy4/mf6/codec/__init__.py @@ -11,6 +11,7 @@ from flopy4.mf6.codec.converter import ( structure_array, unstructure_array, + unstructure_chd, unstructure_component, unstructure_oc, unstructure_tdis, @@ -40,6 +41,7 @@ def _make_converter() -> Converter: from flopy4.mf6.component import Component + from flopy4.mf6.gwf.chd import Chd from flopy4.mf6.gwf.oc import Oc from flopy4.mf6.tdis import Tdis @@ -47,6 +49,7 @@ def _make_converter() -> Converter: converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict) converter.register_unstructure_hook(Component, unstructure_component) converter.register_unstructure_hook(Tdis, unstructure_tdis) + converter.register_unstructure_hook(Chd, unstructure_chd) converter.register_unstructure_hook(Oc, unstructure_oc) return converter @@ -54,10 +57,6 @@ def _make_converter() -> Converter: _CONVERTER = _make_converter() -# TODO unstructure arrays into sparse dicts -# TODO combine OC fields into list input as defined in the MF6 dfn - - def loads(data: str) -> Any: # TODO pass diff --git a/flopy4/mf6/codec/converter.py b/flopy4/mf6/codec/converter.py index ee0ad632..8caf360b 100644 --- a/flopy4/mf6/codec/converter.py +++ b/flopy4/mf6/codec/converter.py @@ -3,6 +3,10 @@ import numpy as np import sparse import xattree +from flopy.discretization.grid import Grid +from flopy.discretization.structuredgrid import StructuredGrid +from flopy.discretization.unstructuredgrid import UnstructuredGrid +from flopy.discretization.vertexgrid import VertexGrid from numpy.typing import NDArray from xarray import DataArray from xattree import get_xatspec @@ -10,7 +14,7 @@ from flopy4.mf6.component import Component from flopy4.mf6.config import SPARSE_THRESHOLD from flopy4.mf6.constants import FILL_DNODATA -from flopy4.mf6.spec import get_blocks +from flopy4.mf6.spec import get_blocks, is_list_field # TODO: convert to a cattrs structuring hook so we don't have to @@ -87,20 +91,16 @@ def _get_nn(cellid): match len(shape): case 1: set_(a, period, kper) - # a[(kper,)] = period case _: for cellid, v in period.items(): nn = _get_nn(cellid) set_(a, v, kper, nn) - # a[(kper, nn)] = v if kper == "*": break else: for cellid, v in value.items(): nn = _get_nn(cellid) set_(a, v, nn) - # a[(nn,)] = v - return final(a) @@ -109,36 +109,47 @@ def unstructure_array(value: DataArray) -> dict: Convert a dense numpy array or a sparse COO array to a sparse dictionary representation suitable for serialization into the MF6 list-based input format. + + The input array must have a time dimension named 'nper', i.e. + it must be stress period data for some MODFLOW 6 component. + + Returns: + dict: {kper: {spatial indices: value, ...}, ...} """ - # make sure dim 'kper' is present - time_dim = "nper" - if time_dim not in value.dims: + if (time_dim := "nper") not in value.dims: raise ValueError(f"Array must have dimension '{time_dim}'") - if isinstance(value.data, sparse.COO): coords = value.coords data = value.data else: - coords = np.array(np.nonzero(value.data)).T # type: ignore + coords = np.array(np.where(value.data != FILL_DNODATA)).T # type: ignore data = value.data[tuple(coords.T)] # type: ignore if not coords.size: # type: ignore return {} + result = {} match value.ndim: case 1: - return {int(k): v for k, v in zip(coords[:, 0], data)} # type: ignore - case 2: - return {(int(k), int(j)): v for (k, j), v in zip(coords, data)} # type: ignore - case 3: - return {(int(k), int(i), int(j)): v for (k, i, j), v in zip(coords, data)} # type: ignore - return {} + # Only kper, no spatial dims + for kper, v in zip(coords[:, 0], data): + result[int(kper)] = v + case _: + # kper + spatial dims + for row, v in zip(coords, data): + kper = int(row[0]) # type: ignore + spatial = tuple(int(x) for x in row[1:]) # type: ignore + if kper not in result: + result[kper] = {} + # flatten spatial index if only one spatial dim + key = spatial[0] if len(spatial) == 1 else spatial + result[kper][key] = v + return result def unstructure_component(value: Component) -> dict[str, Any]: data = xattree.asdict(value) for block in get_blocks(value.dfn).values(): for field_name, field in block.items(): - # unstructure arrays destined for list-based input - if field["type"] == "recarray" and field["reader"] != "readarray": + if is_list_field(field): data[field_name] = unstructure_array(data[field_name]) return data @@ -148,63 +159,119 @@ def unstructure_tdis(value: Any) -> dict[str, Any]: blocks = get_blocks(value.dfn) for block_name, block in blocks.items(): if block_name == "perioddata": - array_fields = list(block.keys()) - - # Unstructure all arrays and collect all unique periods - arrays = {} + arrs_d = {} periods = set() # type: ignore - for field_name in array_fields: - arr = unstructure_array(data.get(field_name, {})) - arrays[field_name] = arr - periods.update(arr.keys()) + for field_name in block.keys(): + arr = data.get(field_name, None) + arr_d = {} if arr is None else unstructure_array(arr) + arrs_d[field_name] = arr_d + periods.update(arr_d.keys()) periods = sorted(periods) # type: ignore - perioddata = {} # type: ignore for kper in periods: line = [] - for arr in arrays.values(): - if kper not in perioddata: - perioddata[kper] = [] # type: ignore - line.append(arr[kper]) + if kper not in perioddata: + perioddata[kper] = [] # type: ignore + for arr_d in arrs_d.values(): + if val := arr_d.get(kper, None): + line.append(val) perioddata[kper] = tuple(line) - data["perioddata"] = perioddata return data +def get_kij(nn: int, nlay: int, nrow: int, ncol: int) -> tuple[int, int, int]: + nodes = nlay * nrow * ncol + if nn < 0 or nn >= nodes: + raise ValueError(f"Node number {nn} is out of bounds (1 to {nodes})") + k = (nn - 1) / (ncol * nrow) + 1 + ij = nn - (k - 1) * ncol * nrow + i = (ij - 1) / ncol + 1 + j = ij - (i - 1) * ncol + return int(k), int(i), int(j) + + +def get_jk(nn: int, ncpl: int) -> tuple[int, int]: + if nn < 0 or nn >= ncpl: + raise ValueError(f"Node number {nn} is out of bounds (1 to {ncpl})") + k = (nn - 1) / ncpl + 1 + j = nn - (k - 1) * ncpl + return int(j), int(k) + + +def get_cellid(nn: int, grid: Grid) -> tuple[int, ...]: + match grid: + case StructuredGrid(): + return get_kij(nn, *grid.shape) + case VertexGrid(): + return get_jk(nn, grid.ncpl) + case UnstructuredGrid(): + return (nn,) + case _: + raise TypeError(f"Unsupported grid type: {type(grid)}") + + +def unstructure_chd(value: Any) -> dict[str, Any]: + if (parent := value.parent) is None: + raise ValueError( + "CHD cannot be unstructured without a parent " + "model and corresponding grid discretization." + ) + grid = parent.grid + data = xattree.asdict(value) + blocks = get_blocks(value.dfn) + for block_name, block in blocks.items(): + if block_name == "period": + arrs_d = {} + periods = set() # type: ignore + for field_name in block.keys(): + arr = data.get(field_name, None) + arr_d = {} if arr is None else unstructure_array(arr) + arrs_d[field_name] = arr_d + periods.update(arr_d.keys()) + periods = sorted(periods) # type: ignore + perioddata = {} # type: ignore + for kper in periods: + line = [] + if kper not in perioddata: + perioddata[kper] = [] # type: ignore + for arr_d in arrs_d.values(): + if val := arr_d.get(kper, None): + for nn, v in val.items(): + cellid = get_cellid(nn, grid) + line.append((*cellid, v)) + perioddata[kper] = tuple(line) + data["period"] = perioddata + return data + + def unstructure_oc(value: Any) -> dict[str, Any]: data = xattree.asdict(value) blocks = get_blocks(value.dfn) for block_name, block in blocks.items(): if block_name == "period": - # Dynamically collect all recarray fields in perioddata block - array_fields = [] + fields = [] for field_name, field in block.items(): - # Try to split field_name into action and kind, e.g. save_head -> ("save", "head") action, rtype = field_name.split("_") - array_fields.append((action, rtype, field_name)) - - # Unstructure all arrays and collect all unique periods - arrays = {} + fields.append((action, rtype, field_name)) + arrs_d = {} periods = set() # type: ignore - for action, rtype, field_name in array_fields: - arr = unstructure_array(data.get(field_name, {})) - arrays[(action, rtype)] = arr - periods.update(arr.keys()) + for action, rtype, field_name in fields: + arr = data.get(field_name, None) + arr_d = {} if arr is None else unstructure_array(arr) + arrs_d[(action, rtype)] = arr_d + periods.update(arr_d.keys()) periods = sorted(periods) # type: ignore - perioddata = {} # type: ignore for kper in periods: - for (action, rtype), arr in arrays.items(): - if kper in arr: - if kper not in perioddata: - perioddata[kper] = [] - perioddata[kper].append((action, rtype, arr[kper])) - + if kper not in perioddata: + perioddata[kper] = [] + for (action, rtype), arr_d in arrs_d.items(): + if arr := arr_d.get(kper, None): + perioddata[kper].append((action, rtype, arr)) data["period"] = perioddata else: for field_name, field in block.items(): - # unstructure arrays destined for list-based input - if field["type"] == "recarray" and field["reader"] != "readarray": + if is_list_field(field): data[field_name] = unstructure_array(data[field_name]) return data diff --git a/flopy4/mf6/filters.py b/flopy4/mf6/filters.py index 4b7d3e22..da0a0f2d 100644 --- a/flopy4/mf6/filters.py +++ b/flopy4/mf6/filters.py @@ -7,15 +7,7 @@ from modflow_devtools.dfn import Dfn, Field from numpy.typing import NDArray -from flopy4.mf6.spec import get_blocks - - -def _is_list_block(block: dict) -> bool: - return ( - len(block) == 1 - and (field := next(iter(block.values())))["type"] == "recarray" - and field["reader"] != "readarray" - ) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values())) +from flopy4.mf6.spec import get_blocks, is_list_block def dict_blocks(dfn: Dfn) -> dict: @@ -28,13 +20,13 @@ def dict_blocks(dfn: Dfn) -> dict: return { block_name: block for block_name, block in get_blocks(dfn).items() - if not _is_list_block(block) + if not is_list_block(block) } def list_blocks(dfn: Dfn) -> dict: return { - block_name: block for block_name, block in get_blocks(dfn).items() if _is_list_block(block) + block_name: block for block_name, block in get_blocks(dfn).items() if is_list_block(block) } diff --git a/flopy4/mf6/gwf/__init__.py b/flopy4/mf6/gwf/__init__.py index 6a417e1e..ca2b515b 100644 --- a/flopy4/mf6/gwf/__init__.py +++ b/flopy4/mf6/gwf/__init__.py @@ -37,16 +37,16 @@ class Output: def head(self) -> xr.DataArray: # TODO support other extensions than .hds (e.g. .hed) return open_hds( - self.parent.parent.path / f"{self.parent.name}.hds", # type: ignore - self.parent.parent.path / f"{self.parent.name}.dis.grb", # type: ignore + self.parent.parent.workspace / f"{self.parent.name}.hds", # type: ignore + self.parent.parent.workspace / f"{self.parent.name}.dis.grb", # type: ignore ) @property def budget(self): # TODO support other extensions than .bud (e.g. .cbc) return open_cbc( - self.parent.parent.path / f"{self.parent.name}.bud", - self.parent.parent.path / f"{self.parent.name}.dis.grb", + self.parent.parent.workspace / f"{self.parent.name}.bud", + self.parent.parent.workspace / f"{self.parent.name}.dis.grb", ) dis: Dis = field(converter=convert_grid) diff --git a/flopy4/mf6/gwf/chd.py b/flopy4/mf6/gwf/chd.py index 9a8c90f4..d61a4307 100644 --- a/flopy4/mf6/gwf/chd.py +++ b/flopy4/mf6/gwf/chd.py @@ -7,6 +7,7 @@ from xattree import xattree from flopy4.mf6.codec import structure_array +from flopy4.mf6.constants import FILL_DNODATA from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -41,6 +42,7 @@ class Steps: ), default=None, converter=Converter(structure_array, takes_self=True, takes_field=True), + reader="urword", ) aux: Optional[NDArray[np.floating]] = array( block="period", @@ -50,6 +52,7 @@ class Steps: ), default=None, converter=Converter(structure_array, takes_self=True, takes_field=True), + reader="urword", ) boundname: Optional[NDArray[np.str_]] = array( block="period", @@ -59,6 +62,7 @@ class Steps: ), default=None, converter=Converter(structure_array, takes_self=True, takes_field=True), + reader="urword", ) steps: Optional[NDArray[np.object_]] = array( Steps, @@ -66,4 +70,17 @@ class Steps: dims=("nper", "nnodes"), default=None, converter=Converter(structure_array, takes_self=True, takes_field=True), + reader="urword", ) + + def __attrs_post_init__(self): + # TODO set up on_setattr hooks for period block + # arrays to update maxbound? for now do it here + # in post init. but this only works when values + # are set in the initializer, not when they are + # set later. + maxhead = len(np.where(self.head != FILL_DNODATA)) if self.head is not None else 0 + maxaux = len(np.where(self.aux != FILL_DNODATA)) if self.aux is not None else 0 + maxboundname = len(np.where(self.boundname != "")) if self.boundname is not None else 0 + # maxsteps = len(np.where(self.steps != None)) if self.steps is not None else 0 + self.maxbound = max(maxhead, maxaux, maxboundname) diff --git a/flopy4/mf6/spec.py b/flopy4/mf6/spec.py index 4a56fe1b..79cb9d1f 100644 --- a/flopy4/mf6/spec.py +++ b/flopy4/mf6/spec.py @@ -5,6 +5,8 @@ import builtins import types +from datetime import datetime +from pathlib import Path from typing import Union, get_args, get_origin import numpy as np @@ -176,6 +178,39 @@ def fields_dict(cls) -> dict[str, Attribute]: return {k: v for k, v in fields.items() if "block" in v.metadata} +def to_dfn_field_type(t: type) -> FieldType: + match t: + case builtins.str | np.str_: + return "string" + case builtins.bool | np.bool: + return "keyword" + case builtins.int | np.integer: + return "integer" # type: ignore + case builtins.float | np.floating: + return "double precision" # type: ignore + case t if t is Path or t is datetime: + return "string" + case t if get_origin(t) in (Union, types.UnionType): + args = get_args(t) + if args[-1] is types.NoneType: + match args[0]: + case builtins.str | np.str_: + return "string" + case builtins.bool | np.bool: + return "keyword" + case builtins.int | np.integer: + return "integer" + case builtins.float | np.floating: + return "double precision" + case tt if tt is Path or tt is datetime: + return "string" + case _: + return "record" + return "keystring" + case _: + return "record" + + def get_dfn_field_type(attribute: Attribute) -> FieldType: """ Get a `xattree` field's type as defined by the MODFLOW 6 input @@ -197,22 +232,9 @@ def get_dfn_field_type(attribute: Attribute) -> FieldType: case "dim": return "integer" case "attr": - match attribute.type: - case builtins.str | np.str_: - return "string" - case builtins.bool | np.bool: - return "keyword" - case builtins.int | np.integer: - return "integer" - case builtins.float | np.floating: - return "double precision" - - case t if ( - get_origin(t) in (Union, types.UnionType) and get_args(t)[-1] is types.NoneType - ): - return "union" - case _: - return "record" + if (t := attribute.type) is None: + raise ValueError(f"Attribute {attribute.name} in {attribute.name} has no type.") + return to_dfn_field_type(t) raise ValueError(f"Could not map {attribute.name} to a valid MF6 type.") @@ -248,3 +270,20 @@ def get_blocks(dfn: Dfn) -> dict: key=block_sort_key, ) ) + + +def is_list_field(field: Field) -> bool: + """ + Check if a field is a list field, which is a recarray + field that uses list input. This is determined by the + reader being "readarray" and the type being "recarray". + """ + return field["type"] == "recarray" and field["reader"] != "readarray" + + +def is_list_block(block: dict) -> bool: + return ( + len(block) == 1 + and (field := next(iter(block.values())))["type"] == "recarray" + and field["reader"] != "readarray" + ) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values())) diff --git a/flopy4/mf6/templates/blocks.jinja b/flopy4/mf6/templates/blocks.jinja index d15be65b..ef534445 100644 --- a/flopy4/mf6/templates/blocks.jinja +++ b/flopy4/mf6/templates/blocks.jinja @@ -9,5 +9,5 @@ END {{ block_name.upper() }} {% endfor %} {% for block_name, block_ in (dfn|list_blocks).items() -%} -{{ macros.list(block_name, block_, stress=block_name == "period") }} +{{ macros.list(block_name, block_, single=block_name != "period") }} {%- endfor%} diff --git a/flopy4/mf6/templates/macros.jinja b/flopy4/mf6/templates/macros.jinja index b1dc3f28..b7ea113d 100644 --- a/flopy4/mf6/templates/macros.jinja +++ b/flopy4/mf6/templates/macros.jinja @@ -50,36 +50,43 @@ OPEN/CLOSE {{ value }} {% endif %} {% endmacro %} -{% macro list(block_name, block, stress=False) %} +{% macro list(block_name, block, single=False) %} {# from mf6's perspective, a list block (e.g. period data) always has just one variable, whose elements might be -records or unions. where we spin those out into arrays -for each individual leaf field to fit the xarray data -model, we have to combine them back here. - -this macro receives the block definition. from that -it looks up the value of the one variable with the -same name as the block, which custom converter has -made sure exists in a sparse dict representation of -an array. we need to expand this into a block for -each stress period. +records or unions. if records, each record field is a +column in a table, and we spin each column out into a +separate array. if unions, we have a single array of +some record which we have invented to represent the +union in a nicer way. in either case we've done some +munging in conversion to turn the (internal) array +into a sparse dict representation where outermost +keys are stress periods, and stored that as a new +variable with the same name as the block. this macro +receives the block definition, and from that it looks +up the value of this variable. if the 'single' param +is True, we assume the dict's values are tuples to +become rows in a single block, e.g. tdis perioddata. +if False, we assume the dict's values are themselves +dicts, where the keys are cell ids, and the values +are tuples, and each entry becomes a row in a block +indexed by stres period. hope that all made sense.. #} {% set d = data[block_name] %} -{% if stress %} +{% if single %} +BEGIN {{ block_name.upper() }} +{% for line in d.values() %} +{{ line|join(" ")|upper }} +{% endfor %} +END {{ block_name.upper() }} +{% else %} {% for kper, value in d.items() %} -BEGIN {{ block_name.upper() }} {{ kper }} +BEGIN {{ block_name.upper() }} {{ kper + 1 }} {% for line in value %} {{ line|join(" ")|upper }} {% endfor %} -END {{ block_name.upper() }} {{ kper }} +END {{ block_name.upper() }} {{ kper + 1 }} {% endfor %} -{% else %} -BEGIN {{ block_name.upper() }} -{% for line in d.values() %} -{{ line|join(" ")|upper }} -{% endfor %} -END {{ block_name.upper() }} {% endif %} {% endmacro %} diff --git a/test/test_codec.py b/test/test_codec.py index 5c43840b..427b18ff 100644 --- a/test/test_codec.py +++ b/test/test_codec.py @@ -13,9 +13,7 @@ def test_dumps_ic(): ) result = dumps(ic) - print() print(result) - print() assert result @@ -51,3 +49,39 @@ def test_dumps_dis(): result = dumps(dis) print(result) assert result + + +def test_dumps_tdis(): + from flopy.discretization.modeltime import ModelTime + + from flopy4.mf6.tdis import Tdis + + tdis = Tdis.from_time(ModelTime(perlen=[1.0, 2.0], nstp=[1, 2])) + tdis.time_units = "days" + + result = dumps(tdis) + print(result) + assert result + + +def test_dumps_chd(): + from flopy4.mf6.gwf import Chd, Dis, Gwf + + dis = Dis(nrow=10, ncol=10) + gwf = Gwf(dis=dis) + chd = Chd( + parent=gwf, + head={ + 0: { + (0, 0, 0): 10.0, + (0, 9, 9): 20.0, + } + }, + save_flows=True, + print_input=True, + dims={"nper": 1}, + ) + + result = dumps(chd) + print(result) + assert result