diff --git a/src/yamlfix/adapters.py b/src/yamlfix/adapters.py index 189294c..2453ad7 100644 --- a/src/yamlfix/adapters.py +++ b/src/yamlfix/adapters.py @@ -1,5 +1,5 @@ """Define adapter / helper classes to hide unrelated functionality in.""" - +import io import logging import re from functools import partial @@ -12,6 +12,7 @@ from ruyaml.tokens import CommentToken from yamlfix.model import YamlfixConfig, YamlNodeStyle +from yamlfix.util import walk_object log = logging.getLogger(__name__) @@ -334,18 +335,11 @@ def __init__(self, yaml: Yaml, config: Optional[YamlfixConfig]) -> None: self.yaml = yaml.yaml self.config = config or YamlfixConfig() - def fix(self, source_code: str) -> str: - """Run all yaml source code fixers. - - Args: - source_code: Source code to be corrected. - - Returns: - Corrected source code. - """ - log.debug("Running source code fixers...") - - fixers = [ + # The list of fixers to run, and the index of the currently called fixer. This allows + # fixers which might need to reinvoke other fixers to reinvoke the previous fixers + # (currently fix comments does this as it needs to re-emit the YAML). + self._fixer_idx = -1 + self._fixers = [ self._fix_truthy_strings, self._fix_jinja_variables, self._ruamel_yaml_fixer, @@ -359,8 +353,39 @@ def fix(self, source_code: str) -> str: self._add_newline_at_end_of_file, ] - for fixer in fixers: + def _rerun_previous_fixers(self, source_code): + """Run all the previous source code fixers except the currently running one. + + This is re-entrant safe and will correctly restore the fixer idx once it's complete. + + Args: + source_code: Source code to be corrected. + + Returns: + Corrected source code. + """ + cur_fixer_idx = self._fixer_idx + for fixer_idx, fixer in enumerate(self._fixers[:cur_fixer_idx]): + self._fixer_idx = fixer_idx source_code = fixer(source_code) + # Restore fixer idx + self._fixer_idx = cur_fixer_idx + return source_code + + def fix(self, source_code: str) -> str: + """Run all yaml source code fixers. + + Args: + source_code: Source code to be corrected. + + Returns: + Corrected source code. + """ + log.debug("Running source code fixers...") + + # Just use the re-run system do it. + self._fixer_idx = len(self._fixers) + source_code = self._rerun_previous_fixers(source_code) return source_code @@ -590,24 +615,106 @@ def _restore_truthy_strings(source_code: str) -> str: def _fix_comments(self, source_code: str) -> str: log.debug("Fixing comments...") config = self.config - comment_start = " " * config.comments_min_spaces_from_content + "#" - fixed_source_lines = [] + # We need the source lines for the comment fixers to analyze whitespace easily + source_lines = source_code.splitlines() - for line in source_code.splitlines(): - # Comment at the start of the line - if config.comments_require_starting_space and re.search(r"(^|\s)#\w", line): - line = line.replace("#", "# ") - # Comment in the middle of the line, but it's not part of a string - if ( - config.comments_min_spaces_from_content > 1 - and " #" in line - and line[-1] not in ["'", '"'] - ): - line = re.sub(r"(.+\S)(\s+?)#", rf"\1{comment_start}", line) - fixed_source_lines.append(line) + yaml = YAML(typ="rt") + # Hijack config options from the regular fixer + yaml.explicit_start = self.yaml.explicit_start + yaml.width = self.yaml.width + # preserve_quotes however must always be true, otherwise we change output unexpectedly. + yaml.preserve_quotes = True + yaml_documents = list(yaml.load_all(source_code)) - return "\n".join(fixed_source_lines) + handled_comments = [] + + def _comment_token_cb(o: Any, key: Optional[Any] = None): + if not isinstance(o, CommentToken): + return + if any(o is e for e in handled_comments): + # This comment was handled at a higher level already. + return + if o.value is None: + return + comment_lines = o.value.split("\n") + fixed_comment_lines = [] + for line in comment_lines: + if config.comments_require_starting_space and re.search( + r"(^|\s)#\w", line + ): + line = line.replace("#", "# ") + fixed_comment_lines.append(line) + + # Update the comment with the fixed lines + o.value = "\n".join(fixed_comment_lines) + + if config.comments_min_spaces_from_content > 1: + # It's hard to reconstruct exactly where the content is, but since we have the line numbers + # what we do is lookup the literal source line here and check where the whitespace is compared + # to where we know the comment starts. + source_line = source_lines[o.start_mark.line] + content_part = source_line[0 : o.start_mark.column] + # Find the non-whitespace position in the content part + m = re.match(r"^.*\S", content_part) + if ( + m is not None + ): # If no match then nothing to do - no content to be away from + content_start, content_end = m.span() + # If there's less than min-spaces from content, we're going to add some. + if ( + o.start_mark.column - content_end + < config.comments_min_spaces_from_content + ): + # Handled + o.start_mark.column = ( + content_end + config.comments_min_spaces_from_content + ) + # Some ruyaml objects will return attached comments at multiple levels (but not all). + # Keep track of which comments we've already processed to avoid double processing them + # (important because we use raw source lines to determine content position above). + handled_comments.append(o) + + def _comment_fixer(o: Any, key: Optional[Any] = None): + """ + This function is the callback for walk_object + + walk_object calls it for every object it finds, and then will walk the mapping/sequence subvalues and + call this function on those too. This gives us direct access to all round tripped comments. + """ + if not hasattr(o, "ca"): + # Scalar or other object with no comment parameter. + return + # Find all comment tokens and fix them + walk_object(o.ca.comment, _comment_token_cb) + walk_object(o.ca.end, _comment_token_cb) + walk_object(o.ca.items, _comment_token_cb) + walk_object(o.ca.pre, _comment_token_cb) + + # Walk the object and invoke the comment fixer + walk_object(yaml_documents, _comment_fixer) + + # Dump out the YAML documents + stream = io.StringIO() + yaml.dump_all(yaml_documents, stream=stream) + + # Scan the source lines for a leading "---" separator. If it's found, add it. + found_leading_document_separator = False + for line in source_lines: + stripped_line = line.strip() + if stripped_line.startswith("#"): + continue + if line.rstrip() == "---": + found_leading_document_separator = True + break + if stripped_line != "": + # Found non-comment content. + break + + fixed_source_code = stream.getvalue() + # Reinvoke the previous fixers to ensure we fix the new output we just created. + fixed_source_code = self._rerun_previous_fixers(fixed_source_code) + return fixed_source_code def _fix_whitelines(self, source_code: str) -> str: """Fixes number of consecutive whitelines. diff --git a/src/yamlfix/util.py b/src/yamlfix/util.py new file mode 100644 index 0000000..1f8cf54 --- /dev/null +++ b/src/yamlfix/util.py @@ -0,0 +1,24 @@ +from typing import Any, Callable, Iterable, Mapping, Optional + +from typing_extensions import Protocol + + +class ObjectCallback(Protocol): + def __call__(self, value: Any, key: Optional[Any] = None) -> None: + ... + + +def walk_object(o: Any, fn: ObjectCallback): + """Walk a YAML/JSON-like object and call a function on all values""" + + # Call the callback and whatever we received. + fn(o) + + if isinstance(o, Mapping): + # Map type + for key, value in o.items(): + walk_object(value, fn) + elif isinstance(o, Iterable) and not isinstance(o, (bytes, str)): + # List type + for idx, value in enumerate(o): + walk_object(value, fn) diff --git a/tests/unit/test_adapter_yaml.py b/tests/unit/test_adapter_yaml.py index 24d66b8..4ee7795 100644 --- a/tests/unit/test_adapter_yaml.py +++ b/tests/unit/test_adapter_yaml.py @@ -830,3 +830,48 @@ def test_section_whitelines_begin_no_explicit_start(self) -> None: result = fix_code(source, config) assert result == fixed_source + + def test_block_scalar_whitespace_is_preserved(self) -> None: + source = dedent( + """\ + --- + addn_doc_key: |- + ####################################### + # This would also be broken # + ####################################### + --- + #Comment above the key + key: |- + ########################################### + # Value with lots of whitespace # + # Some More Whitespace # + ########################################### + #Comment below + + #Comment with some whitespace below + """ + ) + + fixed_source = dedent( + """\ + --- + addn_doc_key: |- + ####################################### + # This would also be broken # + ####################################### + --- + # Comment above the key + key: |- + ########################################### + # Value with lots of whitespace # + # Some More Whitespace # + ########################################### + # Comment below + + # Comment with some whitespace below + """ + ) + + config = YamlfixConfig() + result = fix_code(source, config) + assert result == fixed_source diff --git a/tests/unit/test_services.py b/tests/unit/test_services.py index 125b633..6053658 100644 --- a/tests/unit/test_services.py +++ b/tests/unit/test_services.py @@ -481,6 +481,14 @@ def test_fix_code_functions_emit_debug_logs( "Restoring jinja2 variables...", "Restoring double exclamations...", "Fixing comments...", + # Fixing comments causes a re-run of fixers, so we get duplicates from here + "Fixing truthy strings...", + "Fixing jinja2 variables...", + "Running ruamel yaml fixer...", + "Restoring truthy strings...", + "Restoring jinja2 variables...", + "Restoring double exclamations...", + # End fixing comments duplicates "Fixing top level lists...", "Fixing flow-style lists...", ]