Skip to content

Commit

Permalink
Fix bug in f-strings not working when whitespaces surround the variab…
Browse files Browse the repository at this point in the history
…le. (Issue #6629, PR #6679)

# Description

Fix bug in f-strings where surrounding a variable with spaces would result in a variable not found error

closes #6629

# Self Check:

Strike through any lines that are not applicable (`~~line~~`) then check the box

- [ ] Attached issue to pull request
- [ ] Changelog entry
- [ ] Type annotations are present
- [ ] Code is clear and sufficiently documented
- [ ] No (preventable) type errors (check using make mypy or make mypy-diff)
- [ ] Sufficient test cases (reproduces the bug/tests the requested feature)
- [ ] Correct, in line with design
- [ ] End user documentation is included or an issue is created for end-user documentation (add ref to issue here: )
- [ ] If this PR fixes a race condition in the test suite, also push the fix to the relevant stable branche(s) (see [test-fixes](https://internal.inmanta.com/development/core/tasks/build-master.html#test-fixes) for more info)
  • Loading branch information
Hugo-Inmanta authored and inmantaci committed Nov 8, 2023
1 parent 891ec95 commit 998564a
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 41 deletions.
@@ -0,0 +1,6 @@
description: Fix bug in f-strings not working when whitespaces surround the variable.
issue-nr: 6629
change-type: patch
destination-branches: [master, iso6]
sections:
bugfix: "{{description}}"
20 changes: 13 additions & 7 deletions src/inmanta/ast/statements/assign.py
Expand Up @@ -560,7 +560,7 @@ class FormattedString(ReferenceStatement):
This class is an abstraction around a string containing references to variables.
"""

__slots__ = ("_format_string", "_variables")
__slots__ = ("_format_string",)

def __init__(self, format_string: str, variables: abc.Sequence["Reference"]) -> None:
super().__init__(variables)
Expand All @@ -578,7 +578,7 @@ class StringFormat(FormattedString):
Create a new string by doing a string interpolation
"""

__slots__ = ()
__slots__ = ("_variables",)

def __init__(self, format_string: str, variables: abc.Sequence[Tuple["Reference", str]]) -> None:
super().__init__(format_string, [k for (k, _) in variables])
Expand Down Expand Up @@ -614,14 +614,21 @@ def get_field(self, key: str, args: abc.Sequence[object], kwds: abc.Mapping[str,
class StringFormatV2(FormattedString):
"""
Create a new string by using python build in formatting
"""

__slots__ = ()
__slots__ = ("_variables",)

def __init__(self, format_string: str, variables: abc.Sequence[typing.Tuple["Reference", str]]) -> None:
"""
:param format_string: The string on which to perform substitution
:param variables: Sequence of tuples each holding a normalized reference (i.e. stripped of eventual whitespaces ) to a
variable to substitute in the format_string and the raw full name of this variable (i.e. including potential
whitespaces).
"""
only_refs: abc.Sequence["Reference"] = [k for (k, _) in variables]
super().__init__(format_string, only_refs)
self._variables = only_refs
self._variables: abc.Mapping[Reference, str] = {ref: full_name for (ref, full_name) in variables}

def execute(self, requires: typing.Dict[object, object], resolver: Resolver, queue: QueueScheduler) -> object:
super().execute(requires, resolver, queue)
Expand All @@ -630,14 +637,13 @@ def execute(self, requires: typing.Dict[object, object], resolver: Resolver, que
# We can't cache the formatter because it has no ability to cache the parsed string

kwargs = {}
for _var in self._variables:
for _var, full_name in self._variables.items():
value = _var.execute(requires, resolver, queue)
if isinstance(value, Unknown):
return Unknown(self)
if isinstance(value, float) and (value - int(value)) == 0:
value = int(value)

kwargs[_var.full_name] = value
kwargs[full_name] = value

result_string = formatter.vformat(self._format_string, args=[], kwargs=kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/inmanta/parser/plyInmantaLex.py
Expand Up @@ -67,7 +67,7 @@

def t_FSTRING(t: lex.LexToken) -> lex.LexToken: # noqa: N802
r"f(\"([^\\\"\n]|\\.)*\")|f(\'([^\\\'\n]|\\.)*\')"
t.value = t.value[2:-1]
t.value = safe_decode(token=t, warning_message="Invalid escape sequence in f-string.", start=2, end=-1)
lexer = t.lexer

end = lexer.lexpos - lexer.linestart + 1
Expand Down
68 changes: 46 additions & 22 deletions src/inmanta/parser/plyInmantaParser.py
Expand Up @@ -21,8 +21,7 @@
import string
from collections import abc
from dataclasses import dataclass
from itertools import accumulate
from typing import Iterable, Iterator, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union

import ply.yacc as yacc
from ply.yacc import YaccProduction
Expand Down Expand Up @@ -893,12 +892,14 @@ def p_constant_fstring(p: YaccProduction) -> None:

locatable_matches: List[Tuple[str, LocatableString]] = []

def locate_match(match: Tuple[str, Optional[str], Optional[str], Optional[str]]) -> None:
def locate_match(
match: Tuple[str, Optional[str], Optional[str], Optional[str]], start_char_pos: int, end_char: int
) -> None:
"""
Associates a parsed field name with a locatable string
"""
range: Range = Range(p[1].location.file, start_lnr, start_char_pos, start_lnr, end_char)
assert match[1] # make mypy happy
range: Range = Range(p[1].location.file, start_lnr, start_char_pos, start_lnr, end_char)
locatable_string = LocatableString(match[1], range, p[1].lexpos, p[1].namespace)
locatable_matches.append((match[1], locatable_string))

Expand All @@ -912,7 +913,7 @@ def locate_match(match: Tuple[str, Optional[str], Optional[str], Optional[str]])
start_char_pos += literal_text_len + brackets_length
end_char = start_char_pos + field_name_len

locate_match(match)
locate_match(match, start_char_pos, end_char)
start_char_pos += field_name_len

if match[2]:
Expand All @@ -929,7 +930,7 @@ def locate_match(match: Tuple[str, Optional[str], Optional[str], Optional[str]])
start_char_pos += literal_text_len + inner_brackets_len
end_char = start_char_pos + inner_field_name_len

locate_match(submatch)
locate_match(submatch, start_char_pos, end_char)
start_char_pos += inner_field_name_len + inner_brackets_len

start_char_pos += brackets_length
Expand Down Expand Up @@ -997,35 +998,58 @@ def convert_to_references(variables: List[Tuple[str, LocatableString]]) -> List[
(ex. LocatableString("a.b", range(a.b), lexpos, namespace))
For f-strings:
- The string is the plain variable name without brackets (ex: 'a.b')
- The string is the plain variable name without brackets (ex: 'a.b') and including any potential whitespaces.
- The LocatableString is the same as for regular string interpolation
:returns: A tuple where all LocatableString have been converted to Reference. The matching str holding the variable
name is left untouched
:returns: A tuple where all LocatableString have been converted to Reference. These references are cleaned up of any
potential whitespace character. The matching str holding the variable name is left untouched i.e. will still contain
potential whitespace characters.
"""

def normalize(variable: str, locatable: LocatableString, offset: int = 0) -> LocatableString:
"""
Strip a variable of potential whitespaces and compute the locatable string.
:param variable: String representation for a plain variable or composite part of a variable
including potential whitespace.
:param locatable: LocatableString associated to this variable.
:param offset: Used when normalizing a subpart of a composite variable (e.g. 'a.b.c') to track where the current
subpart starts.
"""
start_char = locatable.location.start_char + offset
end_char = start_char + len(variable)

variable_left_trim = variable.lstrip()
left_spaces: int = len(variable) - len(variable_left_trim)
variable_full_trim = variable_left_trim.rstrip()
right_spaces: int = len(variable_left_trim) - len(variable_full_trim)

range: Range = Range(
locatable.location.file,
locatable.location.lnr,
start_char + left_spaces,
locatable.location.lnr,
end_char - right_spaces,
)
return LocatableString(variable_full_trim, range, locatable.lexpos, locatable.namespace)

assert namespace
_vars: List[Tuple[Reference, str]] = []
for match, var in variables:
var_name: str = str(var)
var_parts: List[str] = var_name.split(".")
start_char = var.location.start_char
end_char = start_char + len(var_parts[0])
range: Range = Range(var.location.file, var.location.lnr, start_char, var.location.lnr, end_char)
ref_locatable_string = LocatableString(var_parts[0], range, var.lexpos, var.namespace)

ref_locatable_string: LocatableString = normalize(var_parts[0], var)

ref = Reference(ref_locatable_string)
ref.location = ref_locatable_string.location
ref.namespace = namespace
if len(var_parts) > 1:
attribute_offsets: Iterator[int] = accumulate(
var_parts[1:], lambda acc, part: acc + len(part) + 1, initial=end_char + 1
)
for attr, char_offset in zip(var_parts[1:], attribute_offsets):
range_attr: Range = Range(
var.location.file, var.location.lnr, char_offset, var.location.lnr, char_offset + len(attr)
)
attr_locatable_string: LocatableString = LocatableString(attr, range_attr, var.lexpos, var.namespace)
offset = len(var_parts[0]) + 1
for attr in var_parts[1:]:
attr_locatable_string: LocatableString = normalize(attr, var, offset=offset)
ref = AttributeReference(ref, attr_locatable_string)
ref.location = range_attr
ref.location = attr_locatable_string.location
ref.namespace = namespace
offset += len(attr) + 1
# For a composite variable e.g. 'a.b.c', we only add the reference to the innermost attribute (e.g. 'c')
_vars.append((ref, match))
else:
Expand Down
56 changes: 45 additions & 11 deletions tests/compiler/test_strings.py
Expand Up @@ -29,7 +29,7 @@ def test_multiline_string_interpolation(snippetcompiler):
snippetcompiler.setup_for_snippet(
"""
var = 42
str = \"\"\"var == {{var}}\"\"\"
str = \"\"\"var == {{ var }}\"\"\"
""",
)
(_, scopes) = compiler.do_compile()
Expand Down Expand Up @@ -182,10 +182,12 @@ def test_fstring_float_formatting(snippetcompiler, capsys):
@pytest.mark.parametrize(
"f_string,expected_output",
[
(r"f'{ arg }'", "123\n"),
(r"f'{arg}'", "123\n"),
(r"f'{arg}{arg}{arg}'", "123123123\n"),
(r"f'{arg:@>5}'", "@@123\n"),
(r"f'{arg:^5}'", " 123 \n"),
(r"f' { \t\narg \n } '", " 123 \n"),
],
)
def test_fstring_formatting(snippetcompiler, capsys, f_string, expected_output):
Expand Down Expand Up @@ -235,7 +237,7 @@ def test_fstring_relations(snippetcompiler, capsys):
b = B(c=c)
c = C()
std::print(f"{a.b.c.n_c}")
std::print(f"{ a .b . c . n_c }")
"""
)

Expand All @@ -245,22 +247,58 @@ def test_fstring_relations(snippetcompiler, capsys):
assert out == expected_output


def check_range(variable: Union[Reference, AttributeReference], start: int, end: int):
assert variable.location.start_char == start, f"{variable=} expected {start=} got {variable.location.start_char=}"
assert variable.location.end_char == end, f"{variable=} expected {end=} got {variable.location.end_char=}"


def test_fstring_numbering_logic():
"""
Check that variable ranges in f-strings are correctly computed
"""
statements = parse_code(
"""
std::print(f"---{s}{mm} - {sub.attr}")
# 10 20 30 40 50 60 70 80
# | | | | | | | |
std::print(f"---{s}{mm} - {sub.attr} - { padded } - { \tpadded.sub.attr }")
# | | | | |
# [-][--] [----] [------] [----] <--- expected ranges
"""
)

def check_range(variable: Union[Reference, AttributeReference], start: int, end: int):
assert variable.location.start_char == start
assert variable.location.end_char == end

# Ranges are 1-indexed [start:end[
ranges = [
(len('std::print(f"---{s'), len('std::print(f"---{s}')),
(len('std::print(f"---{s}{m'), len('std::print(f"---{s}{mm}')),
(len('std::print(f"---{s}{mm} - {sub.a'), len('std::print(f"---{s}{mm} - {sub.attr}')),
(len('std::print(f"---{s}{mm} - {sub.attr} - { p'), len('std::print(f"---{s}{mm} - {sub.attr} - { padded ')),
(
len('std::print(f"---{s}{mm} - {sub.attr} - { padded } - { \tpadded.sub.a'),
len('std::print(f"---{s}{mm} - {sub.attr} - { padded } - { \tpadded.sub.attr '),
),
]
variables = statements[0].children[0]._variables

for var, range in zip(variables, ranges):
check_range(var, *range)


def test_fstring_numbering_logic_multiple_refs():
"""
Check that variable ranges in f-strings are correctly computed
"""
statements = parse_code(
"""
std::print(f"---{s}----{s}")
# | |
# [-] [-] <--- expected ranges
"""
)

# Ranges are 1-indexed [start:end[
ranges = [
(len('std::print(f"---{s'), len('std::print(f"---{s}')),
(len('std::print(f"---{s}----{s'), len('std::print(f"---{s}----{s}')),
]
variables = statements[0].children[0]._variables

Expand Down Expand Up @@ -305,10 +343,6 @@ def test_fstring_numbering_logic_complex():
"""
)

def check_range(variable: Union[Reference, AttributeReference], start: int, end: int):
assert variable.location.start_char == start, print(variable)
assert variable.location.end_char == end, print(variable)

# Ranges are 1-indexed [start:end[
ranges = [
(len('std::print(f"-{a'), len('std::print(f"-{arg:')),
Expand Down

0 comments on commit 998564a

Please sign in to comment.