Skip to content

Commit

Permalink
Merge pull request #7464 from alvarosg/string-func-parser
Browse files Browse the repository at this point in the history
[MRG+2] ENH: _StringFuncParser to get numerical functions callables from strings
  • Loading branch information
efiring committed Dec 15, 2016
2 parents 424d3b0 + c143e75 commit b73cedc
Show file tree
Hide file tree
Showing 2 changed files with 317 additions and 0 deletions.
256 changes: 256 additions & 0 deletions lib/matplotlib/cbook.py
Expand Up @@ -2637,3 +2637,259 @@ def __exit__(self, exc_type, exc_value, traceback):
os.rmdir(path)
except OSError:
pass


class _FuncInfo(object):
"""
Class used to store a function.
"""

def __init__(self, function, inverse, bounded_0_1=True, check_params=None):
"""
Parameters
----------
function : callable
A callable implementing the function receiving the variable as
first argument and any additional parameters in a list as second
argument.
inverse : callable
A callable implementing the inverse function receiving the variable
as first argument and any additional parameters in a list as
second argument. It must satisfy 'inverse(function(x, p), p) == x'.
bounded_0_1: bool or callable
A boolean indicating whether the function is bounded in the [0,1]
interval, or a callable taking a list of values for the additional
parameters, and returning a boolean indicating whether the function
is bounded in the [0,1] interval for that combination of
parameters. Default True.
check_params: callable or None
A callable taking a list of values for the additional parameters
and returning a boolean indicating whether that combination of
parameters is valid. It is only required if the function has
additional parameters and some of them are restricted.
Default None.
"""

self.function = function
self.inverse = inverse

if callable(bounded_0_1):
self._bounded_0_1 = bounded_0_1
else:
self._bounded_0_1 = lambda x: bounded_0_1

if check_params is None:
self._check_params = lambda x: True
elif callable(check_params):
self._check_params = check_params
else:
raise ValueError("Invalid 'check_params' argument.")

def is_bounded_0_1(self, params=None):
"""
Returns a boolean indicating if the function is bounded in the [0,1]
interval for a particular set of additional parameters.
Parameters
----------
params : list
The list of additional parameters. Default None.
Returns
-------
out : bool
True if the function is bounded in the [0,1] interval for
parameters 'params'. Otherwise False.
"""

return self._bounded_0_1(params)

def check_params(self, params=None):
"""
Returns a boolean indicating if the set of additional parameters is
valid.
Parameters
----------
params : list
The list of additional parameters. Default None.
Returns
-------
out : bool
True if 'params' is a valid set of additional parameters for the
function. Otherwise False.
"""

return self._check_params(params)


class _StringFuncParser(object):
"""
A class used to convert predefined strings into
_FuncInfo objects, or to directly obtain _FuncInfo
properties.
"""

_funcs = {}
_funcs['linear'] = _FuncInfo(lambda x: x,
lambda x: x,
True)
_funcs['quadratic'] = _FuncInfo(np.square,
np.sqrt,
True)
_funcs['cubic'] = _FuncInfo(lambda x: x**3,
lambda x: x**(1. / 3),
True)
_funcs['sqrt'] = _FuncInfo(np.sqrt,
np.square,
True)
_funcs['cbrt'] = _FuncInfo(lambda x: x**(1. / 3),
lambda x: x**3,
True)
_funcs['log10'] = _FuncInfo(np.log10,
lambda x: (10**(x)),
False)
_funcs['log'] = _FuncInfo(np.log,
np.exp,
False)
_funcs['log2'] = _FuncInfo(np.log2,
lambda x: (2**x),
False)
_funcs['x**{p}'] = _FuncInfo(lambda x, p: x**p[0],
lambda x, p: x**(1. / p[0]),
True)
_funcs['root{p}(x)'] = _FuncInfo(lambda x, p: x**(1. / p[0]),
lambda x, p: x**p,
True)
_funcs['log{p}(x)'] = _FuncInfo(lambda x, p: (np.log(x) /
np.log(p[0])),
lambda x, p: p[0]**(x),
False,
lambda p: p[0] > 0)
_funcs['log10(x+{p})'] = _FuncInfo(lambda x, p: np.log10(x + p[0]),
lambda x, p: 10**x - p[0],
lambda p: p[0] > 0)
_funcs['log(x+{p})'] = _FuncInfo(lambda x, p: np.log(x + p[0]),
lambda x, p: np.exp(x) - p[0],
lambda p: p[0] > 0)
_funcs['log{p}(x+{p})'] = _FuncInfo(lambda x, p: (np.log(x + p[1]) /
np.log(p[0])),
lambda x, p: p[0]**(x) - p[1],
lambda p: p[1] > 0,
lambda p: p[0] > 0)

