diff --git a/setup.py b/setup.py index 87cd64d8..186af475 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ # "torch", ], "yaml": ["pyyaml"], + "toml": ["tomli", "tomli_w"], } extras_require["all"] = list(set(sum(extras_require.values(), []))) diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index e935ed89..64487c92 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -125,6 +125,20 @@ def dump(self, obj: Any, io: IO, **kwargs) -> None: return torch.save(obj, io, **kwargs) +class TOMLExtension(FormatExtension): + binary: bool = True + + def load(self, io: IO) -> Any: + import tomli + + return tomli.load(io) + + def dump(self, obj: Any, io: IO, **kwargs) -> None: + import tomli_w + + return tomli_w.dump(obj, io, **kwargs) + + json_extension = JSONExtension() yaml_extension = YamlExtension() @@ -136,6 +150,7 @@ def dump(self, obj: Any, io: IO, **kwargs) -> None: ".yml": YamlExtension(), ".npy": NumpyExtension(), ".pth": TorchExtension(), + ".toml": TOMLExtension(), } diff --git a/test/helpers/test_save.py b/test/helpers/test_save.py index 4d43b7dc..89b4acdc 100644 --- a/test/helpers/test_save.py +++ b/test/helpers/test_save.py @@ -63,3 +63,12 @@ def test_save_torch(tmpdir: Path): _hparams = HyperParameters.load(tmp_path) assert hparams == _hparams + + +def test_save_toml(tmpdir: Path): + hparams = HyperParameters.setup("") + tmp_path = Path(tmpdir / "temp.toml") + hparams.save(tmp_path) + + _hparams = HyperParameters.load(tmp_path) + assert hparams == _hparams