Skip to content

Commit

Permalink
Enable custom config to overwrite common_statements
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 committed Nov 21, 2021
1 parent 4bd3bfd commit 02b969e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
19 changes: 13 additions & 6 deletions src/autoimport/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import os
import re
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import autoflake
from pyflakes.messages import UndefinedExport, UndefinedName, UnusedImport
Expand Down Expand Up @@ -37,12 +37,15 @@
class SourceCode: # noqa: R090
"""Python source code entity."""

def __init__(self, source_code: str) -> None:
def __init__(
self, source_code: str, config: Optional[Dict[str, Any]] = None
) -> None:
"""Initialize the object."""
self.header: List[str] = []
self.imports: List[str] = []
self.typing: List[str] = []
self.code: List[str] = []
self.config: Dict[str, Any] = config if config else {}
self._trailing_newline = False
self._split_code(source_code)

Expand Down Expand Up @@ -356,8 +359,7 @@ def _find_package_in_typing(name: str) -> Optional[str]:
except KeyError:
return None

@staticmethod
def _find_package_in_common_statements(name: str) -> Optional[str]:
def _find_package_in_common_statements(self, name: str) -> Optional[str]:
"""Search in the common statements the object name.
Args:
Expand All @@ -366,8 +368,13 @@ def _find_package_in_common_statements(name: str) -> Optional[str]:
Returns:
import_string
"""
if name in common_statements:
return common_statements[name]
if "common_statements" in self.config:
local_common_statements = self.config["common_statements"]
else:
local_common_statements = common_statements

if name in local_common_statements:
return local_common_statements[name]

return None

Expand Down
6 changes: 3 additions & 3 deletions src/autoimport/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and handlers to achieve the program's purpose.
"""

from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from _io import TextIOWrapper

Expand Down Expand Up @@ -49,7 +49,7 @@ def fix_files(files: Tuple[TextIOWrapper]) -> Optional[str]:
return None


def fix_code(original_source_code: str) -> str:
def fix_code(original_source_code: str, config: Optional[Dict[str, Any]] = None) -> str:
"""Fix python source code to correct import statements.
It corrects these errors:
Expand All @@ -64,4 +64,4 @@ def fix_code(original_source_code: str) -> str:
Returns:
Corrected source code.
"""
return SourceCode(original_source_code).fix()
return SourceCode(original_source_code, config=config).fix()
28 changes: 28 additions & 0 deletions tests/unit/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,31 @@ def test_file_with_common_statement() -> None:
result = fix_code(source)

assert result == desired_source


def test_file_with_custom_common_statement() -> None:
"""
Given: Code that uses an undefined object called `FooBar`.
When:
Fix code is run and a `config` dict is passed specifying `FooBar` as a common
statement.
Then:
The appropriate import statement from the common_statements dict is added.
"""
source = dedent(
"""\
FooBar
"""
)
custom_config = {"common_statements": {"FooBar": "from baz_qux import FooBar"}}
desired_source = dedent(
"""\
from baz_qux import FooBar
FooBar
"""
)

result = fix_code(source, config=custom_config)

assert result == desired_source

0 comments on commit 02b969e

Please sign in to comment.