def __init__(self, str_func):
"""
Parameters
----------
str_func : string
String to be parsed.
"""

if not isinstance(str_func, six.string_types):
raise ValueError("'%s' must be a string." % str_func)
self._str_func = six.text_type(str_func)
self._key, self._params = self._get_key_params()
self._func = self._parse_func()

def _parse_func(self):
"""
Parses the parameters to build a new _FuncInfo object,
replacing the relevant parameters if necessary in the lambda
functions.
"""

func = self._funcs[self._key]

if not self._params:
func = _FuncInfo(func.function, func.inverse,
func.is_bounded_0_1())
else:
m = func.function
function = (lambda x, m=m: m(x, self._params))

m = func.inverse
inverse = (lambda x, m=m: m(x, self._params))

is_bounded_0_1 = func.is_bounded_0_1(self._params)

func = _FuncInfo(function, inverse,
is_bounded_0_1)
return func

@property
def func_info(self):
"""
Returns the _FuncInfo object.
"""
return self._func

@property
def function(self):
"""
Returns the callable for the direct function.
"""
return self._func.function

@property
def inverse(self):
"""
Returns the callable for the inverse function.
"""
return self._func.inverse

@property
def is_bounded_0_1(self):
"""
Returns a boolean indicating if the function is bounded
in the [0-1 interval].
"""
return self._func.is_bounded_0_1()

def _get_key_params(self):
str_func = self._str_func
# Checking if it comes with parameters
regex = '\{(.*?)\}'
params = re.findall(regex, str_func)

for i, param in enumerate(params):
try:
params[i] = float(param)
except ValueError:
raise ValueError("Parameter %i is '%s', which is "
"not a number." %
(i, param))

str_func = re.sub(regex, '{p}', str_func)

try:
func = self._funcs[str_func]
except (ValueError, KeyError):
raise ValueError("'%s' is an invalid string. The only strings "
"recognized as functions are %s." %
(str_func, list(self._funcs)))

# Checking that the parameters are valid
if not func.check_params(params):
raise ValueError("%s are invalid values for the parameters "
"in %s." %
(params, str_func))

return str_func, params
61 changes: 61 additions & 0 deletions lib/matplotlib/tests/test_cbook.py
Expand Up @@ -515,3 +515,64 @@ def test_flatiter():

assert 0 == next(it)
assert 1 == next(it)


class TestFuncParser(object):
x_test = np.linspace(0.01, 0.5, 3)
validstrings = ['linear', 'quadratic', 'cubic', 'sqrt', 'cbrt',
'log', 'log10', 'log2', 'x**{1.5}', 'root{2.5}(x)',
'log{2}(x)',
'log(x+{0.5})', 'log10(x+{0.1})', 'log{2}(x+{0.1})',
'log{2}(x+{0})']
results = [(lambda x: x),
np.square,
(lambda x: x**3),
np.sqrt,
(lambda x: x**(1. / 3)),
np.log,
np.log10,
np.log2,
(lambda x: x**1.5),
(lambda x: x**(1 / 2.5)),
(lambda x: np.log2(x)),
(lambda x: np.log(x + 0.5)),
(lambda x: np.log10(x + 0.1)),
(lambda x: np.log2(x + 0.1)),
(lambda x: np.log2(x))]

bounded_list = [True, True, True, True, True,
False, False, False, True, True,
False,
True, True, True,
False]

@pytest.mark.parametrize("string, func",
zip(validstrings, results),
ids=validstrings)
def test_values(self, string, func):
func_parser = cbook._StringFuncParser(string)
f = func_parser.function
assert_array_almost_equal(f(self.x_test), func(self.x_test))

@pytest.mark.parametrize("string", validstrings, ids=validstrings)
def test_inverse(self, string):
func_parser = cbook._StringFuncParser(string)
f = func_parser.func_info
fdir = f.function
finv = f.inverse
assert_array_almost_equal(finv(fdir(self.x_test)), self.x_test)

@pytest.mark.parametrize("string", validstrings, ids=validstrings)
def test_get_inverse(self, string):
func_parser = cbook._StringFuncParser(string)
finv1 = func_parser.inverse
finv2 = func_parser.func_info.inverse
assert_array_almost_equal(finv1(self.x_test), finv2(self.x_test))

@pytest.mark.parametrize("string, bounded",
zip(validstrings, bounded_list),
ids=validstrings)
def test_bounded(self, string, bounded):
func_parser = cbook._StringFuncParser(string)
b = func_parser.is_bounded_0_1
assert_array_equal(b, bounded)

0 comments on commit b73cedc

Please sign in to comment.