Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,6 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"purpose": [
"debug-test"
],
"justMyCode": false
},
{
"name": "quickstart",
"type": "debugpy",
Expand Down
3 changes: 1 addition & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
"editor.formatOnSave": true,
"files.insertFinalNewline": true,
"python.testing.pytestArgs": [
"test",
"-s"
"test"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
Expand Down
4 changes: 1 addition & 3 deletions flopy4/mf6/codec/writer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
trim_blocks=True,
lstrip_blocks=True,
)
_JINJA_ENV.filters["is_dataset"] = filters.is_dataset
_JINJA_ENV.filters["field_format"] = filters.field_format
_JINJA_ENV.filters["field_type"] = filters.field_type
_JINJA_ENV.filters["array_how"] = filters.array_how
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
_JINJA_ENV.filters["array2string"] = filters.array2string
_JINJA_ENV.filters["data2list"] = filters.data2list
_JINJA_ENV.filters["data2keystring"] = filters.data2keystring
_JINJA_TEMPLATE_NAME = "blocks.jinja"
_PRINT_OPTIONS = {
"precision": 4,
Expand Down
190 changes: 72 additions & 118 deletions flopy4/mf6/codec/writer/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,32 @@

import numpy as np
import xarray as xr
from modflow_devtools.dfn.schema.v2 import FieldType
from numpy.typing import NDArray

from flopy4.mf6.constants import FILL_DNODATA


def _is_keystring_format(dataset: xr.Dataset) -> bool:
"""Check if dataset should use keystring format based on metadata."""
field_metadata = dataset.attrs.get("field_metadata", {})
return any(meta.get("format") == "keystring" for meta in field_metadata.values())
def field_type(value: Any) -> FieldType:
"""Get a value's type according to the MF6 specification."""


def _is_tabular_time_format(dataset: xr.Dataset) -> bool:
"""True if a dataset has multiple columns and only one dimension 'nper'."""
return len(dataset.data_vars) > 1 and all(
"nper" in var.dims and len(var.dims) == 1 for var in dataset.data_vars.values()
)


def is_dataset(value: Any) -> bool:
return isinstance(value, xr.Dataset)


def field_format(value: Any) -> str:
"""
Get a field's formatting type as defined by the MF6 definition language:
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
"""
if isinstance(value, bool):
return "keyword"
if isinstance(value, int):
return "integer"
if isinstance(value, float):
return "double precision"
return "double"
if isinstance(value, str):
return "string"
if isinstance(value, (dict, tuple)):
if isinstance(value, tuple):
return "record"
if isinstance(value, xr.DataArray):
if value.dtype == "object":
return "list"
return "array"
if isinstance(value, (xr.Dataset, list)):
if isinstance(value, xr.Dataset):
if _is_keystring_format(value):
return "keystring"
if _is_tabular_time_format(value):
return "list"
if isinstance(value, (list, dict, xr.Dataset)):
return "list"
return "keystring"


def has_time_dim(value: Any) -> bool:
return isinstance(value, xr.DataArray) and "nper" in value.dims
raise ValueError(f"Unsupported field type: {type(value)}")


def array_how(value: xr.DataArray) -> str:
Expand Down Expand Up @@ -140,38 +113,42 @@ def array2string(value: NDArray) -> str:
return buffer.getvalue().strip()


def nonempty(arr: NDArray | xr.DataArray) -> NDArray:
if isinstance(arr, xr.DataArray):
arr = arr.values
if arr.dtype == "object":
mask = arr != None # noqa: E711
def nonempty(value: NDArray | xr.DataArray) -> NDArray:
"""
Return a boolean mask of non-empty (non-nodata) values in an array.
TODO: don't hardcode FILL_DNODATA, support different fill values
"""
if isinstance(value, xr.DataArray):
value = value.values
if value.dtype == "object":
mask = value != None # noqa: E711
else:
mask = ~np.ma.masked_invalid(arr).mask
mask = mask & (arr != FILL_DNODATA)
mask = ~np.ma.masked_invalid(value).mask
mask = mask & (value != FILL_DNODATA)
return mask


def data2list(value: list | xr.DataArray | xr.Dataset):
def data2list(value: list | tuple | dict | xr.Dataset | xr.DataArray):
"""
Yield record tuples from a list, `DataArray` or `Dataset`.

Yields
------
tuple
Tuples of (*cellid, *values) or (*values) depending on spatial dimensions
Yield records (tuples) from data in a `list`, `dict`, `DataArray` or `Dataset`.
"""

if isinstance(value, list):
for item in value:
yield item
if isinstance(value, (list, tuple)):
for rec in value:
yield rec
return

if isinstance(value, dict):
for name, val in value.values():
yield (name, val)
return

if isinstance(value, xr.Dataset):
yield from dataset2list(value)
return

# handle scalar
if value.ndim == 0:
# otherwise we have a DataArray
if value.ndim == 0: # handle scalar
if not np.isnan(value.item()) and value.item() is not None:
yield (value.item(),)
return
Expand All @@ -184,90 +161,67 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
for i, val in enumerate(values):
if has_spatial_dims:
cellid = tuple(idx[i] + 1 for idx in indices)
result = cellid + (val,)
rec = cellid + (val,)
else:
result = (val,)
yield result
rec = (val,)
yield rec


def dataset2list(value: xr.Dataset):
"""
Yield record tuples from an xarray Dataset. For regular/tabular list-based format.
Yield records (tuples) from an `xarray.Dataset`.

Yields
------
tuple
Tuples of (*cellid, *values) or (*values) depending on spatial dimensions
If the first data variable is a string type, assume all are
string type. Then the dataset represents a keystring; yield
tuples of (name, *value). Otherwise, yield tuples: (*value)
if no spatial dimensions, or (*cellid, *value) when spatial
dimensions are present.
"""
if value is None or not any(value.data_vars):
return

# handle scalar
first_arr = next(iter(value.data_vars.values()))
if first_arr.ndim == 0:
field_vals = []
for field_name in value.data_vars.keys():
field_val = value[field_name]
if hasattr(field_val, "item"):
field_vals.append(field_val.item())
else:
field_vals.append(field_val)
yield tuple(field_vals)
first = next(iter(value.data_vars.values()))
is_union = first.dtype.type is np.str_

if first.ndim == 0: # handle scalar
if is_union:
for name in value.data_vars.keys():
val = value[name]
val = val.item() if val.shape == () else val
yield (*name.split("_"), val)
else:
vals = []
for name in value.data_vars.keys():
val = value[name]
val = val.item() if val.shape == () else val
vals.append(val)
yield tuple(vals)
return

# build mask
combined_mask: Any = None
for field_name, arr in value.data_vars.items():
mask = nonempty(arr)
for name, first in value.data_vars.items():
mask = nonempty(first)
combined_mask = mask if combined_mask is None else combined_mask | mask
if combined_mask is None or not np.any(combined_mask):
return

spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nodes")]
spatial_dims = [d for d in first.dims if d in ("nlay", "nrow", "ncol", "nodes")]
has_spatial_dims = len(spatial_dims) > 0
indices = np.where(combined_mask)
for i in range(len(indices[0])):
field_vals = []
for field_name in value.data_vars.keys():
field_val = value[field_name][tuple(idx[i] for idx in indices)]
if hasattr(field_val, "item"):
field_vals.append(field_val.item())
else:
field_vals.append(field_val)
if has_spatial_dims:
cellid = tuple(idx[i] + 1 for idx in indices)
yield cellid + tuple(field_vals)
if is_union:
for name in value.data_vars.keys():
val = value[name][tuple(idx[i] for idx in indices)]
val = val.item() if val.shape == () else val
yield (*name.split("_"), val)
else:
yield tuple(field_vals)


def data2keystring(value: dict | xr.Dataset):
"""
Yield record tuples from a dict or dataset. For irregular list-based format, i.e. keystrings.

Yields
------
tuple
Tuples of (field_name, value) for use with record macro
"""
if isinstance(value, dict):
if not value:
return
for field_name, field_val in value.items():
yield (field_name.upper(), field_val)
elif isinstance(value, xr.Dataset):
if value is None or not any(value.data_vars):
return

for field_name in value.data_vars.keys():
name = (
field_name.replace("_", " ").upper()
if np.issubdtype(value.data_vars[field_name].dtype, np.str_)
else field_name.upper()
)
field_val = value[field_name]
if hasattr(field_val, "item"):
val = field_val.item()
vals = []
for name in value.data_vars.keys():
val = value[name][tuple(idx[i] for idx in indices)]
val = val.item() if val.shape == () else val
vals.append(val)
if has_spatial_dims:
cellid = tuple(idx[i] + 1 for idx in indices)
yield cellid + tuple(vals)
else:
val = field_val
yield (name, val)
yield tuple(vals)
50 changes: 18 additions & 32 deletions flopy4/mf6/codec/writer/templates/macros.jinja
Original file line number Diff line number Diff line change
@@ -1,62 +1,48 @@
{% set inset = " " %}
{% set inset = " " %}

{% macro field(name, value) %}
{% set format = value|field_format %}
{% if format in ['keyword', 'integer', 'double precision', 'string'] %}
{% set type = value|field_type %}
{% if type in ['keyword', 'integer', 'double', 'string'] %}
{{ scalar(name, value) }}
{% elif format == 'record' %}
{{ record(name, value) }}
{% elif format == 'keystring' %}
{{ keystring(name, value) }}
{% elif format == 'array' %}
{% elif type == 'record' %}
{{ record(value) }}
{% elif type == 'array' %}
{{ array(name, value, how=value|array_how) }}
{% elif format == 'list' %}
{% elif type == 'list' %}
{{ list(name, value) }}
{% endif %}
{% endmacro %}

{% macro scalar(name, value) %}
{% set format = value|field_format %}
{% if value is not none %}{{ inset ~ name.upper() }}{% if format != 'keyword' %} {{ value }}{% endif %}{% endif %}
{% set type = value|field_type %}
{% if value is not none %}{{ inset ~ name.upper() }}{% if type != 'keyword' %} {{ value }}{% endif %}{% endif %}
{% endmacro %}

{% macro keystring(name, value) %}
{% for option in (value|data2keystring) -%}
{{ record("", option) }}{% if not loop.last %}{{ "\n" }}{% endif %}
{%- endfor %}
{% macro record(value) %}
{{ inset ~ value|join(" ") -}}
{% endmacro %}

{% macro record(name, value) %}
{%- if value is mapping %}
{% for field_name, field_value in value.items() -%}
{{ field_name.upper() }} {{ field(field_value) }}
{% macro list(name, value) %}
{% for row in (value|data2list) %}
{{ record(row) }}{% if not loop.last %}{{ "\n" }}{% endif %}
{%- endfor %}
{% else %}
{{ inset ~ value|join(" ") }}
{%- endif %}
{% endmacro %}

{% macro array(name, value, how="internal") %}
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}

{% if how == "constant" %}
CONSTANT {{ value.item() }}
CONSTANT {{ value.item() }}
{% elif how == "layered constant" %}
{% for layer in value -%}
CONSTANT {{ layer.item() }}
CONSTANT {{ layer.item() }}
{%- endfor %}
{% elif how == "internal" %}
INTERNAL
INTERNAL
{% for chunk in value|array_chunks -%}
{{ (2 * inset) ~ chunk|array2string }}
{%- endfor %}
{% elif how == "external" %}
OPEN/CLOSE {{ value }}
OPEN/CLOSE {{ value }}
{% endif %}
{% endmacro %}

{% macro list(name, value) %}
{% for row in (value|data2list) %}
{{ inset ~ row|join(" ") }}{% if not loop.last %}{{ "\n" }}{% endif %}
{%- endfor %}
{% endmacro %}
Loading
Loading