Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions docs/examples/array_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,12 @@
constant = data_path / "constant.txt"
external = data_path / "external.txt"
shape = (1000, 100)
dtype = "double"

# Open and load a NumPy array representation

fhandle = open(internal)
imfa = MFArray.load(fhandle, data_path, shape, header=False)
imfa = MFArray.load(fhandle, data_path, shape, type=dtype, header=False)

# Get values

Expand All @@ -87,7 +88,7 @@
plt.colorbar()

fhandle = open(constant)
cmfa = MFArray.load(fhandle, data_path, shape, header=False)
cmfa = MFArray.load(fhandle, data_path, shape, type=dtype, header=False)
cvals = cmfa.value
plt.imshow(cvals[0:100])
plt.colorbar()
Expand All @@ -110,7 +111,7 @@
# External

fhandle = open(external)
emfa = MFArray.load(fhandle, data_path, shape, header=False)
emfa = MFArray.load(fhandle, data_path, shape, type=dtype, header=False)
evals = emfa.value
evals

Expand All @@ -135,7 +136,9 @@

fhandle = open(ilayered)
shape = (3, 1000, 100)
ilmfa = MFArray.load(fhandle, data_path, shape, header=False, layered=True)
ilmfa = MFArray.load(
fhandle, data_path, shape, type=dtype, header=False, layered=True
)
vals = ilmfa.value

ilmfa._value # internal storage
Expand Down Expand Up @@ -182,7 +185,9 @@

fhandle = open(clayered)
shape = (3, 1000, 100)
clmfa = MFArray.load(fhandle, data_path, shape, header=False, layered=True)
clmfa = MFArray.load(
fhandle, data_path, shape, type=dtype, header=False, layered=True
)

clmfa._value

Expand Down Expand Up @@ -235,7 +240,9 @@

fhandle = open(mlayered)
shape = (3, 1000, 100)
mlmfa = MFArray.load(fhandle, data_path, shape, header=False, layered=True)
mlmfa = MFArray.load(
fhandle, data_path, shape, type=dtype, header=False, layered=True
)

mlmfa.how

Expand Down
62 changes: 47 additions & 15 deletions flopy4/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ def load(cls, f, cwd, shape, header=True, **kwargs):
model_shape = kwargs.pop("model_shape", None)
params = kwargs.pop("blk_params", {})
mempath = kwargs.pop("mempath", None)
atype = kwargs.get("type", None)

if atype is not None:
if atype == "integer":
dtype = np.int32
elif atype == "double":
dtype = np.float64
else:
raise ValueError("array spec type not defined")

if model_shape and isinstance(shape, str):
if shape == "(nodes)":
n = math.prod([x for x in model_shape])
Expand All @@ -446,20 +456,30 @@ def load(cls, f, cwd, shape, header=True, **kwargs):
nlay = params.get("dimensions").get("nlay")
nrow = params.get("dimensions").get("nrow")
ncol = params.get("dimensions").get("ncol")
shape = (nlay, nrow, ncol)
elif "disv" in mempath.split("/"):
nlay = params.get("dimensions").get("nlay")
ncpl = params.get("dimensions").get("ncpl")
nodes = params.get("dimensions").get("nodes")
if nrow and ncol:
shape = (nlay, nrow, ncol)
elif ncpl:
nvert = params.get("dimensions").get("nvert")
if shape == "(ncpl)":
shape = ncpl
elif shape == "(ncpl, nlay)":
shape = (nlay, ncpl)
elif nodes:
elif shape == "(nvert)":
shape = nvert
elif "disu" in mempath.split("/"):
nodes = params.get("dimensions").get("nodes")
nja = params.get("dimensions").get("nja")
if "nodes" in shape:
shape = nodes
elif "nja" in shape:
shape = nja
if layered:
nlay = shape[0]
lshp = shape[1:]
objs = []
for _ in range(nlay):
mfa = cls._load(f, cwd, lshp, name)
mfa = cls._load(f, cwd, lshp, dtype=dtype, name=name)
objs.append(mfa)

