Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
workspace = Path(__file__).parent / name
time = ModelTime(perlen=[1.0], nstp=[1])
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
sim = Simulation(name=name, path=workspace, tdis=time)
sim = Simulation(name=name, workspace=workspace, tdis=time)
ims = Ims(parent=sim)
gwf_name = "mymodel"
gwf = Gwf(parent=sim, name=gwf_name, save_flows=True, dis=grid)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/quickstart_expanded.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

ws = "./mymodel"
name = "mymodel"
sim = Simulation(name=name, path=ws, exe="mf6")
sim = Simulation(name=name, workspace=ws, exe="mf6")
tdis = Tdis(sim)
gwf = Gwf(sim, name=name, save_flows=True)
dis = Dis(gwf, nrow=10, ncol=10)
Expand Down
30 changes: 4 additions & 26 deletions flopy4/mf6/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,8 @@
from pathlib import Path

from flopy4.mf6.codec import dump, load
from flopy4.mf6.component import Component
from flopy4.uio import DEFAULT_REGISTRY


def _default_filename(component: Component) -> str:
"""Default path for a component, based on its name."""
if hasattr(component, "filename") and component.filename is not None:
return component.filename
name = component.name # type: ignore
cls_name = component.__class__.__name__.lower()
return f"{name}.{cls_name}"


def _path(component: Component) -> str:
"""Default path for a component, based on its name."""
if hasattr(component, "path") and component.path is not None:
path = Path(component.path).expanduser().resolve()
if path.is_dir():
return str(path / _default_filename(component))
return str(path)
return _default_filename(component)


DEFAULT_REGISTRY.register_loader(Component, "ascii", lambda component: load(_path(component)))
DEFAULT_REGISTRY.register_writer(
Component, "ascii", lambda component: dump(component, _path(component))
)
# register io methods
# TODO: call this "mf6" or something? since it might include binary files
DEFAULT_REGISTRY.register_loader(Component, "ascii", lambda c: load(c.path))
DEFAULT_REGISTRY.register_writer(Component, "ascii", lambda c: dump(c, c.path))
32 changes: 26 additions & 6 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from abc import ABC
from collections.abc import MutableMapping
from pathlib import Path

from modflow_devtools.dfn import Dfn, Field
from xattree import xattree

from flopy4.mf6.spec import fields_dict, to_dfn_field
from flopy4.mf6.spec import field, fields_dict, to_dfn_field
from flopy4.uio import IO, Loader, Writer

COMPONENTS = {}
"""MF6 component registry."""


@xattree
# kw_only=True necessary so we can define optional fields here
# and required fields in subclasses. attrs complains otherwise
@xattree(kw_only=True)
class Component(ABC, MutableMapping):
"""
Base class for MF6 components.
Expand All @@ -22,16 +25,32 @@ class Component(ABC, MutableMapping):

We use the `children` attribute provided by `xattree`. We know
children are also `Component`s, but mypy does not. TODO: fix??
Then we can remove the `# type: ignore` comments.
"""

_load = IO(Loader) # type: ignore
_write = IO(Writer) # type: ignore

filename: str = field(default=None)

@property
def path(self) -> Path:
return Path.cwd() / self.filename

def _default_filename(self) -> str:
name = self.name # type: ignore
cls_name = self.__class__.__name__.lower()
return f"{name}.{cls_name}"

@classmethod
def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls
cls.dfn = cls.get_dfn()

def __attrs_post_init__(self):
if not self.filename:
self.filename = self._default_filename()

def __getitem__(self, key):
return self.children[key] # type: ignore

Expand All @@ -49,13 +68,14 @@ def __len__(self):

@classmethod
def get_dfn(cls) -> Dfn:
"""Generate the component's MODFLOW 6 definition."""
fields = {field_name: to_dfn_field(field) for field_name, field in fields_dict(cls).items()}
blocks: dict[str, dict[str, Field]] = {}
for field_name, field in fields.items():
if (block := field.get("block", None)) is not None:
blocks.setdefault(block, {})[field_name] = field
for field_name, field_ in fields.items():
if (block := field_.get("block", None)) is not None:
blocks.setdefault(block, {})[field_name] = field_
else:
blocks[field_name] = field
blocks[field_name] = field_

