Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions flopy4/mf6/codec/writer/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,15 @@ def dataset2list(value: xr.Dataset):
if value is None or not any(value.data_vars):
return

first = next(iter(value.data_vars.values()))
is_union = first.dtype.type is np.str_
# special case OC for now.
is_oc = all(
str(v.name).startswith("save_") or str(v.name).startswith("print_")
for v in value.data_vars.values()
)

if first.ndim == 0: # handle scalar
if is_union:
# handle scalar
if (first := next(iter(value.data_vars.values()))).ndim == 0:
if is_oc:
for name in value.data_vars.keys():
val = value[name]
val = val.item() if val.shape == () else val
Expand All @@ -230,7 +234,7 @@ def dataset2list(value: xr.Dataset):
has_spatial_dims = len(spatial_dims) > 0
indices = np.where(combined_mask)
for i in range(len(indices[0])):
if is_union:
if is_oc:
for name in value.data_vars.keys():
val = value[name][tuple(idx[i] for idx in indices)]
val = val.item() if val.shape == () else val
Expand Down
4 changes: 2 additions & 2 deletions flopy4/mf6/codec/writer/templates/macros.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}

{% if how == "constant" %}
CONSTANT {{ value|array2const }}
{{ inset }}CONSTANT {{ value|array2const }}
{% elif how == "layered constant" %}
{% for layer in value -%}
CONSTANT {{ layer|array2const }}
{{ inset }}CONSTANT {{ layer|array2const }}
{%- endfor %}
{% elif how == "internal" %}
INTERNAL
Expand Down
104 changes: 41 additions & 63 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,17 @@
from pathlib import Path
from typing import Any, ClassVar

import numpy as np
from attrs import fields
from modflow_devtools.dfn import Dfn, Field
from packaging.version import Version
from xattree import asdict as xattree_asdict
from xattree import xattree

from flopy4.mf6.constants import FILL_DNODATA, MF6
from flopy4.mf6.constants import MF6
from flopy4.mf6.spec import field, fields_dict, to_field
from flopy4.mf6.utils.grid_utils import update_maxbound
from flopy4.uio import IO, Loader, Writer


def update_maxbound(instance, attribute, new_value):
"""
Generalized function to update maxbound when period block arrays change.

This function automatically finds all period block arrays in the instance
and calculates maxbound based on the maximum number of non-default values
across all arrays.

Args:
instance: The package instance
attribute: The attribute being set (from attrs on_setattr)
new_value: The new value being set

Returns:
The new_value (unchanged)
"""

period_arrays = []
instance_fields = fields(instance.__class__)
for f in instance_fields:
if (
f.metadata
and f.metadata.get("block") == "period"
and f.metadata.get("xattree", {}).get("dims")
):
period_arrays.append(f.name)

maxbound_values = []
for array_name in period_arrays:
if attribute and attribute.name == array_name:
array_val = new_value
else:
array_val = getattr(instance, array_name, None)

if array_val is not None:
array_data = (
array_val if array_val.data.shape == array_val.shape else array_val.todense()
)

if array_data.dtype.kind in ["U", "S"]: # String arrays
non_default_count = len(np.where(array_data != "")[0])
else: # Numeric arrays
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])

maxbound_values.append(non_default_count)
if maxbound_values:
instance.maxbound = max(maxbound_values)

return new_value


COMPONENTS = {}
"""MF6 component registry."""

Expand All @@ -86,11 +34,14 @@ class Component(ABC, MutableMapping):
_write = IO(Writer) # type: ignore

dfn: ClassVar[Dfn]
"""The component's definition (i.e. specification)."""

filename: str | None = field(default=None)
"""The name of the component's input file."""

@property
def path(self) -> Path:
"""Get the path to the component's input file."""
"""The path to the component's input file."""
self.filename = self.filename or self.default_filename()
return Path.cwd() / self.filename

Expand Down Expand Up @@ -202,18 +153,45 @@ def write(self, format: str = MF6) -> None:
for child in self.children.values(): # type: ignore
child.write(format=format)

def to_dict(self, blocks: bool = False) -> dict[str, Any]:
"""Convert the component to a dictionary representation."""
def to_dict(self, blocks: bool = False, strict: bool = False) -> dict[str, Any]:
"""
Convert the component to a dictionary representation.

Parameters
----------
blocks : bool, optional
If True, return a nested dict keyed by block name
with values as dicts of fields. Default is False.
strict : bool, optional
If True, include only fields in the DFN specification.

Returns
-------
dict[str, Any]
Dictionary containing component data, either
in terms of fields (flat) or blocks (nested).
"""
data = xattree_asdict(self)
data.pop("filename")
data.pop("workspace", None) # might be a Context
data.pop("nodes", None) # TODO: find a better way to omit
spec = self.dfn.fields

if strict:
data.pop("filename")
data.pop("workspace", None) # might be a Context

if blocks:
blocks_ = {} # type: ignore
for field_name, field_value in data.items():
block_name = self.dfn.fields[field_name].block
for field_name in spec.keys():
field_value = data[field_name]
block_name = spec[field_name].block
if strict and block_name is None:
continue
if block_name not in blocks_:
blocks_[block_name] = {}
blocks_[block_name][field_name] = field_value
return blocks_
return data
else:
return {
field_name: data[field_name]
for field_name in spec.keys()
if spec[field_name].block or not strict
}
118 changes: 60 additions & 58 deletions flopy4/mf6/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
return tuple(t)


def get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
def make_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
if not isinstance(value, Context):
return {}

Expand Down Expand Up @@ -103,13 +103,19 @@ def unstructure_component(value: Component) -> dict[str, Any]:
xatspec = xattree.get_xatspec(type(value))
data = xattree.asdict(value)

