From 47768c0a7303fe7a8543b8fd1dd6e31c371d71dc Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Wed, 8 Oct 2025 13:30:15 -0400 Subject: [PATCH] consolidate converter modules, fix dim name nnodes -> nodes --- flopy4/mf6/adapters.py | 6 +-- flopy4/mf6/codec/writer/filters.py | 4 +- flopy4/mf6/converter.py | 81 +++++++++++++++++++++++++++++- flopy4/mf6/converters.py | 80 ----------------------------- flopy4/mf6/gwf/chd.py | 8 +-- flopy4/mf6/gwf/dis.py | 6 +-- flopy4/mf6/gwf/drn.py | 10 ++-- flopy4/mf6/gwf/ic.py | 4 +- flopy4/mf6/gwf/npf.py | 18 +++---- flopy4/mf6/gwf/oc.py | 2 +- flopy4/mf6/gwf/rch.py | 8 +-- flopy4/mf6/gwf/sto.py | 8 +-- flopy4/mf6/gwf/wel.py | 8 +-- flopy4/mf6/indexes.py | 4 -- flopy4/mf6/tdis.py | 2 +- test/test_component.py | 4 +- test/test_interface.py | 8 +-- 17 files changed, 127 insertions(+), 134 deletions(-) delete mode 100644 flopy4/mf6/converters.py diff --git a/flopy4/mf6/adapters.py b/flopy4/mf6/adapters.py index 913c1726..452fd530 100644 --- a/flopy4/mf6/adapters.py +++ b/flopy4/mf6/adapters.py @@ -319,7 +319,7 @@ def data_type(self): case "ndarray": if "nper" in self._data.dims: if self._data.ndim == 2: - if "nnodes" in self._data.dims: + if "nodes" in self._data.dims: return DataType.transient2d # nodes? if self._data.ndim == 3: return DataType.transient3d # ncpl? @@ -327,7 +327,7 @@ def data_type(self): return DataType.transient2d # nodes? else: if self._data.ndim == 1: - if "nnodes" in self._data.dims: + if "nodes" in self._data.dims: return DataType.array3d if self._data.ndim == 2: return DataType.array2d @@ -351,7 +351,7 @@ def dtype(self): @property def array(self): if self._spec.type.__name__ == "ndarray": - if "nnodes" in self._data.dims: + if "nodes" in self._data.dims: if "nper" in self._data.dims: shape = ( self._time.nper, diff --git a/flopy4/mf6/codec/writer/filters.py b/flopy4/mf6/codec/writer/filters.py index 95a05395..573117b9 100644 --- a/flopy4/mf6/codec/writer/filters.py +++ b/flopy4/mf6/codec/writer/filters.py @@ -176,7 +176,7 @@ def data2list(value: list | xr.DataArray | xr.Dataset): yield (value.item(),) return - spatial_dims = [d for d in value.dims if d in ("nlay", "nrow", "ncol", "nnodes")] + spatial_dims = [d for d in value.dims if d in ("nlay", "nrow", "ncol", "nodes")] has_spatial_dims = len(spatial_dims) > 0 mask = nonempty(value) indices = np.where(mask) @@ -223,7 +223,7 @@ def dataset2list(value: xr.Dataset): if combined_mask is None or not np.any(combined_mask): return - spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nnodes")] + spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nodes")] has_spatial_dims = len(spatial_dims) > 0 indices = np.where(combined_mask) for i in range(len(indices[0])): diff --git a/flopy4/mf6/converter.py b/flopy4/mf6/converter.py index 63e91cbf..0ae1fc9f 100644 --- a/flopy4/mf6/converter.py +++ b/flopy4/mf6/converter.py @@ -3,12 +3,19 @@ from pathlib import Path from typing import Any +import numpy as np +import sparse import xarray as xr import xattree from attrs import define from cattrs import Converter +from numpy.typing import NDArray +from xattree import get_xatspec +from flopy4.adapters import get_nn from flopy4.mf6.component import Component +from flopy4.mf6.config import SPARSE_THRESHOLD +from flopy4.mf6.constants import FILL_DNODATA from flopy4.mf6.context import Context from flopy4.mf6.exchange import Exchange from flopy4.mf6.model import Model @@ -148,7 +155,7 @@ def unstructure_component(value: Component) -> dict[str, Any]: blocks[block_name][field_name] = tuple(field_value.values.tolist()) elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims: has_spatial_dims = any( - dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nnodes"] + dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"] ) if has_spatial_dims: # terrible hack to convert flat nodes dimension to 3d structured dims. @@ -156,7 +163,7 @@ def unstructure_component(value: Component) -> dict[str, Any]: # should then have access to all dimensions needed. dims_ = set(field_value.dims).copy() dims_.remove("nper") - if dims_ == {"nnodes"}: + if dims_ == {"nodes"}: parent = value.parent # type: ignore field_value = xr.DataArray( field_value.data.reshape( @@ -228,3 +235,73 @@ def _make_converter() -> Converter: COMPONENT_CONVERTER = _make_converter() + + +def dict_to_array(value, self_, field) -> NDArray: + """ + Convert a sparse dictionary representation of an array to a + dense numpy array or a sparse COO array. + """ + + if not isinstance(value, dict): + # if not a dict, assume it's a numpy array + # and let xarray deal with it if it isn't + return value + + spec = get_xatspec(type(self_)).flat + field = spec[field.name] + if not field.dims: + raise ValueError(f"Field {field} missing dims") + + # resolve dims + explicit_dims = self_.__dict__.get("dims", {}) + inherited_dims = dict(self_.parent.data.dims) if self_.parent else {} + dims = inherited_dims | explicit_dims + shape = [dims.get(d, d) for d in field.dims] + unresolved = [d for d in shape if isinstance(d, str)] + if any(unresolved): + raise ValueError(f"Couldn't resolve dims: {unresolved}") + + if np.prod(shape) > SPARSE_THRESHOLD: + a: dict[tuple[Any, ...], Any] = dict() + + def set_(arr, val, *ind): + arr[tuple(ind)] = val + + def final(arr): + coords = np.array(list(map(list, zip(*arr.keys())))) + return sparse.COO( + coords, + list(arr.values()), + shape=shape, + fill_value=field.default or FILL_DNODATA, + ) + else: + a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore + + def set_(arr, val, *ind): + arr[ind] = val + + def final(arr): + arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA + return arr + + if "nper" in dims: + for kper, period in value.items(): + if kper == "*": + kper = 0 + match len(shape): + case 1: + set_(a, period, kper) + case _: + for cellid, v in period.items(): + nn = get_nn(cellid, **dims) + set_(a, v, kper, nn) + if kper == "*": + break + else: + for cellid, v in value.items(): + nn = get_nn(cellid, **dims) + set_(a, v, nn) + + return final(a) diff --git a/flopy4/mf6/converters.py b/flopy4/mf6/converters.py deleted file mode 100644 index 1c4e5fe8..00000000 --- a/flopy4/mf6/converters.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Any - -import numpy as np -import sparse -from numpy.typing import NDArray -from xattree import get_xatspec - -from flopy4.adapters import get_nn -from flopy4.mf6.config import SPARSE_THRESHOLD -from flopy4.mf6.constants import FILL_DNODATA - - -def dict_to_array(value, self_, field) -> NDArray: - """ - Convert a sparse dictionary representation of an array to a - dense numpy array or a sparse COO array. - """ - - if not isinstance(value, dict): - # if not a dict, assume it's a numpy array - # and let xarray deal with it if it isn't - return value - - spec = get_xatspec(type(self_)).flat - field = spec[field.name] - if not field.dims: - raise ValueError(f"Field {field} missing dims") - - # resolve dims - explicit_dims = self_.__dict__.get("dims", {}) - inherited_dims = dict(self_.parent.data.dims) if self_.parent else {} - dims = inherited_dims | explicit_dims - shape = [dims.get(d, d) for d in field.dims] - unresolved = [d for d in shape if isinstance(d, str)] - if any(unresolved): - raise ValueError(f"Couldn't resolve dims: {unresolved}") - - if np.prod(shape) > SPARSE_THRESHOLD: - a: dict[tuple[Any, ...], Any] = dict() - - def set_(arr, val, *ind): - arr[tuple(ind)] = val - - def final(arr): - coords = np.array(list(map(list, zip(*arr.keys())))) - return sparse.COO( - coords, - list(arr.values()), - shape=shape, - fill_value=field.default or FILL_DNODATA, - ) - else: - a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore - - def set_(arr, val, *ind): - arr[ind] = val - - def final(arr): - arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA - return arr - - if "nper" in dims: - for kper, period in value.items(): - if kper == "*": - kper = 0 - match len(shape): - case 1: - set_(a, period, kper) - case _: - for cellid, v in period.items(): - nn = get_nn(cellid, **dims) - set_(a, v, kper, nn) - if kper == "*": - break - else: - for cellid, v in value.items(): - nn = get_nn(cellid, **dims) - set_(a, v, nn) - - return final(a) diff --git a/flopy4/mf6/gwf/chd.py b/flopy4/mf6/gwf/chd.py index de8416c0..3fc97532 100644 --- a/flopy4/mf6/gwf/chd.py +++ b/flopy4/mf6/gwf/chd.py @@ -7,7 +7,7 @@ from xattree import xattree from flopy4.mf6.component import update_maxbound -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -29,7 +29,7 @@ class Chd(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), @@ -40,7 +40,7 @@ class Chd(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), @@ -51,7 +51,7 @@ class Chd(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), diff --git a/flopy4/mf6/gwf/dis.py b/flopy4/mf6/gwf/dis.py index d39b8f4d..273a1e14 100644 --- a/flopy4/mf6/gwf/dis.py +++ b/flopy4/mf6/gwf/dis.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, dim, field @@ -70,14 +70,14 @@ class Dis(Package): dims=("nlay", "nrow", "ncol"), converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) - nnodes: int = dim( + nodes: int = dim( coord="node", scope="gwf", init=False, ) def __attrs_post_init__(self): - self.nnodes = self.ncol * self.nrow * self.nlay + self.nodes = self.ncol * self.nrow * self.nlay super().__attrs_post_init__() def to_grid(self) -> StructuredGrid: diff --git a/flopy4/mf6/gwf/drn.py b/flopy4/mf6/gwf/drn.py index b877e302..6a01c2d8 100644 --- a/flopy4/mf6/gwf/drn.py +++ b/flopy4/mf6/gwf/drn.py @@ -7,7 +7,7 @@ from xattree import xattree from flopy4.mf6.component import update_maxbound -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -29,7 +29,7 @@ class Drn(Package): maxbound: Optional[int] = field(block="dimensions", default=None, init=False) elev: Optional[NDArray[np.float64]] = array( block="period", - dims=("nper", "nnodes"), + dims=("nper", "nodes"), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), reader="urword", @@ -37,7 +37,7 @@ class Drn(Package): ) cond: Optional[NDArray[np.float64]] = array( block="period", - dims=("nper", "nnodes"), + dims=("nper", "nodes"), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), reader="urword", @@ -47,7 +47,7 @@ class Drn(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), @@ -58,7 +58,7 @@ class Drn(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), diff --git a/flopy4/mf6/gwf/ic.py b/flopy4/mf6/gwf/ic.py index 44a0f292..e5190bab 100644 --- a/flopy4/mf6/gwf/ic.py +++ b/flopy4/mf6/gwf/ic.py @@ -3,7 +3,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -14,7 +14,7 @@ class Ic(Package): export_array_netcdf: bool = field(block="options", default=False) strt: NDArray[np.float64] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=1.0, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) diff --git a/flopy4/mf6/gwf/npf.py b/flopy4/mf6/gwf/npf.py index 4828fdef..370d666d 100644 --- a/flopy4/mf6/gwf/npf.py +++ b/flopy4/mf6/gwf/npf.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -49,49 +49,49 @@ class Xt3dOptions: dev_omega: Optional[float] = field(block="options", default=None) icelltype: NDArray[np.integer] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=0, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) k: NDArray[np.float64] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=1.0, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) k22: Optional[NDArray[np.float64]] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) k33: Optional[NDArray[np.float64]] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) angle1: Optional[NDArray[np.float64]] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) angle2: Optional[NDArray[np.float64]] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) angle3: Optional[NDArray[np.float64]] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) wetdry: Optional[NDArray[np.float64]] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) diff --git a/flopy4/mf6/gwf/oc.py b/flopy4/mf6/gwf/oc.py index b50f715d..0c47499a 100644 --- a/flopy4/mf6/gwf/oc.py +++ b/flopy4/mf6/gwf/oc.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field from flopy4.utils import to_path diff --git a/flopy4/mf6/gwf/rch.py b/flopy4/mf6/gwf/rch.py index b38eb64f..b7cc131d 100644 --- a/flopy4/mf6/gwf/rch.py +++ b/flopy4/mf6/gwf/rch.py @@ -7,7 +7,7 @@ from xattree import xattree from flopy4.mf6.component import update_maxbound -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -29,7 +29,7 @@ class Rch(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), @@ -40,7 +40,7 @@ class Rch(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), @@ -51,7 +51,7 @@ class Rch(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), diff --git a/flopy4/mf6/gwf/sto.py b/flopy4/mf6/gwf/sto.py index 3e375cda..e5482be0 100644 --- a/flopy4/mf6/gwf/sto.py +++ b/flopy4/mf6/gwf/sto.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -23,19 +23,19 @@ class Sto(Package): dev_oldstorageformulation: bool = field(block="options", default=False) iconvert: NDArray[np.int32] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=0, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) ss: NDArray[np.float64] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=1e-5, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) sy: NDArray[np.float64] = array( block="griddata", - dims=("nnodes",), + dims=("nodes",), default=0.15, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) diff --git a/flopy4/mf6/gwf/wel.py b/flopy4/mf6/gwf/wel.py index da502b24..880fe567 100644 --- a/flopy4/mf6/gwf/wel.py +++ b/flopy4/mf6/gwf/wel.py @@ -7,7 +7,7 @@ from xattree import xattree from flopy4.mf6.component import update_maxbound -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -31,7 +31,7 @@ class Wel(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), @@ -42,7 +42,7 @@ class Wel(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), @@ -53,7 +53,7 @@ class Wel(Package): block="period", dims=( "nper", - "nnodes", + "nodes", ), default=None, converter=Converter(dict_to_array, takes_self=True, takes_field=True), diff --git a/flopy4/mf6/indexes.py b/flopy4/mf6/indexes.py index 2bd0329d..0ef5f3bc 100644 --- a/flopy4/mf6/indexes.py +++ b/flopy4/mf6/indexes.py @@ -46,13 +46,9 @@ def sel(self, labels): def grid_index(dataset: xr.Dataset) -> MetaIndex: return MetaIndex( { - # TODO add 'per' (stress period) "lay": alias(dataset, "nlay", "lay"), "col": alias(dataset, "ncol", "col"), "row": alias(dataset, "nrow", "row"), - # "node": alias(dataset, "nnodes", "node"), - # TODO: adding node breaks the other three. - # and just having node by itself works. why? } ) diff --git a/flopy4/mf6/tdis.py b/flopy4/mf6/tdis.py index ee54337f..db2164ec 100644 --- a/flopy4/mf6/tdis.py +++ b/flopy4/mf6/tdis.py @@ -7,7 +7,7 @@ from numpy.typing import NDArray from xattree import ROOT, xattree -from flopy4.mf6.converters import dict_to_array +from flopy4.mf6.converter import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, dim, field diff --git a/test/test_component.py b/test/test_component.py index 6020cea3..46a17162 100644 --- a/test/test_component.py +++ b/test/test_component.py @@ -36,7 +36,7 @@ def test_init_gwf_explicit_dims(): "nlay": grid.nlay, "nrow": grid.nrow, "ncol": grid.ncol, - "nnodes": grid.nnodes, + "nodes": grid.nnodes, } dis = Dis(dims=dims) ic = Ic(dims=dims) @@ -164,7 +164,7 @@ def test_init_sim_explicit_dims(): } dis = Dis(**dims) dims["nper"] = time.nper - dims["nnodes"] = grid.nnodes + dims["nodes"] = grid.nnodes ic = Ic(dims=dims) oc = Oc(dims=dims) npf = Npf(dims=dims) diff --git a/test/test_interface.py b/test/test_interface.py index b1793f23..66d0a889 100644 --- a/test/test_interface.py +++ b/test/test_interface.py @@ -56,7 +56,7 @@ def test_flopy3_model(tmp_path): dis.yorigin = 0.0 dims["nper"] = time.nper - dims["nnodes"] = grid.nnodes + dims["nodes"] = grid.nnodes # ims = Ims(dims=dims) ims = Ims() @@ -182,7 +182,7 @@ def test_flopy3_package(tmp_path): dis.botm = botm dims["nper"] = time.nper - dims["nnodes"] = grid.nnodes + dims["nodes"] = grid.nnodes gwf = Gwf( dis=dis, @@ -257,7 +257,7 @@ def norun_test_flopy3_cbd_small(tmp_path): } dis = Dis(**dims) dims["nper"] = time.nper - dims["nnodes"] = cbd_small.nnodes + dims["nodes"] = cbd_small.nnodes gwf = Gwf( dis=dis, dims=dims, @@ -298,7 +298,7 @@ def test_flopy3_grid2(tmp_path): dis.botm = botm dis.idomain = idomain dims["nper"] = time.nper - dims["nnodes"] = nlay * nrow * ncol + dims["nodes"] = nlay * nrow * ncol gwf = Gwf( dis=dis, dims=dims,