return MFArray(
Expand All @@ -474,11 +494,17 @@ def load(cls, f, cwd, shape, header=True, **kwargs):
else:
kwargs.pop("layered", None)
return cls._load(
f, cwd, shape, layered=layered, name=name, **kwargs
f,
cwd,
shape,
layered=layered,
dtype=dtype,
name=name,
**kwargs,
)

@classmethod
def _load(cls, f, cwd, shape, layered=False, **kwargs):
def _load(cls, f, cwd, shape, layered=False, dtype=None, **kwargs):
control_line = multi_line_strip(f).split()

if CommonNames.iprn.lower() in control_line:
Expand All @@ -491,25 +517,31 @@ def _load(cls, f, cwd, shape, layered=False, **kwargs):
clpos = 1

if how == MFArrayType.internal:
array = cls.read_array(f)
array = cls.read_array(f, dtype)

elif how == MFArrayType.constant:
array = float(control_line[clpos])
if dtype == np.float64:
array = float(control_line[clpos])
else:
array = int(control_line[clpos])
clpos += 1

elif how == MFArrayType.external:
extpath = Path(control_line[clpos])
fpath = cwd / extpath
with open(fpath) as foo:
array = cls.read_array(foo)
array = cls.read_array(foo, dtype)
clpos += 1

else:
raise NotImplementedError()

factor = None
if len(control_line) > 2:
factor = float(control_line[clpos + 1])
if dtype == np.float64:
factor = float(control_line[clpos + 1])
else:
factor = int(control_line[clpos + 1])

return cls(
shape,
Expand All @@ -521,7 +553,7 @@ def _load(cls, f, cwd, shape, layered=False, **kwargs):
)

@staticmethod
def read_array(f):
def read_array(f, dtype):
"""
Read a MODFLOW 6 array from an open file
into a flat NumPy array representation.
Expand All @@ -532,11 +564,11 @@ def read_array(f):
pos = f.tell()
line = f.readline()
line = line_strip(line)
if not re.match("^[0-9. ]+$", line):
if not re.match("^[-0-9. ]+$", line):
f.seek(pos, 0)
break
astr.append(line)

astr = StringIO(" ".join(astr))
array = np.genfromtxt(astr).ravel()
array = np.genfromtxt(astr, dtype=dtype).ravel()
return array
16 changes: 12 additions & 4 deletions flopy4/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def load(cls, f, **kwargs):
name = None
index = None
found = False
period = False
params = dict()
members = cls.params

Expand All @@ -235,15 +236,20 @@ def load(cls, f, **kwargs):
line = f.readline()
if line == "":
raise ValueError("Early EOF, aborting")
if line == "\n":
if line == "\n" or line.lstrip().startswith("#"):
continue
words = strip(line).lower().split()
key = words[0]
if period:
key = "stress_period_data"
else:
key = words[0]
if key == "begin":
found = True
name = words[1]
if len(words) > 2 and str.isdigit(words[2]):
index = int(words[2])
if name == "period":
period = True
elif key == "end":
break
elif found:
Expand All @@ -268,20 +274,22 @@ def load(cls, f, **kwargs):
# TODO: inject from model somehow?
# and remove special handling here
kwrgs["cwd"] = ""
# kwrgs["type"] = param.type
kwrgs["mempath"] = f"{mempath}/{name}"
if ptype is not MFArray:
if ptype is not MFArray and ptype is not MFList:
kwrgs.pop("model_shape", None)
kwrgs.pop("blk_params", None)

params[param.name] = ptype.load(f, **kwrgs)
period = False

return cls(name=name, index=index, params=params)

def write(self, f):
"""Write the block to file."""
index = self.index if self.index is not None else ""
begin = f"BEGIN {self.name.upper()} {index}\n"
end = f"END {self.name.upper()}\n"
end = f"END {self.name.upper()}\n\n"

f.write(begin)
super().write(f)
Expand Down
Loading
Loading