blocks.update(binding_blocks := get_binding_blocks(value))
# create child component binding blocks
blocks.update(make_binding_blocks(value))

# process blocks in order, unstructuring fields as needed,
# then slice period data into separate kper-indexed blocks
# each of which contains a dataset indexed for that period.
for block_name, block in blockspec.items():
period_data = {} # type: ignore
period_blocks = {} # type: ignore
period_block_name = None

if block_name not in blocks:
blocks[block_name] = {}
period_data = {}
period_blocks = {} # type: ignore

for field_name in block.keys():
# Skip child components that have been processed as bindings
Expand All @@ -119,82 +125,78 @@ def unstructure_component(value: Component) -> dict[str, Any]:
if child_spec.metadata["block"] == block_name: # type: ignore
continue

field_value = data[field_name]
# convert:
# filter out empty values and false keywords, and convert:
# - paths to records
# - datetime to ISO format
# - auxiliary fields to tuples
# - xarray DataArrays with 'nper' dimension to kper-sliced datasets
# (and split the period data into separate kper-indexed blocks)
# - datetimes to ISO format
# - filter out false keywords
# - 'auxiliary' fields to tuples
# - xarray DataArrays with 'nper' dim to dict of kper-sliced datasets
# - other values to their original form
if isinstance(field_value, Path):
field_spec = xatspec.attrs[field_name]
field_meta = getattr(field_spec, "metadata", {})
t = path_to_tuple(field_name, field_value, inout=field_meta.get("inout", "fileout"))
# name may have changed e.g dropping '_file' suffix
blocks[block_name][t[0]] = t
elif isinstance(field_value, datetime):
blocks[block_name][field_name] = field_value.isoformat()
elif (
field_name == "auxiliary"
and hasattr(field_value, "values")
and field_value is not None
):
blocks[block_name][field_name] = tuple(field_value.values.tolist())
elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims:
has_spatial_dims = any(
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
)
if has_spatial_dims:
field_value = _hack_structured_grid_dims(
field_value,
structured_grid_dims=value.parent.data.dims, # type: ignore
match field_value := data[field_name]:
case None:
continue
case bool():
if field_value:
blocks[block_name][field_name] = field_value
case Path():
field_spec = xatspec.attrs[field_name]
field_meta = getattr(field_spec, "metadata", {})
t = path_to_tuple(
field_name, field_value, inout=field_meta.get("inout", "fileout")
)

period_data[field_name] = {
kper: field_value.isel(nper=kper)
for kper in range(field_value.sizes["nper"])
}
else:
# TODO why not putting in block here but doing below? how does this even work
if np.issubdtype(field_value.dtype, np.str_):
# name may have changed e.g dropping '_file' suffix
blocks[block_name][t[0]] = t
case datetime():
blocks[block_name][field_name] = field_value.isoformat()
case t if (
field_name == "auxiliary"
and hasattr(field_value, "values")
and field_value is not None
):
blocks[block_name][field_name] = tuple(field_value.values.tolist())
case xr.DataArray() if "nper" in field_value.dims:
has_spatial_dims = any(
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
)
if has_spatial_dims:
field_value = _hack_structured_grid_dims(
field_value,
structured_grid_dims=value.parent.data.dims, # type: ignore
)
if "period" in block_name:
period_block_name = block_name
period_data[field_name] = {
kper: field_value[kper] for kper in range(field_value.sizes["nper"])
kper: field_value.isel(nper=kper)
for kper in range(field_value.sizes["nper"])
}
else:
if block_name not in period_data:
period_data[block_name] = {}
period_data[block_name][field_name] = field_value # type: ignore
else:
if field_value is not None:
if isinstance(field_value, bool):
if field_value:
blocks[block_name][field_name] = field_value
else:
blocks[block_name][field_name] = field_value

if block_name in period_data and isinstance(period_data[block_name], dict):
dataset = xr.Dataset(period_data[block_name])
blocks[block_name] = {block_name: dataset}
del period_data[block_name]
case _:
blocks[block_name][field_name] = field_value

# invert key order, (arr_name, kper) -> (kper, arr_name)
for arr_name, periods in period_data.items():
for kper, arr in periods.items():
if kper not in period_blocks:
period_blocks[kper] = {}
period_blocks[kper][arr_name] = arr

# setup indexed period blocks, combine arrays into datasets
for kper, block in period_blocks.items():
dataset = xr.Dataset(block)
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}
assert isinstance(period_block_name, str)
blocks[f"{period_block_name} {kper + 1}"] = {
period_block_name: xr.Dataset(block, coords=block[arr_name].coords)
}

# total temporary hack! manually set solutiongroup 1. still need to support multiple..
# total temporary hack! manually set solutiongroup 1.
# TODO still need to support multiple..
if "solutiongroup" in blocks:
sg = blocks["solutiongroup"]
blocks["solutiongroup 1"] = sg
del blocks["solutiongroup"]

return {name: block for name, block in blocks.items() if name != "period"}
return {name: block for name, block in blocks.items() if name != period_block_name}


def _make_converter() -> Converter:
Expand Down
2 changes: 1 addition & 1 deletion flopy4/mf6/gwf/chd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.component import update_maxbound
from flopy4.mf6.constants import LENBOUNDNAME
from flopy4.mf6.converter import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field, path
from flopy4.mf6.utils.grid_utils import update_maxbound
from flopy4.utils import to_path


Expand Down
2 changes: 1 addition & 1 deletion flopy4/mf6/gwf/drn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.component import update_maxbound
from flopy4.mf6.constants import LENBOUNDNAME
from flopy4.mf6.converter import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field, path
from flopy4.mf6.utils.grid_utils import update_maxbound
from flopy4.utils import to_path


Expand Down
Loading
Loading