diff --git a/flopy4/mf6/attr_hooks.py b/flopy4/mf6/attr_hooks.py deleted file mode 100644 index a341bfc4..00000000 --- a/flopy4/mf6/attr_hooks.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Attribute hooks for attrs on_setattr callbacks.""" - -import numpy as np -from attrs import fields - -from flopy4.mf6.constants import FILL_DNODATA - - -def update_maxbound(instance, attribute, new_value): - """ - Generalized function to update maxbound when period block arrays change. - - This function automatically finds all period block arrays in the instance - and calculates maxbound based on the maximum number of non-default values - across all arrays. - - Args: - instance: The package instance - attribute: The attribute being set (from attrs on_setattr) - new_value: The new value being set - - Returns: - The new_value (unchanged) - """ - - period_arrays = [] - instance_fields = fields(instance.__class__) - for field in instance_fields: - if ( - field.metadata - and field.metadata.get("block") == "period" - and field.metadata.get("xattree", {}).get("dims") - ): - period_arrays.append(field.name) - - maxbound_values = [] - for array_name in period_arrays: - if attribute and attribute.name == array_name: - array_val = new_value - else: - array_val = getattr(instance, array_name, None) - - if array_val is not None: - array_data = ( - array_val if array_val.data.shape == array_val.shape else array_val.todense() - ) - - if array_data.dtype.kind in ["U", "S"]: # String arrays - non_default_count = len(np.where(array_data != "")[0]) - else: # Numeric arrays - non_default_count = len(np.where(array_data != FILL_DNODATA)[0]) - - maxbound_values.append(non_default_count) - if maxbound_values: - instance.maxbound = max(maxbound_values) - - return new_value diff --git a/flopy4/mf6/component.py b/flopy4/mf6/component.py index 32ed57ad..efffa7c9 100644 --- a/flopy4/mf6/component.py +++ b/flopy4/mf6/component.py @@ -3,12 +3,67 @@ from pathlib import Path from typing import ClassVar +import numpy as np +from attrs import fields from modflow_devtools.dfn import Dfn, Field from xattree import xattree +from flopy4.mf6.constants import FILL_DNODATA from flopy4.mf6.spec import field, fields_dict, to_dfn_field from flopy4.uio import IO, Loader, Writer + +def update_maxbound(instance, attribute, new_value): + """ + Generalized function to update maxbound when period block arrays change. + + This function automatically finds all period block arrays in the instance + and calculates maxbound based on the maximum number of non-default values + across all arrays. + + Args: + instance: The package instance + attribute: The attribute being set (from attrs on_setattr) + new_value: The new value being set + + Returns: + The new_value (unchanged) + """ + + period_arrays = [] + instance_fields = fields(instance.__class__) + for f in instance_fields: + if ( + f.metadata + and f.metadata.get("block") == "period" + and f.metadata.get("xattree", {}).get("dims") + ): + period_arrays.append(f.name) + + maxbound_values = [] + for array_name in period_arrays: + if attribute and attribute.name == array_name: + array_val = new_value + else: + array_val = getattr(instance, array_name, None) + + if array_val is not None: + array_data = ( + array_val if array_val.data.shape == array_val.shape else array_val.todense() + ) + + if array_data.dtype.kind in ["U", "S"]: # String arrays + non_default_count = len(np.where(array_data != "")[0]) + else: # Numeric arrays + non_default_count = len(np.where(array_data != FILL_DNODATA)[0]) + + maxbound_values.append(non_default_count) + if maxbound_values: + instance.maxbound = max(maxbound_values) + + return new_value + + COMPONENTS = {} """MF6 component registry.""" @@ -50,6 +105,36 @@ def default_filename(self) -> str: cls_name = self.__class__.__name__.lower() return f"{name}.{cls_name}" + def __attrs_post_init__(self): + """ + Post-initialization hook for all components. + + Automatically handles common post-init tasks like computing maxbound + for components with period block arrays. + """ + self._update_maxbound_if_needed() + + def _update_maxbound_if_needed(self): + """ + Update maxbound if this component has period block arrays. + + This method checks if the component has any period block arrays defined + and calls update_maxbound if needed. This generalizes the pattern that + was previously repeated in multiple component classes. + """ + # Check if component has a maxbound field and period block arrays + component_fields = fields(self.__class__) + has_maxbound = any(f.name == "maxbound" for f in component_fields) + has_period_arrays = any( + f.metadata + and f.metadata.get("block") == "period" + and f.metadata.get("xattree", {}).get("dims") + for f in component_fields + ) + + if has_maxbound and has_period_arrays: + update_maxbound(self, None, None) + @classmethod def __attrs_init_subclass__(cls): COMPONENTS[cls.__name__.lower()] = cls diff --git a/flopy4/mf6/gwf/chd.py b/flopy4/mf6/gwf/chd.py index 804e0325..de8416c0 100644 --- a/flopy4/mf6/gwf/chd.py +++ b/flopy4/mf6/gwf/chd.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.attr_hooks import update_maxbound +from flopy4.mf6.component import update_maxbound from flopy4.mf6.converters import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -58,7 +58,3 @@ class Chd(Package): reader="urword", on_setattr=update_maxbound, ) - - def __attrs_post_init__(self): - if self.head is not None or self.aux is not None or self.boundname is not None: - update_maxbound(self, None, None) diff --git a/flopy4/mf6/gwf/drn.py b/flopy4/mf6/gwf/drn.py index ea813f0d..b877e302 100644 --- a/flopy4/mf6/gwf/drn.py +++ b/flopy4/mf6/gwf/drn.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.attr_hooks import update_maxbound +from flopy4.mf6.component import update_maxbound from flopy4.mf6.converters import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -65,12 +65,3 @@ class Drn(Package): reader="urword", on_setattr=update_maxbound, ) - - def __attrs_post_init__(self): - if ( - self.elev is not None - or self.cond is not None - or self.aux is not None - or self.boundname is not None - ): - update_maxbound(self, None, None) diff --git a/flopy4/mf6/gwf/rch.py b/flopy4/mf6/gwf/rch.py index 4eb9a176..b38eb64f 100644 --- a/flopy4/mf6/gwf/rch.py +++ b/flopy4/mf6/gwf/rch.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.attr_hooks import update_maxbound +from flopy4.mf6.component import update_maxbound from flopy4.mf6.converters import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -58,7 +58,3 @@ class Rch(Package): reader="urword", on_setattr=update_maxbound, ) - - def __attrs_post_init__(self): - if self.recharge is not None or self.aux is not None or self.boundname is not None: - update_maxbound(self, None, None) diff --git a/flopy4/mf6/gwf/wel.py b/flopy4/mf6/gwf/wel.py index 59d27912..da502b24 100644 --- a/flopy4/mf6/gwf/wel.py +++ b/flopy4/mf6/gwf/wel.py @@ -6,7 +6,7 @@ from numpy.typing import NDArray from xattree import xattree -from flopy4.mf6.attr_hooks import update_maxbound +from flopy4.mf6.component import update_maxbound from flopy4.mf6.converters import dict_to_array from flopy4.mf6.package import Package from flopy4.mf6.spec import array, field @@ -60,7 +60,3 @@ class Wel(Package): reader="urword", on_setattr=update_maxbound, ) - - def __attrs_post_init__(self): - if self.q is not None or self.aux is not None or self.boundname is not None: - update_maxbound(self, None, None)