Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve time discretization #257

Merged
merged 27 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
15d19e1
improve time discretization:
dbrakenhoff Aug 28, 2023
78b620c
update STO pkg
dbrakenhoff Aug 28, 2023
ec31953
update recharge pkg for new ds time discretization
dbrakenhoff Aug 28, 2023
f239cb2
update get_tdis_perioddata
dbrakenhoff Aug 28, 2023
36cebae
update tests for new time discretization
dbrakenhoff Aug 28, 2023
06c4e3c
black
dbrakenhoff Aug 28, 2023
c5ffc29
codacy
dbrakenhoff Aug 28, 2023
6082bd4
process @OnnoEbbens comments
dbrakenhoff Aug 30, 2023
21dd905
process @OnnoEbbens comments
dbrakenhoff Aug 30, 2023
505a985
remove commented code
dbrakenhoff Aug 31, 2023
361c509
pin pandas version < 2.1.0
dbrakenhoff Aug 31, 2023
9ff2a84
add pin to ci not RTD...
dbrakenhoff Aug 31, 2023
5bc7da3
process comments @rubencalje
dbrakenhoff Aug 31, 2023
54160a4
Add perlen and default value for start
rubencalje Aug 31, 2023
f5c5da2
minor docstring update
rubencalje Aug 31, 2023
3bfc28e
Allow time to be a single value as well
rubencalje Aug 31, 2023
d32c937
remove default value of start
rubencalje Aug 31, 2023
3322919
Fix tests
rubencalje Aug 31, 2023
d620084
Update notebooks
rubencalje Aug 31, 2023
fb17b73
Make sure time is converted to an iterable a bit earlier
rubencalje Aug 31, 2023
d796995
Add knmi bugfix
rubencalje Sep 1, 2023
f270efd
Fix new warning in pandas 2.1.0
rubencalje Sep 1, 2023
f7159c9
Fix other problems in notebooks
rubencalje Sep 1, 2023
cfa3d3a
Fix last notebook bugs
rubencalje Sep 1, 2023
e4d009f
Remove start_date_time check in modpath
rubencalje Sep 1, 2023
652196e
codacy + json error nb11
dbrakenhoff Sep 4, 2023
658b4ef
update log message
dbrakenhoff Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions nlmod/dims/time.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import datetime as dt
import logging
import warnings

import numpy as np
import pandas as pd
import xarray as xr
from xarray import IndexVariable

logger = logging.getLogger(__name__)


