Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flopy4/mf6/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,15 @@ def data_type(self):
case "ndarray":
if "nper" in self._data.dims:
if self._data.ndim == 2:
if "nnodes" in self._data.dims:
if "nodes" in self._data.dims:
return DataType.transient2d # nodes?
if self._data.ndim == 3:
return DataType.transient3d # ncpl?
if self._data.ndim == 4:
return DataType.transient2d # nodes?
else:
if self._data.ndim == 1:
if "nnodes" in self._data.dims:
if "nodes" in self._data.dims:
return DataType.array3d
if self._data.ndim == 2:
return DataType.array2d
Expand All @@ -351,7 +351,7 @@ def dtype(self):
@property
def array(self):
if self._spec.type.__name__ == "ndarray":
if "nnodes" in self._data.dims:
if "nodes" in self._data.dims:
if "nper" in self._data.dims:
shape = (
self._time.nper,
Expand Down
4 changes: 2 additions & 2 deletions flopy4/mf6/codec/writer/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
yield (value.item(),)
return

spatial_dims = [d for d in value.dims if d in ("nlay", "nrow", "ncol", "nnodes")]
spatial_dims = [d for d in value.dims if d in ("nlay", "nrow", "ncol", "nodes")]
has_spatial_dims = len(spatial_dims) > 0
mask = nonempty(value)
indices = np.where(mask)
Expand Down Expand Up @@ -223,7 +223,7 @@ def dataset2list(value: xr.Dataset):
if combined_mask is None or not np.any(combined_mask):
return

spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nnodes")]
spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nodes")]
has_spatial_dims = len(spatial_dims) > 0
indices = np.where(combined_mask)
for i in range(len(indices[0])):
Expand Down
81 changes: 79 additions & 2 deletions flopy4/mf6/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
from pathlib import Path
from typing import Any

import numpy as np
import sparse
import xarray as xr
import xattree
from attrs import define
from cattrs import Converter
from numpy.typing import NDArray
from xattree import get_xatspec

from flopy4.adapters import get_nn
from flopy4.mf6.component import Component
from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA
from flopy4.mf6.context import Context
from flopy4.mf6.exchange import Exchange
from flopy4.mf6.model import Model
Expand Down Expand Up @@ -148,15 +155,15 @@ def unstructure_component(value: Component) -> dict[str, Any]:
blocks[block_name][field_name] = tuple(field_value.values.tolist())
elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims:
has_spatial_dims = any(
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nnodes"]
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
)
if has_spatial_dims:
# terrible hack to convert flat nodes dimension to 3d structured dims.
# long term solution for this is to use a custom xarray index. filters
# should then have access to all dimensions needed.
dims_ = set(field_value.dims).copy()
dims_.remove("nper")
if dims_ == {"nnodes"}:
if dims_ == {"nodes"}:
parent = value.parent # type: ignore
field_value = xr.DataArray(
field_value.data.reshape(
Expand Down Expand Up @@ -228,3 +235,73 @@ def _make_converter() -> Converter:


COMPONENT_CONVERTER = _make_converter()


def dict_to_array(value, self_, field) -> NDArray:
"""
Convert a sparse dictionary representation of an array to a
dense numpy array or a sparse COO array.
"""

if not isinstance(value, dict):
# if not a dict, assume it's a numpy array
# and let xarray deal with it if it isn't
return value

spec = get_xatspec(type(self_)).flat
field = spec[field.name]
if not field.dims:
raise ValueError(f"Field {field} missing dims")

# resolve dims
explicit_dims = self_.__dict__.get("dims", {})
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
dims = inherited_dims | explicit_dims
shape = [dims.get(d, d) for d in field.dims]
unresolved = [d for d in shape if isinstance(d, str)]
if any(unresolved):
raise ValueError(f"Couldn't resolve dims: {unresolved}")

if np.prod(shape) > SPARSE_THRESHOLD:
a: dict[tuple[Any, ...], Any] = dict()

def set_(arr, val, *ind):
arr[tuple(ind)] = val

def final(arr):
coords = np.array(list(map(list, zip(*arr.keys()))))
return sparse.COO(
coords,
list(arr.values()),
shape=shape,
fill_value=field.default or FILL_DNODATA,
)
else:
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore

def set_(arr, val, *ind):
arr[ind] = val

def final(arr):
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
return arr

if "nper" in dims:
for kper, period in value.items():
if kper == "*":
kper = 0
match len(shape):
case 1:
set_(a, period, kper)
case _:
for cellid, v in period.items():
nn = get_nn(cellid, **dims)
set_(a, v, kper, nn)
if kper == "*":
break
else:
for cellid, v in value.items():
nn = get_nn(cellid, **dims)
set_(a, v, nn)

return final(a)
80 changes: 0 additions & 80 deletions flopy4/mf6/converters.py

This file was deleted.

8 changes: 4 additions & 4 deletions flopy4/mf6/gwf/chd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xattree import xattree

from flopy4.mf6.component import update_maxbound
from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.converter import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand All @@ -29,7 +29,7 @@ class Chd(Package):
block="period",
dims=(
"nper",
"nnodes",
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
Expand All @@ -40,7 +40,7 @@ class Chd(Package):
block="period",
dims=(
"nper",
"nnodes",
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
Expand All @@ -51,7 +51,7 @@ class Chd(Package):
block="period",
dims=(
"nper",
"nnodes",
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
Expand Down
6 changes: 3 additions & 3 deletions flopy4/mf6/gwf/dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.converter import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, dim, field

Expand Down Expand Up @@ -70,14 +70,14 @@ class Dis(Package):
dims=("nlay", "nrow", "ncol"),
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
nnodes: int = dim(
nodes: int = dim(
coord="node",
scope="gwf",
init=False,
)

def __attrs_post_init__(self):
self.nnodes = self.ncol * self.nrow * self.nlay
self.nodes = self.ncol * self.nrow * self.nlay
super().__attrs_post_init__()

def to_grid(self) -> StructuredGrid:
Expand Down
10 changes: 5 additions & 5 deletions flopy4/mf6/gwf/drn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xattree import xattree

from flopy4.mf6.component import update_maxbound
from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.converter import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand All @@ -29,15 +29,15 @@ class Drn(Package):
maxbound: Optional[int] = field(block="dimensions", default=None, init=False)
elev: Optional[NDArray[np.float64]] = array(
block="period",
dims=("nper", "nnodes"),
dims=("nper", "nodes"),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
on_setattr=update_maxbound,
)
cond: Optional[NDArray[np.float64]] = array(
block="period",
dims=("nper", "nnodes"),
dims=("nper", "nodes"),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
reader="urword",
Expand All @@ -47,7 +47,7 @@ class Drn(Package):
block="period",
dims=(
"nper",
"nnodes",
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
Expand All @@ -58,7 +58,7 @@ class Drn(Package):
block="period",
dims=(
"nper",
"nnodes",
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
Expand Down
4 changes: 2 additions & 2 deletions flopy4/mf6/gwf/ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.converters import dict_to_array
from flopy4.mf6.converter import dict_to_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand All @@ -14,7 +14,7 @@ class Ic(Package):
export_array_netcdf: bool = field(block="options", default=False)
strt: NDArray[np.float64] = array(
block="griddata",
dims=("nnodes",),
dims=("nodes",),
default=1.0,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
Loading
Loading