Skip to content

Commit

Permalink
Rewrite trans_derrf to intended behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
Blunde1 committed Sep 28, 2023
1 parent 664fce7 commit b577c65
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 13 deletions.
42 changes: 34 additions & 8 deletions src/ert/config/gen_kw_config.py
Expand Up @@ -427,6 +427,13 @@ def trans_errf(x: float, arg: List[float]) -> float:
"""
_min, _max, _skew, _width = arg[0], arg[1], arg[2], arg[3]
y = 0.5 * (1 + math.erf((x + _skew) / (_width * math.sqrt(2.0))))
if np.isnan(y):
raise ValueError(
(
"Output is nan, likely from triplet (x, skewness, width) "
"leading to low/high-probability in normal CDF."
)
)
return _min + y * (_max - _min)

@staticmethod
Expand All @@ -439,15 +446,34 @@ def trans_raw(x: float, _: List[float]) -> float:

@staticmethod
def trans_derrf(x: float, arg: List[float]) -> float:
'''Observe that the argument of the shift should be \"+\"'''
_steps, _min, _max, _skew, _width = int(arg[0]), arg[1], arg[2], arg[3], arg[4]
y = math.floor(
_steps
* 0.5
* (1 + math.erf((x + _skew) / (_width * math.sqrt(2.0))))
/ (_steps - 1)
"""
Bin the result of `trans_errf` with `min=0` and `max=1` to closest of `nbins`
linearly spaced values on [0,1]. Finally map [0,1] to [min, max].
"""
_steps, _min, _max, _skew, _width = (
int(arg[0]),
arg[1],
arg[2],
arg[3],
arg[4],
)
return _min + y * (_max - _min)
q_values = np.linspace(start=0, stop=1, num=_steps)
q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:]
y = TransferFunction.trans_errf(x, [0, 1, _skew, _width])
bin_index = np.digitize(y, q_checks, right=True)
y_binned = q_values[bin_index]
result = _min + y_binned * (_max - _min)
if result > _max or result < _min:
warnings.warn(
"trans_derff suffered from catastrophic loss of precision, clamping to min,max",
stacklevel=1,
)
return np.clip(result, _min, _max)
if np.isnan(result):
raise ValueError(
"trans_derrf returns nan, check that input arguments are reasonable"
)
return result

@staticmethod
def trans_unif(x: float, arg: List[float]) -> float:
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/config/test_gen_kw_config.py
Expand Up @@ -358,11 +358,6 @@ def test_gen_kw_params_parsing(tmpdir, params, error):
("MYNAME ERRF 1 2 0.1 0.1", 0.3, 1.99996832875816688002),
("MYNAME ERRF 1 2 0.1 0.1", 0.7, 1.99999999999999933387),
("MYNAME ERRF 1 2 0.1 0.1", 1.0, 2.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", -1.0, 1.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 0.0, 1.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 0.3, 2.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 0.7, 2.00000000000000000000),
("MYNAME DERRF 10 1 2 0.1 0.1", 1.0, 2.00000000000000000000),
],
)
def test_gen_kw_trans_func(tmpdir, params, xinput, expected):
Expand Down
78 changes: 78 additions & 0 deletions tests/unit_tests/config/test_transfer_functions.py
@@ -1,6 +1,7 @@
import numpy as np
from hypothesis import given
from hypothesis import strategies as st
from scipy.stats import norm

from ert.config import TransferFunction

Expand Down Expand Up @@ -79,3 +80,80 @@ def test_that_truncated_normal_stretches(x, arg):
return
result = TransferFunction.trans_truncated_normal(x, arg)
assert np.isclose(result, expected)


def valid_derrf_parameters():
"""All elements in R, min<max, and width>0"""
steps = st.integers(min_value=2, max_value=1000)
min_max = (
st.tuples(
st.floats(
min_value=-1e6, max_value=1e6, allow_nan=False, allow_infinity=False
),
st.floats(
min_value=-1e6, max_value=1e6, allow_nan=False, allow_infinity=False
),
)
.map(sorted)
.filter(lambda x: x[0] < x[1]) # filter out edge case of equality
)
skew = st.floats(allow_nan=False, allow_infinity=False)
width = st.floats(
min_value=0.01, max_value=1e6, allow_nan=False, allow_infinity=False
)
return min_max.flatmap(
lambda min_max: st.tuples(
steps, st.just(min_max[0]), st.just(min_max[1]), skew, width
)
)


@given(st.floats(allow_nan=False, allow_infinity=False), valid_derrf_parameters())
def test_that_derrf_is_within_bounds(x, arg):
"""The result shold always be between (or equal) min and max"""
result = TransferFunction.trans_derrf(x, arg)
assert arg[1] <= result <= arg[2]


@given(
st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=2),
valid_derrf_parameters(),
)
def test_that_derrf_creates_at_least_steps_or_less_distinct_values(xlist, arg):
"""derrf cannot create more than steps distinct values"""
res = [TransferFunction.trans_derrf(x, arg) for x in xlist]
assert len(set(res)) <= arg[0]


@given(st.floats(allow_nan=False, allow_infinity=False), valid_derrf_parameters())
def test_that_derrf_corresponds_scaled_binned_normal_cdf(x, arg):
"""Check correspondance to normal cdf with -mu=_skew and sd=_width"""
_steps, _min, _max, _skew, _width = arg
q_values = np.linspace(start=0, stop=1, num=_steps)
q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:]
p = norm.cdf(x, loc=-_skew, scale=_width)
bin_index = np.digitize(p, q_checks, right=True)
expected = q_values[bin_index]
# scale and ensure ok numerics
expected = _min + expected * (_max - _min)
if expected > _max or expected < _min:
np.clip(expected, _min, _max)
result = TransferFunction.trans_derrf(x, arg)
assert np.isclose(result, expected)


@given(
st.tuples(
st.floats(allow_nan=False, allow_infinity=False),
st.floats(allow_nan=False, allow_infinity=False),
)
.map(sorted)
.filter(lambda x: x[0] < x[1]),
valid_derrf_parameters(),
)
def test_that_derrf_is_non_strictly_monotone(x_tuple, arg):
"""`derrf` is a non-strict monotone function"""
x1, x2 = x_tuple
assert TransferFunction.trans_derrf(x1, arg) <= TransferFunction.trans_derrf(
x2, arg
)

0 comments on commit b577c65

Please sign in to comment.