Skip to content

Commit 46e5850

Browse files
committed
Fix: built-in star import suggestion
1 parent 842b3af commit 46e5850

2 files changed

Lines changed: 28 additions & 19 deletions

File tree

tests/test_refactor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,24 @@ def test_star_import_2(self):
551551
self.session.refactor(action),
552552
)
553553

554+
def test_two_suggestion(self):
555+
action = (
556+
"from time import *\n"
557+
"from os import *\n"
558+
"time() # Function from time module.\n"
559+
"path.join()\n"
560+
)
561+
expected = (
562+
"from time import time\n"
563+
"from os import path\n"
564+
"time() # Function from time module.\n"
565+
"path.join()\n"
566+
)
567+
self.assertEqual(
568+
expected,
569+
self.session.refactor(action),
570+
)
571+
554572

555573
class TestImportError(RefactorTestCase):
556574
"""

unimport/scan.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import contextlib
88
import functools
99
import importlib
10-
import inspect
1110
import io
1211
import re
1312
import sys
@@ -317,24 +316,16 @@ def get_names(self) -> Iterator[Name]:
317316
def get_suggestion_modules(self, imp: ImportFrom) -> List[str]:
318317
if imp.module is None:
319318
return []
320-
scanner = self.__class__()
321-
try:
322-
source = inspect.getsource(imp.module)
323-
except (OSError, TypeError):
324-
return []
325-
else:
326-
scanner.scan(source)
327-
objects = scanner.classes + scanner.functions + scanner.names
328-
from_all_name = {
329-
obj.name.split(".")[0] for obj in objects
330-
} # from module
331-
to_names = { # current
332-
to_cfv.name
333-
for to_cfv in self.names
334-
if to_cfv.name not in self.ignore_import_names
335-
}
336-
suggestion_modules = sorted(from_all_name & to_names)
337-
return suggestion_modules
319+
current_names = { # current
320+
to_cfv.name
321+
for to_cfv in self.names
322+
if to_cfv.name not in self.ignore_import_names
323+
}
324+
modules = {
325+
module for module in dir(imp.module) if not module.startswith("_")
326+
}
327+
suggestion_modules = sorted(modules & current_names)
328+
return suggestion_modules
338329

339330
def get_unused_imports(self) -> Iterator[Union[Import, ImportFrom]]:
340331
for imp in self.imports:

0 commit comments

Comments
 (0)