diff --git a/flopy4/adapters.py b/flopy4/adapters.py index f4a36ddf..8a3cec3e 100644 --- a/flopy4/adapters.py +++ b/flopy4/adapters.py @@ -142,3 +142,18 @@ def get_cellid(nn: int, grid: Grid) -> tuple[int, ...]: return (nn,) case _: raise TypeError(f"Unsupported grid type: {type(grid)}") + + +def get_nn(cellid, **kwargs): + ndim = len(cellid) + match ndim: + case 1: + return cellid[0] + case 2: + k, j = cellid + return k * kwargs["ncpl"] + j + case 3: + k, i, j = cellid + return k * kwargs["nrow"] * kwargs["ncol"] + i * kwargs["ncol"] + j + case _: + raise ValueError(f"Invalid cellid: {cellid}") diff --git a/flopy4/mf6/codec/converter.py b/flopy4/mf6/codec/converter.py index c2ff04f1..f3077c09 100644 --- a/flopy4/mf6/codec/converter.py +++ b/flopy4/mf6/codec/converter.py @@ -7,7 +7,7 @@ from xarray import DataArray from xattree import get_xatspec -from flopy4.adapters import get_cellid +from flopy4.adapters import get_cellid, get_nn from flopy4.mf6.component import Component from flopy4.mf6.config import SPARSE_THRESHOLD from flopy4.mf6.constants import FILL_DNODATA @@ -66,19 +66,6 @@ def final(arr): arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA return arr - def _get_nn(cellid): - match len(cellid): - case 1: - return cellid[0] - case 2: - k, j = cellid - return k * dims["ncpl"] + j - case 3: - k, i, j = cellid - return k * dims["nrow"] * dims["ncol"] + i * dims["ncol"] + j - case _: - raise ValueError(f"Invalid cellid: {cellid}") - # populate array. TODO: is there a way to do this # without hardcoding awareness of kper and cellid? if "nper" in dims: @@ -90,13 +77,13 @@ def _get_nn(cellid): set_(a, period, kper) case _: for cellid, v in period.items(): - nn = _get_nn(cellid) + nn = get_nn(cellid, **dims) set_(a, v, kper, nn) if kper == "*": break else: for cellid, v in value.items(): - nn = _get_nn(cellid) + nn = get_nn(cellid, **dims) set_(a, v, nn) return final(a)