Skip to content

Commit

Permalink
Fix block scalar mangling bug #231
Browse files Browse the repository at this point in the history
The regex based parsing for fixing comments was breaking block scalars.
By using the ruyaml round trip handler, instead the comment formatting
now can correctly identify block-scalars and avoid mangling them.
  • Loading branch information
wrouesnel committed Apr 6, 2023
1 parent d75a141 commit 803493d
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 29 deletions.
165 changes: 136 additions & 29 deletions src/yamlfix/adapters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Define adapter / helper classes to hide unrelated functionality in."""

import io
import logging
import re
from functools import partial
Expand All @@ -12,6 +12,7 @@
from ruyaml.tokens import CommentToken

from yamlfix.model import YamlfixConfig, YamlNodeStyle
from yamlfix.util import walk_object

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -334,18 +335,11 @@ def __init__(self, yaml: Yaml, config: Optional[YamlfixConfig]) -> None:
self.yaml = yaml.yaml
self.config = config or YamlfixConfig()

def fix(self, source_code: str) -> str:
"""Run all yaml source code fixers.
Args:
source_code: Source code to be corrected.
Returns:
Corrected source code.
"""
log.debug("Running source code fixers...")

fixers = [
# The list of fixers to run, and the index of the currently called fixer. This allows
# fixers which might need to reinvoke other fixers to reinvoke the previous fixers
# (currently fix comments does this as it needs to re-emit the YAML).
self._fixer_idx = -1
self._fixers = [
self._fix_truthy_strings,
self._fix_jinja_variables,
self._ruamel_yaml_fixer,
Expand All @@ -359,8 +353,39 @@ def fix(self, source_code: str) -> str:
self._add_newline_at_end_of_file,
]

for fixer in fixers:
def _rerun_previous_fixers(self, source_code):
"""Run all the previous source code fixers except the currently running one.
This is re-entrant safe and will correctly restore the fixer idx once it's complete.
Args:
source_code: Source code to be corrected.
Returns:
Corrected source code.
"""
cur_fixer_idx = self._fixer_idx
for fixer_idx, fixer in enumerate(self._fixers[:cur_fixer_idx]):
self._fixer_idx = fixer_idx
source_code = fixer(source_code)
# Restore fixer idx
self._fixer_idx = cur_fixer_idx
return source_code

def fix(self, source_code: str) -> str:
"""Run all yaml source code fixers.
Args:
source_code: Source code to be corrected.
Returns:
Corrected source code.
"""
log.debug("Running source code fixers...")

# Just use the re-run system do it.
self._fixer_idx = len(self._fixers)
source_code = self._rerun_previous_fixers(source_code)

return source_code

Expand Down Expand Up @@ -590,24 +615,106 @@ def _restore_truthy_strings(source_code: str) -> str:
def _fix_comments(self, source_code: str) -> str:
log.debug("Fixing comments...")
config = self.config
comment_start = " " * config.comments_min_spaces_from_content + "#"

fixed_source_lines = []
# We need the source lines for the comment fixers to analyze whitespace easily
source_lines = source_code.splitlines()

for line in source_code.splitlines():
# Comment at the start of the line
if config.comments_require_starting_space and re.search(r"(^|\s)#\w", line):
line = line.replace("#", "# ")
# Comment in the middle of the line, but it's not part of a string
if (
config.comments_min_spaces_from_content > 1
and " #" in line
and line[-1] not in ["'", '"']
):
line = re.sub(r"(.+\S)(\s+?)#", rf"\1{comment_start}", line)
fixed_source_lines.append(line)
yaml = YAML(typ="rt")
# Hijack config options from the regular fixer
yaml.explicit_start = self.yaml.explicit_start
yaml.width = self.yaml.width
# preserve_quotes however must always be true, otherwise we change output unexpectedly.
yaml.preserve_quotes = True
yaml_documents = list(yaml.load_all(source_code))

return "\n".join(fixed_source_lines)
handled_comments = []

def _comment_token_cb(o: Any, key: Optional[Any] = None):
if not isinstance(o, CommentToken):
return
if any(o is e for e in handled_comments):
# This comment was handled at a higher level already.
return
if o.value is None:
return
comment_lines = o.value.split("\n")
fixed_comment_lines = []
for line in comment_lines:
if config.comments_require_starting_space and re.search(
r"(^|\s)#\w", line
):
line = line.replace("#", "# ")
fixed_comment_lines.append(line)

# Update the comment with the fixed lines
o.value = "\n".join(fixed_comment_lines)

if config.comments_min_spaces_from_content > 1:
# It's hard to reconstruct exactly where the content is, but since we have the line numbers
# what we do is lookup the literal source line here and check where the whitespace is compared
# to where we know the comment starts.
source_line = source_lines[o.start_mark.line]
content_part = source_line[0 : o.start_mark.column]
# Find the non-whitespace position in the content part
m = re.match(r"^.*\S", content_part)
if (
m is not None
): # If no match then nothing to do - no content to be away from
content_start, content_end = m.span()
# If there's less than min-spaces from content, we're going to add some.
if (
o.start_mark.column - content_end
< config.comments_min_spaces_from_content
):
# Handled
o.start_mark.column = (
content_end + config.comments_min_spaces_from_content
)
# Some ruyaml objects will return attached comments at multiple levels (but not all).
# Keep track of which comments we've already processed to avoid double processing them
# (important because we use raw source lines to determine content position above).
handled_comments.append(o)

def _comment_fixer(o: Any, key: Optional[Any] = None):
"""
This function is the callback for walk_object
walk_object calls it for every object it finds, and then will walk the mapping/sequence subvalues and
call this function on those too. This gives us direct access to all round tripped comments.
"""
if not hasattr(o, "ca"):
# Scalar or other object with no comment parameter.
return
# Find all comment tokens and fix them
walk_object(o.ca.comment, _comment_token_cb)
walk_object(o.ca.end, _comment_token_cb)
walk_object(o.ca.items, _comment_token_cb)
walk_object(o.ca.pre, _comment_token_cb)

# Walk the object and invoke the comment fixer
walk_object(yaml_documents, _comment_fixer)

# Dump out the YAML documents
stream = io.StringIO()
yaml.dump_all(yaml_documents, stream=stream)

# Scan the source lines for a leading "---" separator. If it's found, add it.
found_leading_document_separator = False
for line in source_lines:
stripped_line = line.strip()
if stripped_line.startswith("#"):
continue
if line.rstrip() == "---":
found_leading_document_separator = True
break
if stripped_line != "":
# Found non-comment content.
break

fixed_source_code = stream.getvalue()
# Reinvoke the previous fixers to ensure we fix the new output we just created.
fixed_source_code = self._rerun_previous_fixers(fixed_source_code)
return fixed_source_code

def _fix_whitelines(self, source_code: str) -> str:
"""Fixes number of consecutive whitelines.
Expand Down
24 changes: 24 additions & 0 deletions src/yamlfix/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, Callable, Iterable, Mapping, Optional

