Skip to content

Commit

Permalink
Merge pull request #613 from minrk/from-string-none
Browse files Browse the repository at this point in the history
handle allow_none in from_string
  • Loading branch information
Carreau committed Sep 3, 2020
2 parents d648d6d + c378d0c commit efc76b8
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 61 deletions.
9 changes: 9 additions & 0 deletions traitlets/config/tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,15 @@ class TestApp(Application):
assert app.config.TestApp.value == 'cli'
assert app.value == 'cli'

def test_cli_allow_none(self):
class App(Application):
aliases = {"opt": "App.opt"}
opt = Unicode(allow_none=True, config=True)

app = App()
app.parse_command_line(["--opt=None"])
assert app.opt is None

def test_flags(self):
app = MyApp()
app.parse_command_line(["--disable"])
Expand Down
148 changes: 87 additions & 61 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2658,7 +2658,7 @@ def _from_string_test(traittype, s, expected):
if isinstance(traittype, TraitType):
trait = traittype
else:
trait = traittype()
trait = traittype(allow_none=True)
if isinstance(s, list):
cast = trait.from_string_list
else:
Expand All @@ -2673,107 +2673,133 @@ def _from_string_test(traittype, s, expected):


@pytest.mark.parametrize(
"s, expected", [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"),]
"s, expected",
[("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
)
def test_unicode_from_string(s, expected):
_from_string_test(Unicode, s, expected)


@pytest.mark.parametrize(
"s, expected", [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"),]
"s, expected",
[("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
)
def test_cunicode_from_string(s, expected):
_from_string_test(CUnicode, s, expected)


@pytest.mark.parametrize('s, expected', [
('x', ValueError),
('1', 1),
('123', 123),
('2.0', ValueError),
])
@pytest.mark.parametrize(
"s, expected",
[("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)],
)
def test_int_from_string(s, expected):
_from_string_test(Integer, s, expected)


@pytest.mark.parametrize('s, expected', [
('x', ValueError),
('1', 1.0),
('123.5', 123.5),
('2.5', 2.5),
])
@pytest.mark.parametrize(
"s, expected",
[("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)],
)
def test_float_from_string(s, expected):
_from_string_test(Float, s, expected)


@pytest.mark.parametrize('s, expected', [
('x', ValueError),
('1', 1.0),
('123.5', 123.5),
('2.5', 2.5),
('1+2j', 1+2j),
])
@pytest.mark.parametrize(
"s, expected",
[
("x", ValueError),
("1", 1.0),
("123.5", 123.5),
("2.5", 2.5),
("1+2j", 1 + 2j),
("None", None),
],
)
def test_complex_from_string(s, expected):
_from_string_test(Complex, s, expected)


@pytest.mark.parametrize('s, expected', [
('true', True),
('TRUE', True),
('1', True),
('0', False),
('False', False),
('false', False),
('1.0', ValueError),
])
@pytest.mark.parametrize(
"s, expected",
[
("true", True),
("TRUE", True),
("1", True),
("0", False),
("False", False),
("false", False),
("1.0", ValueError),
("None", None),
],
)
def test_bool_from_string(s, expected):
_from_string_test(Bool, s, expected)


@pytest.mark.parametrize('s, expected', [
('{}', {}),
('1', TraitError),
('{1: 2}', {1: 2}),
('{"key": "value"}', {"key": "value"}),
('x', TraitError),
])
@pytest.mark.parametrize(
"s, expected",
[
("{}", {}),
("1", TraitError),
("{1: 2}", {1: 2}),
('{"key": "value"}', {"key": "value"}),
("x", TraitError),
("None", None),
],
)
def test_dict_from_string(s, expected):
_from_string_test(Dict, s, expected)


@pytest.mark.parametrize('s, expected', [
('[]', []),
('[1, 2, "x"]', [1, 2, 'x']),
(["1", "x"], ["1", "x"])
])
@pytest.mark.parametrize(
"s, expected",
[
("[]", []),
('[1, 2, "x"]', [1, 2, "x"]),
(["1", "x"], ["1", "x"]),
(["None"], None),
],
)
def test_list_from_string(s, expected):
_from_string_test(List, s, expected)


@pytest.mark.parametrize('s, expected, value_trait', [
(["1", "2", "3"], [1, 2, 3], Integer()),
(["x"], ValueError, Integer()),
(["1", "x"], ["1", "x"], Unicode())
])
def test_list_from_string(s, expected, value_trait):
@pytest.mark.parametrize(
"s, expected, value_trait",
[
(["1", "2", "3"], [1, 2, 3], Integer()),
(["x"], ValueError, Integer()),
(["1", "x"], ["1", "x"], Unicode()),
(["None"], [None], Unicode(allow_none=True)),
],
)
def test_list_items_from_string(s, expected, value_trait):
_from_string_test(List(value_trait), s, expected)


@pytest.mark.parametrize('s, expected', [
('x', 'x'),
('mod.submod', 'mod.submod'),
('not an identifier', TraitError),
('1', '1'),
])
@pytest.mark.parametrize(
"s, expected",
[
("x", "x"),
("mod.submod", "mod.submod"),
("not an identifier", TraitError),
("1", "1"),
("None", None),
],
)
def test_object_from_string(s, expected):
_from_string_test(DottedObjectName, s, expected)


@pytest.mark.parametrize('s, expected', [
('127.0.0.1:8000', ('127.0.0.1', 8000)),
('host.tld:80', ('host.tld', 80)),
('host:notaport', ValueError),
('127.0.0.1', ValueError),
])
@pytest.mark.parametrize(
"s, expected",
[
("127.0.0.1:8000", ("127.0.0.1", 8000)),
("host.tld:80", ("host.tld", 80)),
("host:notaport", ValueError),
("127.0.0.1", ValueError),
("None", None),
],
)
def test_tcp_from_string(s, expected):
_from_string_test(TCPAddress, s, expected)
22 changes: 22 additions & 0 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def from_string(self, s):
.. versionadded:: 5.0
"""
if self.allow_none and s == 'None':
return None
return s

def default(self, obj=None):
Expand Down Expand Up @@ -2038,6 +2040,8 @@ def validate(self, obj, value):
return _validate_bounds(self, obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
return int(s)


Expand Down Expand Up @@ -2076,6 +2080,8 @@ def validate(self, obj, value):
return _validate_bounds(self, obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
return float(s)


Expand Down Expand Up @@ -2104,6 +2110,8 @@ def validate(self, obj, value):
self.error(obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
return complex(s)


Expand Down Expand Up @@ -2131,6 +2139,8 @@ def validate(self, obj, value):
self.error(obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
return s.encode('utf8')


Expand Down Expand Up @@ -2162,6 +2172,8 @@ def validate(self, obj, value):
self.error(obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
s = os.path.expanduser(s)
if len(s) >= 2:
# handle deprecated "1"
Expand Down Expand Up @@ -2203,6 +2215,8 @@ def validate(self, obj, value):
self.error(obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
return s

class DottedObjectName(ObjectName):
Expand Down Expand Up @@ -2233,6 +2247,8 @@ def validate(self, obj, value):
self.error(obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
s = s.lower()
if s in {'true', '1'}:
return True
Expand Down Expand Up @@ -2487,6 +2503,8 @@ def from_string_list(self, s_list):
if len(s_list) == 1:
# check for deprecated --Class.trait="['a', 'b', 'c']"
r = s_list[0]
if r == "None" and self.allow_none:
return None
if (
(r[0] == '[' and r[-1] == ']') or
(r[0] == '(' and r[-1] == ')')
Expand Down Expand Up @@ -2897,6 +2915,8 @@ def from_string_list(self, s_list):
This is where we parse CLI configuration
"""
if len(s_list) == 1 and s_list[0] == "None" and self.allow_none:
return None
if (
len(s_list) == 1
and s_list[0].startswith("{")
Expand Down Expand Up @@ -2965,6 +2985,8 @@ def validate(self, obj, value):
self.error(obj, value)

def from_string(self, s):
if self.allow_none and s == 'None':
return None
if ':' not in s:
raise ValueError('Require `ip:port`, got %r' % s)
ip, port = s.split(':', 1)
Expand Down

0 comments on commit efc76b8

Please sign in to comment.