Skip to content

Commit

Permalink
Improved version call normalizer
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed Mar 29, 2024
1 parent cf122bb commit 0df728c
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Unreleased

- Improved version call normalizer
- Prevent recursions in self.future() calls

# v4.10.1
Expand Down
3 changes: 2 additions & 1 deletion src/machinable/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from machinable.utils import (
Jsonable,
id_from_uuid,
norm_version_call,
sentinel,
serialize,
unflatten_dict,
Expand Down Expand Up @@ -100,7 +101,7 @@ def _norm(item):
OmegaConf.to_container(OmegaConf.create(item)), recursive=False
)
if isinstance(item, str) and "~" in item:
return item.strip().replace("\n", "").replace(" ", "")
return norm_version_call(item)

return item

Expand Down
57 changes: 57 additions & 0 deletions src/machinable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import stat
import sys
import tokenize
from io import BytesIO

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -457,6 +459,61 @@ def update_dict(
return d


def norm_version_call(version: str):
code = version.replace("~", "").strip()
tokens = tokenize.tokenize(BytesIO(code.encode("utf-8")).readline)
normalized_tokens = []

last_token_type = None
for toknum, tokval, _, _, _ in tokens:
if toknum in [
tokenize.ENCODING,
tokenize.ENDMARKER,
tokenize.NEWLINE,
tokenize.NL,
]:
# skip encoding, endmarker, newlines
continue
if toknum == tokenize.STRING:
# string literals are preserved
normalized_tokens.append(tokval)
elif toknum == tokenize.OP:
# remove spaces before and after operators
if (
tokval == "="
and normalized_tokens
and normalized_tokens[-1] == " "
):
normalized_tokens.pop()
normalized_tokens.append(tokval)
elif toknum in [tokenize.NAME, tokenize.NUMBER]:
if last_token_type not in [
None,
tokenize.OP,
tokenize.NL,
tokenize.NEWLINE,
] and not (
last_token_type == tokenize.OP
and normalized_tokens[-1] in ["(", ","]
):
normalized_tokens.append(" ")
normalized_tokens.append(tokval)
last_token_type = toknum

if "~" in version:
normalized_tokens = ["~"] + normalized_tokens

normalized_code = "".join(normalized_tokens)
normalized_code = (
normalized_code.replace(" (", "(")
.replace("( ", "(")
.replace(" )", ")")
.replace(", ", ",")
.replace(" ,", ",")
)
return normalized_code


def dot_splitter(flat_key):
if not isinstance(flat_key, str):
raise ValueError(
Expand Down
16 changes: 16 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,19 @@ def test_file_hash(tmp_path):
assert utils.file_hash(tmp_path / "test.txt") == "a71079d42853"
with pytest.raises(FileNotFoundError):
utils.file_hash("not-existing")


@pytest.mark.parametrize(
"input_code, expected",
[
("foo(1, bar=2)", "foo(1,bar=2)"),
("foo( 1, bar = 2 )", "foo(1,bar=2)"),
("\nfoo(\n1,\nbar = 2\n)\n", "foo(1,bar=2)"),
("foo(' hello ', bar=2)", "foo(' hello ',bar=2)"),
("foo(bar= 'world', baz =3)", "foo(bar='world',baz=3)"),
("~foo(bar= ' world',)", "~foo(bar=' world',)"),
(" ~foo(bar= 'world')", "~foo(bar='world')"),
],
)
def test_norm_version_call(input_code, expected):
assert utils.norm_version_call(input_code) == expected

0 comments on commit 0df728c

Please sign in to comment.