from typing_extensions import Protocol


class ObjectCallback(Protocol):
def __call__(self, value: Any, key: Optional[Any] = None) -> None:
...


def walk_object(o: Any, fn: ObjectCallback):
"""Walk a YAML/JSON-like object and call a function on all values"""

# Call the callback and whatever we received.
fn(o)

if isinstance(o, Mapping):
# Map type
for key, value in o.items():
walk_object(value, fn)
elif isinstance(o, Iterable) and not isinstance(o, (bytes, str)):
# List type
for idx, value in enumerate(o):
walk_object(value, fn)
45 changes: 45 additions & 0 deletions tests/unit/test_adapter_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,48 @@ def test_section_whitelines_begin_no_explicit_start(self) -> None:
result = fix_code(source, config)

assert result == fixed_source

def test_block_scalar_whitespace_is_preserved(self) -> None:
source = dedent(
"""\
---
addn_doc_key: |-
#######################################
# This would also be broken #
#######################################
---
#Comment above the key
key: |-
###########################################
# Value with lots of whitespace #
# Some More Whitespace #
###########################################
#Comment below
#Comment with some whitespace below
"""
)

fixed_source = dedent(
"""\
---
addn_doc_key: |-
#######################################
# This would also be broken #
#######################################
---
# Comment above the key
key: |-
###########################################
# Value with lots of whitespace #
# Some More Whitespace #
###########################################
# Comment below
# Comment with some whitespace below
"""
)

config = YamlfixConfig()
result = fix_code(source, config)
assert result == fixed_source
8 changes: 8 additions & 0 deletions tests/unit/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,14 @@ def test_fix_code_functions_emit_debug_logs(
"Restoring jinja2 variables...",
"Restoring double exclamations...",
"Fixing comments...",
# Fixing comments causes a re-run of fixers, so we get duplicates from here
"Fixing truthy strings...",
"Fixing jinja2 variables...",
"Running ruamel yaml fixer...",
"Restoring truthy strings...",
"Restoring jinja2 variables...",
"Restoring double exclamations...",
# End fixing comments duplicates
"Fixing top level lists...",
"Fixing flow-style lists...",
]
Expand Down

0 comments on commit 803493d

Please sign in to comment.