def set_ds_time(
def set_ds_time_deprecated(
ds,
time=None,
steady_state=False,
Expand Down Expand Up @@ -71,6 +73,13 @@ def set_ds_time(
ds : xarray.Dataset
dataset with time variant model data
"""

warnings.warn(
"this function is deprecated and will eventually be removed, "
"please use nlmod.time.set_ds_time() in the future.",
DeprecationWarning,
)

# checks
if time_units.lower() != "days":
raise NotImplementedError()
Expand Down Expand Up @@ -127,6 +136,99 @@ def set_ds_time(
return ds


def set_ds_time(ds, time, start, steady=True, time_units="DAYS", nstp=1, tsmult=1.0):
"""Set time discretisation for model dataset.

Parameters
----------
ds : xarray.Dataset
model dataset
time : array-like
array-like of floats (indicating elapsed time) or timestamps corresponding to
the end of each stress period in the model.
start : str or pandas.Timestamp, optional
model start datetime as string or pandas Timestamp, if None, defaults to
1 january 2000.
steady : arraylike or bool, optional
arraylike indicating which stress periods are steady-state, by default True,
which sets all stress periods to steady-state.
time_units : str, optional
time units, by default "DAYS"
nstp : int or array-like, optional
number of steps per stress period, stored in ds.attrs, default is 1
tsmult : float, optional
timestep multiplier within stress periods, stored in ds.attrs, default is 1.0

Returns
-------
ds : xarray.Dataset
model dataset with added time coordinate

"""
logger.info(
"This is the new version of set_ds_time()."
" If you're looking for the old behavior,"
"use `nlmod.time.set_ds_time_deprecated()`."
)

# parse start datetime
if isinstance(start, str):
start = pd.Timestamp(start)
elif isinstance(start, (pd.Timestamp, np.datetime64)):
pass
else:
raise TypeError("Cannot parse start datetime.")

# convert time to Timestamps
# calculate time idx
if isinstance(time[0], (int, np.integer, float)):
time = pd.Timestamp(start) + pd.to_timedelta(time, time_units)
elif isinstance(time[0], str):
time = pd.to_datetime(time)
elif isinstance(time[0], (pd.Timestamp, np.datetime64, xr.core.variable.Variable)):
pass
else:
raise TypeError("Cannot process 'time' argument. Datatype not understood.")

# set steady
if isinstance(steady, bool):
steady = steady * np.ones(len(time))

ds = ds.assign_coords(coords={"time": time})
if time_units == "D":
time_units = "DAYS"
ds.time.attrs["time_units"] = time_units
ds.time.attrs["start"] = str(start)
ds.time.attrs["steady"] = steady
dbrakenhoff marked this conversation as resolved.
Show resolved Hide resolved
ds.time.attrs["nstp"] = nstp
ds.time.attrs["tsmult"] = tsmult
return ds


def ds_time_idx_from_tdis_settings(start, perlen, nstp=1, tsmult=1.0, time_units="D"):
dbrakenhoff marked this conversation as resolved.
Show resolved Hide resolved
deltlist = []
for kper, delt in enumerate(perlen):
if not isinstance(nstp, int):
kstpkper = nstp[kper]
else:
kstpkper = nstp

if not isinstance(tsmult, float):
tsm = tsmult[kper]
else:
tsm = tsmult

if tsm > 1.0:
delt0 = delt * (tsm - 1) / (tsm**kstpkper - 1)
delt = delt0 * tsm ** np.arange(kstpkper)
else:
delt = np.ones(kstpkper) * delt / kstpkper
deltlist.append(delt)

dt_arr = np.cumsum(np.concatenate(deltlist))
return ds_time_idx(dt_arr, start_datetime=start, time_units=time_units)


def estimate_nstp(
forcing, perlen=1, tsmult=1.1, nstp_min=1, nstp_max=25, return_dt_arr=False
):
Expand Down Expand Up @@ -214,6 +316,15 @@ def estimate_nstp(


def ds_time_from_model(gwf):
warnings.warn(
"this function was renamed to `ds_time_idx_from_model`. "
"Please use the new function name.",
DeprecationWarning,
)
return ds_time_idx_from_model(gwf)


def ds_time_idx_from_model(gwf):
"""Get time index variable from model (gwf or gwt).

Parameters
Expand All @@ -227,10 +338,19 @@ def ds_time_from_model(gwf):
time coordinate for xarray data-array or dataset
"""

return ds_time_from_modeltime(gwf.modeltime)
return ds_time_idx_from_modeltime(gwf.modeltime)


def ds_time_from_modeltime(modeltime):
warnings.warn(
"this function was renamed to `ds_time_idx_from_model`. "
"Please use the new function name.",
DeprecationWarning,
)
return ds_time_idx_from_modeltime(modeltime)


def ds_time_idx_from_modeltime(modeltime):
"""Get time index variable from modeltime object.

Parameters
Expand Down
13 changes: 5 additions & 8 deletions nlmod/gwf/gwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,18 +578,15 @@ def sto(
"""
logger.info("creating mf6 STO")

if ds.time.steady_state:
if ds.time.steady_state.all():
logger.warning("Model is steady-state, no STO package created.")
return None
else:
if ds.time.steady_start:
sts_spd = {0: True}
trn_spd = {1: True}
else:
sts_spd = None
trn_spd = {0: True}
sts_spd = {iper: bool(b) for iper, b in enumerate(ds.time.steady)}
trn_spd = {iper: not bool(b) for iper, b in enumerate(ds.time.steady)}

sy = _get_value_from_ds_datavar(ds, "sy", sy, default=0.2)
ss = _get_value_from_ds_datavar(ds, "ss", ss, default=0.000001)
ss = _get_value_from_ds_datavar(ds, "ss", ss, default=1e-5)

sto = flopy.mf6.ModflowGwfsto(
gwf,
Expand Down
8 changes: 4 additions & 4 deletions nlmod/gwf/recharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ds_to_rch(gwf, ds, mask=None, pname="rch", **kwargs):
raise ValueError("please remove nan values in recharge data array")

# get stress period data
if ds.time.steady_state:
if ds.time.steady.all():
recharge = "recharge"
if "time" in ds["recharge"].dims:
mask = ds["recharge"].isel(time=0) != 0
Expand Down Expand Up @@ -69,7 +69,7 @@ def ds_to_rch(gwf, ds, mask=None, pname="rch", **kwargs):
**kwargs,
)

if ds.time.steady_state:
if ds.time.steady.all():
return rch

# create timeseries packages
Expand Down Expand Up @@ -128,7 +128,7 @@ def ds_to_evt(gwf, ds, pname="evt", nseg=1, surface=None, depth=None, **kwargs):
raise ValueError("please remove nan values in evaporation data array")

# get stress period data
if ds.time.steady_state:
if ds.time.steady.all():
if "time" in ds["evaporation"].dims:
mask = ds["evaporation"].isel(time=0) != 0
else:
Expand Down Expand Up @@ -163,7 +163,7 @@ def ds_to_evt(gwf, ds, pname="evt", nseg=1, surface=None, depth=None, **kwargs):
**kwargs,
)

if ds.time.steady_state:
if ds.time.steady.all():
return evt

# create timeseries packages
Expand Down
1 change: 1 addition & 0 deletions nlmod/plot/plotutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..dims.resample import get_affine_mod_to_world
from ..epsg28992 import EPSG_28992


def get_patches(ds, rotated=False):
"""Get the matplotlib patches for a vertex grid."""
assert "icell2d" in ds.dims
Expand Down
24 changes: 15 additions & 9 deletions nlmod/sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def write_and_run(sim, ds, write_ds=True, script_path=None, silent=False):
ds.attrs["model_ran_on"] = dt.datetime.now().strftime("%Y%m%d_%H:%M:%S")


def get_tdis_perioddata(ds):
def get_tdis_perioddata(ds, nstp=None, tsmult=None):
dbrakenhoff marked this conversation as resolved.
Show resolved Hide resolved
"""Get tdis_perioddata from ds.

Parameters
Expand Down Expand Up @@ -92,15 +92,21 @@ def get_tdis_perioddata(ds):
if len(ds["time"]) > 1:
perlen.extend(np.diff(ds["time"]) / deltat)

if "nstp" in ds:
nstp = ds["nstp"].values
else:
nstp = [ds.time.nstp] * len(perlen)
nstp = util._get_value_from_ds_datavar(
ds, "nstp", nstp, return_da=False, warn=False
)
nstp = util._get_value_from_ds_attr(ds.time, "nstp", nstp)

if isinstance(nstp, (int, np.integer)):
nstp = [nstp] * len(perlen)

nstp = util._get_value_from_ds_datavar(
dbrakenhoff marked this conversation as resolved.
Show resolved Hide resolved
ds, "nstp", nstp, return_da=False, warn=False
)
tsmult = util._get_value_from_ds_attr(ds.time, "tsmult", value=tsmult)

if "tsmult" in ds:
tsmult = ds["tsmult"].values
else:
tsmult = [ds.time.tsmult] * len(perlen)
if isinstance(tsmult, float):
tsmult = [tsmult] * len(perlen)

tdis_perioddata = list(zip(perlen, nstp, tsmult))

Expand Down
Loading