Skip to content

Commit

Permalink
Add support for TOML config files (#289)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrey Klochkov <aklochkov@liftoff.io>
  • Loading branch information
diggerk and andrey-klochkov-liftoff committed Oct 4, 2023
1 parent 77e4ff5 commit bccb312
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# "torch",
],
"yaml": ["pyyaml"],
"toml": ["tomli", "tomli_w"],
}
extras_require["all"] = list(set(sum(extras_require.values(), [])))

Expand Down
15 changes: 15 additions & 0 deletions simple_parsing/helpers/serialization/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ def dump(self, obj: Any, io: IO, **kwargs) -> None:
return torch.save(obj, io, **kwargs)


class TOMLExtension(FormatExtension):
binary: bool = True

def load(self, io: IO) -> Any:
import tomli

return tomli.load(io)

def dump(self, obj: Any, io: IO, **kwargs) -> None:
import tomli_w

return tomli_w.dump(obj, io, **kwargs)


json_extension = JSONExtension()
yaml_extension = YamlExtension()

Expand All @@ -136,6 +150,7 @@ def dump(self, obj: Any, io: IO, **kwargs) -> None:
".yml": YamlExtension(),
".npy": NumpyExtension(),
".pth": TorchExtension(),
".toml": TOMLExtension(),
}


Expand Down
9 changes: 9 additions & 0 deletions test/helpers/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ def test_save_torch(tmpdir: Path):

_hparams = HyperParameters.load(tmp_path)
assert hparams == _hparams


def test_save_toml(tmpdir: Path):
hparams = HyperParameters.setup("")
tmp_path = Path(tmpdir / "temp.toml")
hparams.save(tmp_path)

_hparams = HyperParameters.load(tmp_path)
assert hparams == _hparams

0 comments on commit bccb312

Please sign in to comment.