Skip to content

Commit

Permalink
Fix multiline import statements (proposal for issue #8) (#63)
Browse files Browse the repository at this point in the history
* Add skeleton for fixing multiline from imports

* Rename "Contination" to "PendingFix"

* Add fix for the multiline from import problem

* Add fix for multiline imports without from

* Add fix for multiline relative imports

* Fix other edge cases of multiline import

Such as:
- leading/trailing commas/line continuation
- try/except
- semicolon
- empty imports as result

In some cases the approach of "not changing anything if it is too risk"
was adopted (as already happens in other parts of the code).

* Deal with CLI options in FilterMultilineImport

* Simplify multiline imports by ignoring comments

Unfortunately treating inline comments inside imports increases the
complexity in the implementation.
This change reduces complexity by refusing to change multiline import
statements that contain comments.

This approach is also used in other parts of the code.

* Simplify logic in multiline imports

... By using the existing code as a template.

* Fix docstrings according to pydocstyle

* Remove unused itertools

* Remove solved TODO comment

* Add examples from issue #8 and fix edge cases
  • Loading branch information
abravalheri committed Aug 23, 2020
1 parent 22b0a69 commit e14b5c3
Show file tree
Hide file tree
Showing 2 changed files with 705 additions and 26 deletions.
225 changes: 202 additions & 23 deletions autoflake.py
Expand Up @@ -35,6 +35,7 @@
import os
import re
import signal
import string
import sys
import tokenize

Expand Down Expand Up @@ -251,10 +252,6 @@ def multiline_import(line, previous_line=''):
if symbol in line:
return True

# Ignore doctests.
if line.lstrip().startswith('>'):
return True

return multiline_statement(line, previous_line)


Expand All @@ -272,6 +269,182 @@ def multiline_statement(line, previous_line=''):
return True


class PendingFix(object):
"""Allows a rewrite operation to span multiple lines.
In the main rewrite loop, every time a helper function returns a
``PendingFix`` object instead of a string, this object will be called
with the following line.
"""

def __init__(self, line):
"""Analyse and store the first line."""
self.accumulator = collections.deque([line])

def __call__(self, line):
"""Process line considering the accumulator.
Return self to keep processing the following lines or a string
with the final result of all the lines processed at once.
"""
raise NotImplementedError("Abstract method needs to be overwritten")


def _valid_char_in_line(char, line):
"""Return True if a char appears in the line and is not commented."""
comment_index = line.find('#')
char_index = line.find(char)
valid_char_in_line = (
char_index >= 0 and
(comment_index > char_index or comment_index < 0)
)
return valid_char_in_line


def _top_module(module_name):
"""Return the name of the top level module in the hierarchy."""
if module_name[0] == '.':
return '%LOCAL_MODULE%'
return module_name.split('.')[0]


def _modules_to_remove(unused_modules, safe_to_remove=SAFE_IMPORTS):
"""Discard unused modules that are not safe to remove from the list."""
return [x for x in unused_modules if _top_module(x) in safe_to_remove]


def _segment_module(segment):
"""Extract the module identifier inside the segment.
It might be the case the segment does not have a module (e.g. is composed
just by a parenthesis or line continuation and whitespace). In this
scenario we just keep the segment... These characters are not valid in
identifiers, so they will never be contained in the list of unused modules
anyway.
"""
return segment.strip(string.whitespace + ',\\()') or segment


class FilterMultilineImport(PendingFix):
"""Remove unused imports from multiline import statements.
This class handles both the cases: "from imports" and "direct imports".
Some limitations exist (e.g. imports with comments, lines joined by ``;``,
etc). In these cases, the statement is left unchanged to avoid problems.
"""

IMPORT_RE = re.compile(r'\bimport\b\s*')
INDENTATION_RE = re.compile(r'^\s*')
BASE_RE = re.compile(r'\bfrom\s+([^ ]+)')
SEGMENT_RE = re.compile(
r'([^,\s]+(?:[\s\\]+as[\s\\]+[^,\s]+)?[,\s\\)]*)', re.M)
# ^ module + comma + following space (including new line and continuation)
IDENTIFIER_RE = re.compile(r'[^,\s]+')

def __init__(self, line, unused_module=(), remove_all_unused_imports=False,
safe_to_remove=SAFE_IMPORTS, previous_line=''):
"""Receive the same parameters as ``filter_unused_import``."""
self.remove = unused_module
self.parenthesized = '(' in line
self.from_, imports = self.IMPORT_RE.split(line, maxsplit=1)
match = self.BASE_RE.search(self.from_)
self.base = match.group(1) if match else None
self.give_up = False

if not remove_all_unused_imports:
if self.base and _top_module(self.base) not in safe_to_remove:
self.give_up = True
else:
self.remove = _modules_to_remove(self.remove, safe_to_remove)

if '\\' in previous_line:
# Ignore tricky things like "try: \<new line> import" ...
self.give_up = True

self.analyze(line)

PendingFix.__init__(self, imports)

def is_over(self, line=None):
"""Return True if the multiline import statement is over."""
line = line or self.accumulator[-1]

if self.parenthesized:
return _valid_char_in_line(')', line)

return not _valid_char_in_line('\\', line)

def analyze(self, line):
"""Decide if the statement will be fixed or left unchanged."""
if any(ch in line for ch in ';:#'):
self.give_up = True

def fix(self, accumulated):
"""Given a collection of accumulated lines, fix the entire import."""
old_imports = ''.join(accumulated)
ending = get_line_ending(old_imports)
# Split imports into segments that contain the module name +
# comma + whitespace and eventual <newline> \ ( ) chars
segments = [x for x in self.SEGMENT_RE.findall(old_imports) if x]
modules = [_segment_module(x) for x in segments]
keep = _filter_imports(modules, self.base, self.remove)

# Short-circuit if no import was discarded
if len(keep) == len(segments):
return self.from_ + 'import ' + ''.join(accumulated)

fixed = ''
if keep:
# Since it is very difficult to deal with all the line breaks and
# continuations, let's use the code layout that already exists and
# just replace the module identifiers inside the first N-1 segments
# + the last segment
templates = list(zip(modules, segments))
templates = templates[:len(keep)-1] + templates[-1:]
# It is important to keep the last segment, since it might contain
# important chars like `)`
fixed = ''.join(
template.replace(module, keep[i])
for i, (module, template) in enumerate(templates)
)

# Fix the edge case: inline parenthesis + just one surviving import
if self.parenthesized and any(ch not in fixed for ch in '()'):
fixed = fixed.strip(string.whitespace + '()') + ending

# Replace empty imports with a "pass" statement
empty = len(fixed.strip(string.whitespace + '\\(),')) < 1
if empty:
indentation = self.INDENTATION_RE.search(self.from_).group(0)
return indentation + 'pass' + ending

return self.from_ + 'import ' + fixed

def __call__(self, line=None):
"""Accumulate all the lines in the import and then trigger the fix."""
if line:
self.accumulator.append(line)
self.analyze(line)
if not self.is_over(line):
return self
if self.give_up:
return self.from_ + 'import ' + ''.join(self.accumulator)

return self.fix(self.accumulator)


def _filter_imports(imports, parent=None, unused_module=()):
# We compare full module name (``a.module`` not `module`) to
# guarantee the exact same module as detected from pyflakes.
sep = '' if parent and parent[-1] == '.' else '.'

def full_name(name):
return name if parent is None else parent + sep + name

return [x for x in imports if full_name(x) not in unused_module]


def filter_from_import(line, unused_module):
"""Parse and filter ``from something import a, b, c``.
Expand All @@ -283,15 +456,8 @@ def filter_from_import(line, unused_module):
base_module = re.search(pattern=r'\bfrom\s+([^ ]+)',
string=indentation).group(1)

# Create an imported module list with base module name
# ex ``from a import b, c as d`` -> ``['a.b', 'a.c as d']``
imports = re.split(pattern=r',', string=imports.strip())
imports = [base_module + '.' + x.strip() for x in imports]

# We compare full module name (``a.module`` not `module`) to
# guarantee the exact same module as detected from pyflakes.
filtered_imports = [x.replace(base_module + '.', '')
for x in imports if x not in unused_module]
imports = re.split(pattern=r'\s*,\s*', string=imports.strip())
filtered_imports = _filter_imports(imports, base_module, unused_module)

# All of the import in this statement is unused
if not filtered_imports:
Expand Down Expand Up @@ -388,26 +554,32 @@ def filter_code(source, additional_imports=None,

sio = io.StringIO(source)
previous_line = ''
result = None
for line_number, line in enumerate(sio.readlines(), start=1):
if '#' in line:
yield line
if isinstance(result, PendingFix):
result = result(line)
elif '#' in line:
result = line
elif line_number in marked_import_line_numbers:
yield filter_unused_import(
result = filter_unused_import(
line,
unused_module=marked_unused_module[line_number],
remove_all_unused_imports=remove_all_unused_imports,
imports=imports,
previous_line=previous_line)
elif line_number in marked_variable_line_numbers:
yield filter_unused_variable(line)
result = filter_unused_variable(line)
elif line_number in marked_key_line_numbers:
yield filter_duplicate_key(line, line_messages[line_number],
line_number, marked_key_line_numbers,
source)
result = filter_duplicate_key(line, line_messages[line_number],
line_number, marked_key_line_numbers,
source)
elif line_number in marked_star_import_line_numbers:
yield filter_star_import(line, undefined_names)
result = filter_star_import(line, undefined_names)
else:
yield line
result = line

if not isinstance(result, PendingFix):
yield result

previous_line = line

Expand All @@ -429,9 +601,16 @@ def filter_star_import(line, marked_star_import_undefined_name):
def filter_unused_import(line, unused_module, remove_all_unused_imports,
imports, previous_line=''):
"""Return line if used, otherwise return None."""
if multiline_import(line, previous_line):
# Ignore doctests.
if line.lstrip().startswith('>'):
return line

if multiline_import(line, previous_line):
filt = FilterMultilineImport(line, unused_module,
remove_all_unused_imports,
imports, previous_line)
return filt()

is_from_import = line.lstrip().startswith('from')

if ',' in line and not is_from_import:
Expand Down

0 comments on commit e14b5c3

Please sign in to comment.