Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions flopy4/mf6/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any

import numpy as np
import sparse
from numpy.typing import NDArray
from xattree import get_xatspec

from flopy4.adapters import get_nn
from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA


def dict_to_array(value, self_, field) -> NDArray:
"""
Convert a sparse dictionary representation of an array to a
dense numpy array or a sparse COO array.
"""

if not isinstance(value, dict):
# if not a dict, assume it's a numpy array
# and let xarray deal with it if it isn't
return value

spec = get_xatspec(type(self_)).flat
field = spec[field.name]
if not field.dims:
raise ValueError(f"Field {field} missing dims")

# resolve dims
explicit_dims = self_.__dict__.get("dims", {})
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
dims = inherited_dims | explicit_dims
shape = [dims.get(d, d) for d in field.dims]
unresolved = [d for d in shape if isinstance(d, str)]
if any(unresolved):
raise ValueError(f"Couldn't resolve dims: {unresolved}")

if np.prod(shape) > SPARSE_THRESHOLD:
a: dict[tuple[Any, ...], Any] = dict()

def set_(arr, val, *ind):
arr[tuple(ind)] = val

def final(arr):
coords = np.array(list(map(list, zip(*arr.keys()))))
return sparse.COO(
coords,
list(arr.values()),
shape=shape,
fill_value=field.default or FILL_DNODATA,
)
else:
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore

def set_(arr, val, *ind):
arr[ind] = val

def final(arr):
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
return arr

if "nper" in dims:
for kper, period in value.items():
if kper == "*":
kper = 0
match len(shape):
case 1:
set_(a, period, kper)
case _:
for cellid, v in period.items():
nn = get_nn(cellid, **dims)
set_(a, v, kper, nn)
if kper == "*":
break
else:
for cellid, v in value.items():
nn = get_nn(cellid, **dims)
set_(a, v, nn)

return final(a)
10 changes: 6 additions & 4 deletions flopy4/mf6/gwf/chd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import ClassVar, Optional

import numpy as np
from attrs import Converter
from numpy.typing import NDArray
from xattree import dict_to_array_converter, xattree
from xattree import xattree

from flopy4.mf6.constants import FILL_DNODATA
from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand All @@ -31,7 +33,7 @@ class Chd(Package):
"nnodes",
),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
aux: Optional[NDArray[np.float64]] = array(
Expand All @@ -41,7 +43,7 @@ class Chd(Package):
"nnodes",
),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
boundname: Optional[NDArray[np.str_]] = array(
Expand All @@ -51,7 +53,7 @@ class Chd(Package):
"nnodes",
),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)

Expand Down
14 changes: 8 additions & 6 deletions flopy4/mf6/gwf/dis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from attrs import Converter
from flopy.discretization.structuredgrid import StructuredGrid
from numpy.typing import NDArray
from xattree import dict_to_array_converter, xattree
from xattree import xattree

from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, dim, field

Expand Down Expand Up @@ -43,31 +45,31 @@ class Dis(Package):
block="griddata",
default=1.0,
dims=("ncol",),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
delc: NDArray[np.float64] = array(
block="griddata",
default=1.0,
dims=("nrow",),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
top: NDArray[np.float64] = array(
block="griddata",
default=1.0,
dims=("nrow", "ncol"),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
botm: NDArray[np.float64] = array(
block="griddata",
default=0.0,
dims=("nlay", "nrow", "ncol"),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
idomain: NDArray[np.int32] = array(
block="griddata",
default=1,
dims=("nlay", "nrow", "ncol"),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
nnodes: int = dim(
coord="node",
Expand Down
14 changes: 8 additions & 6 deletions flopy4/mf6/gwf/drn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
from typing import ClassVar, Optional

import numpy as np
from attrs import Converter
from numpy.typing import NDArray
from xattree import dict_to_array_converter
from xattree import xattree

from flopy4.mf6.constants import FILL_DNODATA
from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field


@xattree
class Drn(Package):
multi_package: ClassVar[bool] = True

auxiliary: Optional[list[str]] = array(block="options", default=None)
auxmultname: Optional[str] = field(block="options", default=None)
auxdepthname: Optional[str] = field(block="options", default=None)
Expand All @@ -29,14 +31,14 @@ class Drn(Package):
block="period",
dims=("nper", "nnodes"),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
cond: Optional[NDArray[np.float64]] = array(
block="period",
dims=("nper", "nnodes"),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
aux: Optional[NDArray[np.float64]] = array(
Expand All @@ -46,7 +48,7 @@ class Drn(Package):
"nnodes",
),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
boundname: Optional[NDArray[np.str_]] = array(
Expand All @@ -56,7 +58,7 @@ class Drn(Package):
"nnodes",
),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)

Expand Down
6 changes: 4 additions & 2 deletions flopy4/mf6/gwf/ic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from attrs import Converter
from numpy.typing import NDArray
from xattree import dict_to_array_converter, xattree
from xattree import xattree

from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand All @@ -12,7 +14,7 @@ class Ic(Package):
block="packagedata",
dims=("nnodes",),
default=1.0,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
export_array_ascii: bool = field(block="options", default=False)
export_array_netcdf: bool = field(block="options", default=False)
21 changes: 11 additions & 10 deletions flopy4/mf6/gwf/npf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Optional

import numpy as np
from attrs import define
from attrs import Converter, define
from numpy.typing import NDArray
from xattree import dict_to_array_converter, xattree
from xattree import xattree

from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand Down Expand Up @@ -50,47 +51,47 @@ class Xt3dOptions:
block="griddata",
dims=("nnodes",),
default=0,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
k: NDArray[np.float64] = array(
block="griddata",
dims=("nnodes",),
default=1.0,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
k22: Optional[NDArray[np.float64]] = array(
block="griddata",
dims=("nnodes",),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
k33: Optional[NDArray[np.float64]] = array(
block="griddata",
dims=("nnodes",),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
angle1: Optional[NDArray[np.float64]] = array(
block="griddata",
dims=("nnodes",),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
angle2: Optional[NDArray[np.float64]] = array(
block="griddata",
dims=("nnodes",),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
angle3: Optional[NDArray[np.float64]] = array(
block="griddata",
dims=("nnodes",),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
wetdry: Optional[NDArray[np.float64]] = array(
block="griddata",
dims=("nnodes",),
default=None,
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
24 changes: 7 additions & 17 deletions flopy4/mf6/gwf/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Literal, Optional

import numpy as np
from attrs import define
from attrs import Converter, define
from numpy.typing import NDArray
from xattree import dict_to_array_converter, xattree
from xattree import xattree

from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field
from flopy4.utils import to_path
Expand Down Expand Up @@ -54,41 +55,30 @@ class Period:
block="period",
default="all",
dims=("nper",),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
save_budget: Optional[NDArray[np.object_]] = array(
Steps,
block="period",
default="all",
dims=("nper",),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
print_head: Optional[NDArray[np.object_]] = array(
Steps,
block="period",
default="all",
dims=("nper",),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)
print_budget: Optional[NDArray[np.object_]] = array(
Steps,
block="period",
default="all",
dims=("nper",),
converter=dict_to_array_converter,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
)

# original DFN
# @classmethod
# def get_dfn(cls) -> Dfn:
# """Generate the component's MODFLOW 6 definition."""
# dfn = super().get_dfn()
# for field_name in list(dfn["perioddata"].keys()):
# dfn["perioddata"].pop(field_name)
# dfn["perioddata"]["saverecord"] = _oc_action_field("save")
# dfn["perioddata"]["printrecord"] = _oc_action_field("print")
# return dfn
Loading
Loading