return Dfn(
name=cls.__name__.lower(),
Expand Down
20 changes: 20 additions & 0 deletions flopy4/mf6/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import ABC
from pathlib import Path

from xattree import xattree

from flopy4.mf6.component import Component
from flopy4.mf6.spec import field


@xattree
class Context(Component, ABC):
workspace: Path = field(default=None)

def __attrs_post_init__(self):
if self.workspace is None:
self.workspace = Path.cwd()

@property
def path(self) -> Path:
return self.workspace / self.filename
9 changes: 6 additions & 3 deletions flopy4/mf6/exchange.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from pathlib import Path
from typing import Optional

Expand All @@ -7,8 +8,10 @@


@xattree
class Exchange(Package):
exgtype: type = field()
exgfile: Path = field()
class Exchange(Package, ABC):
# mypy doesn't understand that kw_only=True on the
# Component means we can have required fields here
exgtype: type = field() # type: ignore
exgfile: Path = field() # type: ignore
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)
3 changes: 2 additions & 1 deletion flopy4/mf6/interface/flopy3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from typing import Optional
from warnings import warn

import numpy as np
from flopy.datbase import DataInterface, DataListInterface, DataType
Expand Down Expand Up @@ -332,7 +333,7 @@ def data_type(self):
return DataType.array3d
# TODO: boundname, auxvar arrays of strings?
case _:
raise Exception(f"UNMATCHED data_type {self._name}: {self._spec.type.__name__}")
warn(f"UNMATCHED data_type {self._name}: {self._spec.type.__name__}", UserWarning)

@property
def dtype(self):
Expand Down
40 changes: 23 additions & 17 deletions flopy4/mf6/simulation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from os import PathLike
from pathlib import Path
from typing import ClassVar
from warnings import warn

from flopy.discretization.modeltime import ModelTime
from modflow_devtools.misc import cd, run_cmd
from xattree import field, xattree
from xattree import xattree

from flopy4.mf6.component import Component
from flopy4.mf6.context import Context
from flopy4.mf6.exchange import Exchange
from flopy4.mf6.model import Model
from flopy4.mf6.solution import Solution
from flopy4.mf6.spec import field
from flopy4.mf6.tdis import Tdis


Expand All @@ -22,36 +22,42 @@ def convert_time(value):


@xattree
class Simulation(Component):
class Simulation(Context):
models: dict[str, Model] = field()
exchanges: dict[str, Exchange] = field()
solutions: dict[str, Solution] = field()
tdis: Tdis = field(converter=convert_time)
# TODO: decorator for components bound
# to some directory or file path?
path: Path = field(default=None)
filename: ClassVar[str] = "mfsim.nam"
filename: str = field(default="mfsim.nam", init=False)

def __attrs_post_init__(self):
if self.filename != "mfsim.nam":
warn(
"Simulation filename must be 'mfsim.nam'.",
UserWarning,
)
self.filename = "mfsim.nam"

@property
def time(self) -> ModelTime:
"""Return the simulation time discretization."""
return self.tdis.to_time()

def run(self, exe: str | PathLike = "mf6", verbose: bool = False) -> None:
"""Run the simulation using the given executable."""
if self.path is None:
if self.workspace is None:
raise ValueError(f"Simulation {self.name} has no workspace path.")
with cd(self.path):
with cd(self.workspace):
stdout, stderr, retcode = run_cmd(exe, verbose=verbose)
if retcode != 0:
raise RuntimeError(
f"Simulation {self.name}: {exe} failed to run with returncode " # type: ignore
f"{retcode}, and error message:\n\n{stdout + stderr} "
)

def load(self, format):
with cd(self.path):
super().load(format)
def load(self, format="ascii"):
"""Load the simulation in the specified format."""
super().load(format)

def write(self, format):
with cd(self.path):
super().write(format)
def write(self, format="ascii"):
"""Write the simulation in the specified format."""
super().write(format)
Loading
Loading