diff --git a/flopy4/mf6/codec/writer/filters.py b/flopy4/mf6/codec/writer/filters.py index 25c89e88..2719379a 100644 --- a/flopy4/mf6/codec/writer/filters.py +++ b/flopy4/mf6/codec/writer/filters.py @@ -201,11 +201,15 @@ def dataset2list(value: xr.Dataset): if value is None or not any(value.data_vars): return - first = next(iter(value.data_vars.values())) - is_union = first.dtype.type is np.str_ + # special case OC for now. + is_oc = all( + str(v.name).startswith("save_") or str(v.name).startswith("print_") + for v in value.data_vars.values() + ) - if first.ndim == 0: # handle scalar - if is_union: + # handle scalar + if (first := next(iter(value.data_vars.values()))).ndim == 0: + if is_oc: for name in value.data_vars.keys(): val = value[name] val = val.item() if val.shape == () else val @@ -230,7 +234,7 @@ def dataset2list(value: xr.Dataset): has_spatial_dims = len(spatial_dims) > 0 indices = np.where(combined_mask) for i in range(len(indices[0])): - if is_union: + if is_oc: for name in value.data_vars.keys(): val = value[name][tuple(idx[i] for idx in indices)] val = val.item() if val.shape == () else val diff --git a/flopy4/mf6/codec/writer/templates/macros.jinja b/flopy4/mf6/codec/writer/templates/macros.jinja index 29c1369a..45069913 100644 --- a/flopy4/mf6/codec/writer/templates/macros.jinja +++ b/flopy4/mf6/codec/writer/templates/macros.jinja @@ -32,10 +32,10 @@ {{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %} {% if how == "constant" %} -CONSTANT {{ value|array2const }} +{{ inset }}CONSTANT {{ value|array2const }} {% elif how == "layered constant" %} {% for layer in value -%} -CONSTANT {{ layer|array2const }} +{{ inset }}CONSTANT {{ layer|array2const }} {%- endfor %} {% elif how == "internal" %} INTERNAL diff --git a/flopy4/mf6/component.py b/flopy4/mf6/component.py index 8a748965..09ef9c1b 100644 --- a/flopy4/mf6/component.py +++ b/flopy4/mf6/component.py @@ -3,69 +3,17 @@ from pathlib import Path from typing import Any, ClassVar -import numpy as np from attrs import fields from modflow_devtools.dfn import Dfn, Field from packaging.version import Version from xattree import asdict as xattree_asdict from xattree import xattree -from flopy4.mf6.constants import FILL_DNODATA, MF6 +from flopy4.mf6.constants import MF6 from flopy4.mf6.spec import field, fields_dict, to_field +from flopy4.mf6.utils.grid_utils import update_maxbound from flopy4.uio import IO, Loader, Writer - -def update_maxbound(instance, attribute, new_value): - """ - Generalized function to update maxbound when period block arrays change. - - This function automatically finds all period block arrays in the instance - and calculates maxbound based on the maximum number of non-default values - across all arrays. - - Args: - instance: The package instance - attribute: The attribute being set (from attrs on_setattr) - new_value: The new value being set - - Returns: - The new_value (unchanged) - """ - - period_arrays = [] - instance_fields = fields(instance.__class__) - for f in instance_fields: - if ( - f.metadata - and f.metadata.get("block") == "period" - and f.metadata.get("xattree", {}).get("dims") - ): - period_arrays.append(f.name) - - maxbound_values = [] - for array_name in period_arrays: - if attribute and attribute.name == array_name: - array_val = new_value - else: - array_val = getattr(instance, array_name, None) - - if array_val is not None: - array_data = ( - array_val if array_val.data.shape == array_val.shape else array_val.todense() - ) - - if array_data.dtype.kind in ["U", "S"]: # String arrays - non_default_count = len(np.where(array_data != "")[0]) - else: # Numeric arrays - non_default_count = len(np.where(array_data != FILL_DNODATA)[0]) - - maxbound_values.append(non_default_count) - if maxbound_values: - instance.maxbound = max(maxbound_values) - - return new_value - - COMPONENTS = {} """MF6 component registry.""" @@ -86,11 +34,14 @@ class Component(ABC, MutableMapping): _write = IO(Writer) # type: ignore dfn: ClassVar[Dfn] + """The component's definition (i.e. specification).""" + filename: str | None = field(default=None) + """The name of the component's input file.""" @property def path(self) -> Path: - """Get the path to the component's input file.""" + """The path to the component's input file.""" self.filename = self.filename or self.default_filename() return Path.cwd() / self.filename @@ -202,18 +153,45 @@ def write(self, format: str = MF6) -> None: for child in self.children.values(): # type: ignore child.write(format=format) - def to_dict(self, blocks: bool = False) -> dict[str, Any]: - """Convert the component to a dictionary representation.""" + def to_dict(self, blocks: bool = False, strict: bool = False) -> dict[str, Any]: + """ + Convert the component to a dictionary representation. + + Parameters + ---------- + blocks : bool, optional + If True, return a nested dict keyed by block name + with values as dicts of fields. Default is False. + strict : bool, optional + If True, include only fields in the DFN specification. + + Returns + ------- + dict[str, Any] + Dictionary containing component data, either + in terms of fields (flat) or blocks (nested). + """ data = xattree_asdict(self) - data.pop("filename") - data.pop("workspace", None) # might be a Context - data.pop("nodes", None) # TODO: find a better way to omit + spec = self.dfn.fields + + if strict: + data.pop("filename") + data.pop("workspace", None) # might be a Context + if blocks: blocks_ = {} # type: ignore - for field_name, field_value in data.items(): - block_name = self.dfn.fields[field_name].block + for field_name in spec.keys(): + field_value = data[field_name] + block_name = spec[field_name].block + if strict and block_name is None: + continue if block_name not in blocks_: blocks_[block_name] = {} blocks_[block_name][field_name] = field_value return blocks_ - return data + else: + return { + field_name: data[field_name] + for field_name in spec.keys() + if spec[field_name].block or not strict + } diff --git a/flopy4/mf6/converter.py b/flopy4/mf6/converter.py index d0bb79a4..ee16db5f 100644 --- a/flopy4/mf6/converter.py +++ b/flopy4/mf6/converter.py @@ -31,7 +31,7 @@ def path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]: return tuple(t) -def get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]: +def make_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]: if not isinstance(value, Context): return {} @@ -103,13 +103,19 @@ def unstructure_component(value: Component) -> dict[str, Any]: xatspec = xattree.get_xatspec(type(value)) data = xattree.asdict(value) - blocks.update(binding_blocks := get_binding_blocks(value)) + # create child component binding blocks + blocks.update(make_binding_blocks(value)) + # process blocks in order, unstructuring fields as needed, + # then slice period data into separate kper-indexed blocks + # each of which contains a dataset indexed for that period. for block_name, block in blockspec.items(): + period_data = {} # type: ignore + period_blocks = {} # type: ignore + period_block_name = None + if block_name not in blocks: blocks[block_name] = {} - period_data = {} - period_blocks = {} # type: ignore for field_name in block.keys(): # Skip child components that have been processed as bindings @@ -119,82 +125,78 @@ def unstructure_component(value: Component) -> dict[str, Any]: if child_spec.metadata["block"] == block_name: # type: ignore continue - field_value = data[field_name] - # convert: + # filter out empty values and false keywords, and convert: # - paths to records - # - datetime to ISO format - # - auxiliary fields to tuples - # - xarray DataArrays with 'nper' dimension to kper-sliced datasets - # (and split the period data into separate kper-indexed blocks) + # - datetimes to ISO format + # - filter out false keywords + # - 'auxiliary' fields to tuples + # - xarray DataArrays with 'nper' dim to dict of kper-sliced datasets # - other values to their original form - if isinstance(field_value, Path): - field_spec = xatspec.attrs[field_name] - field_meta = getattr(field_spec, "metadata", {}) - t = path_to_tuple(field_name, field_value, inout=field_meta.get("inout", "fileout")) - # name may have changed e.g dropping '_file' suffix - blocks[block_name][t[0]] = t - elif isinstance(field_value, datetime): - blocks[block_name][field_name] = field_value.isoformat() - elif ( - field_name == "auxiliary" - and hasattr(field_value, "values") - and field_value is not None - ): - blocks[block_name][field_name] = tuple(field_value.values.tolist()) - elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims: - has_spatial_dims = any( - dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"] - ) - if has_spatial_dims: - field_value = _hack_structured_grid_dims( - field_value, - structured_grid_dims=value.parent.data.dims, # type: ignore + match field_value := data[field_name]: + case None: + continue + case bool(): + if field_value: + blocks[block_name][field_name] = field_value + case Path(): + field_spec = xatspec.attrs[field_name] + field_meta = getattr(field_spec, "metadata", {}) + t = path_to_tuple( + field_name, field_value, inout=field_meta.get("inout", "fileout") ) - - period_data[field_name] = { - kper: field_value.isel(nper=kper) - for kper in range(field_value.sizes["nper"]) - } - else: - # TODO why not putting in block here but doing below? how does this even work - if np.issubdtype(field_value.dtype, np.str_): + # name may have changed e.g dropping '_file' suffix + blocks[block_name][t[0]] = t + case datetime(): + blocks[block_name][field_name] = field_value.isoformat() + case t if ( + field_name == "auxiliary" + and hasattr(field_value, "values") + and field_value is not None + ): + blocks[block_name][field_name] = tuple(field_value.values.tolist()) + case xr.DataArray() if "nper" in field_value.dims: + has_spatial_dims = any( + dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"] + ) + if has_spatial_dims: + field_value = _hack_structured_grid_dims( + field_value, + structured_grid_dims=value.parent.data.dims, # type: ignore + ) + if "period" in block_name: + period_block_name = block_name period_data[field_name] = { - kper: field_value[kper] for kper in range(field_value.sizes["nper"]) + kper: field_value.isel(nper=kper) + for kper in range(field_value.sizes["nper"]) } - else: - if block_name not in period_data: - period_data[block_name] = {} - period_data[block_name][field_name] = field_value # type: ignore - else: - if field_value is not None: - if isinstance(field_value, bool): - if field_value: - blocks[block_name][field_name] = field_value else: blocks[block_name][field_name] = field_value - if block_name in period_data and isinstance(period_data[block_name], dict): - dataset = xr.Dataset(period_data[block_name]) - blocks[block_name] = {block_name: dataset} - del period_data[block_name] + case _: + blocks[block_name][field_name] = field_value + # invert key order, (arr_name, kper) -> (kper, arr_name) for arr_name, periods in period_data.items(): for kper, arr in periods.items(): if kper not in period_blocks: period_blocks[kper] = {} period_blocks[kper][arr_name] = arr + # setup indexed period blocks, combine arrays into datasets for kper, block in period_blocks.items(): - dataset = xr.Dataset(block) - blocks[f"{block_name} {kper + 1}"] = {block_name: dataset} + assert isinstance(period_block_name, str) + blocks[f"{period_block_name} {kper + 1}"] = { + period_block_name: xr.Dataset(block, coords=block[arr_name].coords) + } - # total temporary hack! manually set solutiongroup 1. still need to support multiple.. + # total temporary hack! manually set solutiongroup 1. + # TODO still need to support multiple.. if "solutiongroup" in blocks: sg = blocks["solutiongroup"] blocks["solutiongroup 1"] = sg del blocks["solutiongroup"] - return {name: block for name, block in blocks.items() if name != "period"} + return {name: block for name, block in blocks.items() if name != period_block_name} def _make_converter() -> Converter: diff --git a/flopy4/mf6/gwf/chd.py b/flopy4/mf6/gwf/chd.py index 2bc1b33a..3896b066 100644 --- a/flopy4/mf6/gwf/chd.py +++ b/flopy4/mf6/gwf/chd.py @@ -6,11 +6,11 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.component import update_maxbound from flopy4.mf6.constants import LENBOUNDNAME from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field, path +from flopy4.mf6.utils.grid_utils import update_maxbound from flopy4.utils import to_path diff --git a/flopy4/mf6/gwf/drn.py b/flopy4/mf6/gwf/drn.py index a6d1cfc3..9d89a3a4 100644 --- a/flopy4/mf6/gwf/drn.py +++ b/flopy4/mf6/gwf/drn.py @@ -6,11 +6,11 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.component import update_maxbound from flopy4.mf6.constants import LENBOUNDNAME from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field, path +from flopy4.mf6.utils.grid_utils import update_maxbound from flopy4.utils import to_path diff --git a/flopy4/mf6/gwf/oc.py b/flopy4/mf6/gwf/oc.py index 47e81187..1e5b0974 100644 --- a/flopy4/mf6/gwf/oc.py +++ b/flopy4/mf6/gwf/oc.py @@ -49,28 +49,28 @@ class Period: save_head: Optional[NDArray[np.str_]] = array( dtype=np.dtypes.StringDType(), block="period", - default="all", + default=None, dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) save_budget: Optional[NDArray[np.str_]] = array( dtype=np.dtypes.StringDType(), block="period", - default="all", + default=None, dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) print_head: Optional[NDArray[np.str_]] = array( dtype=np.dtypes.StringDType(), block="period", - default="all", + default=None, dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) print_budget: Optional[NDArray[np.str_]] = array( dtype=np.dtypes.StringDType(), block="period", - default="all", + default=None, dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) diff --git a/flopy4/mf6/gwf/rch.py b/flopy4/mf6/gwf/rch.py index 0dd3d0dd..6e62cd45 100644 --- a/flopy4/mf6/gwf/rch.py +++ b/flopy4/mf6/gwf/rch.py @@ -6,11 +6,11 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.component import update_maxbound from flopy4.mf6.constants import LENBOUNDNAME from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field, path +from flopy4.mf6.utils.grid_utils import update_maxbound from flopy4.utils import to_path diff --git a/flopy4/mf6/gwf/wel.py b/flopy4/mf6/gwf/wel.py index 66b414dd..e04df63d 100644 --- a/flopy4/mf6/gwf/wel.py +++ b/flopy4/mf6/gwf/wel.py @@ -6,11 +6,11 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.component import update_maxbound from flopy4.mf6.constants import LENBOUNDNAME from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field, path +from flopy4.mf6.utils.grid_utils import update_maxbound from flopy4.utils import to_path diff --git a/flopy4/mf6/simulation.py b/flopy4/mf6/simulation.py index 8ee73d6e..7c0df905 100644 --- a/flopy4/mf6/simulation.py +++ b/flopy4/mf6/simulation.py @@ -23,11 +23,13 @@ def convert_time(value): @xattree class Simulation(Context): - tdis: Tdis = field(converter=convert_time, block="timing") + tdis: Tdis = field(block="timing", converter=convert_time) models: dict[str, Model] = field(block="models") exchanges: dict[str, Exchange] = field(block="exchanges") solutions: dict[str, Solution] = field(block="solutiongroup") - filename: str = field(default="mfsim.nam", init=False) + + def default_filename(self) -> str: + return "mfsim.nam" def __attrs_post_init__(self): super().__attrs_post_init__() diff --git a/flopy4/mf6/solution.py b/flopy4/mf6/solution.py index ba444e24..6895b6b9 100644 --- a/flopy4/mf6/solution.py +++ b/flopy4/mf6/solution.py @@ -1,9 +1,8 @@ from abc import ABC -from pathlib import Path -from typing import ClassVar, Optional +from typing import ClassVar import attrs -from xattree import field, xattree +from xattree import xattree from flopy4.mf6.package import Package @@ -11,9 +10,7 @@ @xattree class Solution(Package, ABC): slntype: ClassVar[str] = "sln" - - slnfname: Optional[Path] = field(default=None) # type: ignore models: list[str] = attrs.field(default=attrs.Factory(list)) def default_filename(self) -> str: - return str(self.slnfname) if self.slnfname else f"solution.{self.slntype.lower()}" + return f"solution.{self.slntype.lower()}" diff --git a/flopy4/mf6/utils/grid_utils.py b/flopy4/mf6/utils/grid_utils.py index a2c857b0..4bd3fb5b 100644 --- a/flopy4/mf6/utils/grid_utils.py +++ b/flopy4/mf6/utils/grid_utils.py @@ -2,8 +2,11 @@ from typing import Any import numpy as np +from attrs import fields from flopy.discretization import StructuredGrid +from flopy4.mf6.constants import FILL_DNODATA + def get_coords(grid: StructuredGrid) -> dict[str, Any]: # unpack tuples @@ -31,3 +34,54 @@ def get_coords(grid: StructuredGrid) -> dict[str, Any]: coords["dy"] = ("y", dy) coords["layer"] = np.arange(1, grid.nlay + 1) return coords + + +def update_maxbound(instance, attribute, new_value): + """ + Generalized function to update maxbound when period block arrays change. + + This function automatically finds all period block arrays in the instance + and calculates maxbound based on the maximum number of non-default values + across all arrays. + + Args: + instance: The package instance + attribute: The attribute being set (from attrs on_setattr) + new_value: The new value being set + + Returns: + The new_value (unchanged) + """ + + period_arrays = [] + instance_fields = fields(instance.__class__) + for f in instance_fields: + if ( + f.metadata + and f.metadata.get("block") == "period" + and f.metadata.get("xattree", {}).get("dims") + ): + period_arrays.append(f.name) + + maxbound_values = [] + for array_name in period_arrays: + if attribute and attribute.name == array_name: + array_val = new_value + else: + array_val = getattr(instance, array_name, None) + + if array_val is not None: + array_data = ( + array_val if array_val.data.shape == array_val.shape else array_val.todense() + ) + + if array_data.dtype.kind in ["U", "S"]: # String arrays + non_default_count = len(np.where(array_data != "")[0]) + else: # Numeric arrays + non_default_count = len(np.where(array_data != FILL_DNODATA)[0]) + + maxbound_values.append(non_default_count) + if maxbound_values: + instance.maxbound = max(maxbound_values) + + return new_value diff --git a/flopy4/spec.py b/flopy4/spec.py index 305b9f35..e80f44dc 100644 --- a/flopy4/spec.py +++ b/flopy4/spec.py @@ -1,8 +1,6 @@ """ Wrap `xattree` and `attrs` specification utilities. These include field decorators and introspection functions. -TODO: add `derived` option to dims? or more generic option -to any field indicating it is not part of the formal spec? """ import numpy as np diff --git a/test/test_codec.py b/test/test_codec.py index 8cc8d07f..f0747efb 100644 --- a/test/test_codec.py +++ b/test/test_codec.py @@ -1,7 +1,5 @@ from pprint import pprint -import pytest - from flopy4.mf6.codec import dumps, loads from flopy4.mf6.converter import COMPONENT_CONVERTER @@ -55,7 +53,6 @@ def test_dumps_ic(): pprint(loaded) -@pytest.mark.xfail(reason="TODO") def test_dumps_oc(): from flopy4.mf6.gwf import Oc @@ -70,13 +67,8 @@ def test_dumps_oc(): dumped = dumps(COMPONENT_CONVERTER.unstructure(oc)) print("OC dump:") print(dumped) - # TODO these are getting truncated, need to specify string length like