diff --git a/torchx/specs/api.py b/torchx/specs/api.py index a83104991..c28212bc5 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -951,11 +951,11 @@ def cfg_from_str(self, cfg_str: str) -> Dict[str, CfgVal]: def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal: if opt_type == bool: return value.lower() == "true" - elif opt_type == List[str]: + elif opt_type in (List[str], list[str]): # lists may be ; or , delimited # also deal with trailing "," by removing empty strings return [v for v in value.replace(";", ",").split(",") if v] - elif opt_type == Dict[str, str]: + elif opt_type in (Dict[str, str], dict[str, str]): return { s.split(":", 1)[0]: s.split(":", 1)[1] for s in value.replace(";", ",").split(",") diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 9a251c7d7..e02898943 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -566,6 +566,41 @@ def test_cfg_from_str(self) -> None: {"E": {"f": "b", "F": "B"}}, opts.cfg_from_str("E=f:b,F:B") ) + def test_cfg_from_str_builtin_generic_types(self) -> None: + # basically a repeat of "test_cfg_from_str()" but with + # list[str] and dict[str, str] instead of List[str] and Dict[str, str] + opts = runopts() + opts.add("K", type_=list[str], help="a list opt", default=[]) + opts.add("J", type_=str, help="a str opt", required=True) + opts.add("E", type_=dict[str, str], help="a dict opt", default=[]) + + self.assertDictEqual({}, opts.cfg_from_str("")) + self.assertDictEqual({}, opts.cfg_from_str("UNKWN=b")) + self.assertDictEqual({"K": ["a"], "J": "b"}, opts.cfg_from_str("K=a,J=b")) + self.assertDictEqual({"K": ["a"]}, opts.cfg_from_str("K=a,UNKWN=b")) + self.assertDictEqual({"K": ["a", "b"]}, opts.cfg_from_str("K=a,b")) + self.assertDictEqual({"K": ["a", "b"]}, opts.cfg_from_str("K=a;b")) + self.assertDictEqual({"K": ["a", "b"]}, opts.cfg_from_str("K=a,b")) + self.assertDictEqual({"K": ["a", "b"]}, opts.cfg_from_str("K=a,b;")) + self.assertDictEqual( + {"K": ["a", "b"], "J": "d"}, opts.cfg_from_str("K=a,b,J=d") + ) + self.assertDictEqual( + {"K": ["a", "b"], "J": "d"}, opts.cfg_from_str("K=a,b;J=d") + ) + self.assertDictEqual( + {"K": ["a", "b"], "J": "d"}, opts.cfg_from_str("K=a;b,J=d") + ) + self.assertDictEqual( + {"K": ["a", "b"], "J": "d"}, opts.cfg_from_str("K=a;b;J=d") + ) + self.assertDictEqual( + {"K": ["a"], "J": "d"}, opts.cfg_from_str("J=d,K=a,UNKWN=e") + ) + self.assertDictEqual( + {"E": {"f": "b", "F": "B"}}, opts.cfg_from_str("E=f:b,F:B") + ) + def test_resolve_from_str(self) -> None: opts = runopts() opts.add("foo", type_=str, default="", help="")