Skip to content

Commit

Permalink
fix: make fix_code respect if TYPE_CHECKING statements
Browse files Browse the repository at this point in the history
  • Loading branch information
lyz-code committed Dec 18, 2020
1 parent b6694e9 commit 8360d6e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 8 deletions.
63 changes: 55 additions & 8 deletions src/autoimport/model.py
Expand Up @@ -36,20 +36,37 @@ def __init__(self, source_code: str) -> None:
"""Initialize the object."""
self.docstring: List[str] = []
self.imports: List[str] = []
self.typing: List[str] = []
self.code: List[str] = []
self._split_code(source_code)

def _split_code(self, source_code: str) -> None:
"""Split the source code in docstring, import statements and code.
"""Split the source code in the different sections.
* Module Docstring
* Import statements
* Typing statements
* Code.
Args:
source_code: Source code to be corrected.
"""
source_code_lines = source_code.splitlines()

self._extract_docstring(source_code_lines)
self._extract_import_statements(source_code_lines)
self._extract_typing_statements(source_code_lines)
self._extract_code(source_code_lines)

def _extract_docstring(self, source_lines: List[str]) -> None:
"""Save the module docstring from the source code into self.docstring.
Args:
source_lines: A list containing all code lines.
"""
docstring_type: Optional[str] = None

# Extract the module docstring from the code.
for line in source_code_lines:
for line in source_lines:
if re.match(r'"{3}.*"{3}', line):
# Match single line docstrings
self.docstring.append(line)
Expand All @@ -64,11 +81,18 @@ def _split_code(self, source_code: str) -> None:
break
self.docstring.append(line)

# Extract the import lines from the code.
def _extract_import_statements(self, source_lines: List[str]) -> None:
"""Save the import statements from the source code into self.imports.
Args:
source_lines: A list containing all code lines.
"""
import_start_line = len(self.docstring)
multiline_import = False

for line in source_code_lines[import_start_line:]:
for line in source_lines[import_start_line:]:
if re.match(r"^if TYPE_CHECKING:$", line):
break
if (
re.match(r"^\s*(from .*)?import.[^\'\"]*$", line)
or line == ""
Expand All @@ -83,9 +107,31 @@ def _split_code(self, source_code: str) -> None:
else:
break

def _extract_typing_statements(self, source_lines: List[str]) -> None:
"""Save the typing statements from the source code into self.typing.
Args:
source_lines: A list containing all code lines.
"""
typing_start_line = len(self.docstring) + len(self.imports)

if re.match(r"^if TYPE_CHECKING:$", source_lines[typing_start_line]):
self.typing.append(source_lines[typing_start_line])
typing_start_line += 1
for line in source_lines[typing_start_line:]:
if not re.match(r"^\s+.*", line):
break
self.typing.append(line)

def _extract_code(self, source_lines: List[str]) -> None:
"""Save the code from the source code into self.code.
Args:
source_lines: A list containing all code lines.
"""
# Extract the code lines
code_start_line = len(self.docstring) + len(self.imports)
self.code = source_code_lines[code_start_line:]
code_start_line = len(self.docstring) + len(self.imports) + len(self.typing)
self.code = source_lines[code_start_line:]

def _join_code(self) -> str:
"""Join the source code from docstring, import statements and code lines.
Expand All @@ -98,7 +144,8 @@ def _join_code(self) -> str:
# Remove new lines at start and end of each section.
sections = [
"\n".join(section).strip()
for section in (self.docstring, self.imports, self.code)
for section in (self.docstring, self.imports, self.typing, self.code)
if len(section) > 0
]

# Add new lines between existent sections
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_services.py
Expand Up @@ -670,3 +670,29 @@ def test_fix_autoimports_objects_defined_in___all__special_variable() -> None:
result = fix_code(source)

assert result == fixed_source


def test_fix_respects_type_checking_import_statements() -> None:
"""
Given: Code with if TYPE_CHECKING imports
When: Fix code is run.
Then: The imports are not moved above the if statement.
"""
source = dedent(
"""\
import os
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .model import Book
os.getcwd()
def read_book(book: Book):
pass"""
)

result = fix_code(source)

assert result == source

0 comments on commit 8360d6e

Please sign in to comment.