Skip to content

Commit

Permalink
feat!: add location to error reports (#357)
Browse files Browse the repository at this point in the history
* feat(imports): add `Location` class

* feat(imports): extract locations

* feat(violations): add error information and location

* feat: display locations on errors

* feat: sort violations per location and error code

* docs(usage): update JSON output example
  • Loading branch information
mkniewallner committed May 7, 2023
1 parent fe84fc5 commit eaf9546
Show file tree
Hide file tree
Showing 38 changed files with 1,369 additions and 402 deletions.
75 changes: 48 additions & 27 deletions deptry/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import operator
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING
Expand All @@ -16,7 +17,7 @@
from deptry.issues_finder.missing import MissingDependenciesFinder
from deptry.issues_finder.obsolete import ObsoleteDependenciesFinder
from deptry.issues_finder.transitive import TransitiveDependenciesFinder
from deptry.module import ModuleBuilder
from deptry.module import ModuleBuilder, ModuleLocations
from deptry.python_file_finder import PythonFileFinder
from deptry.reporters import JSONReporter, TextReporter
from deptry.stdlibs import STDLIBS_PYTHON
Expand All @@ -27,7 +28,6 @@

from deptry.dependency import Dependency
from deptry.dependency_getter.base import DependenciesExtract
from deptry.module import Module
from deptry.violations import Violation


Expand Down Expand Up @@ -68,19 +68,26 @@ def run(self) -> None:
local_modules = self._get_local_modules()
stdlib_modules = self._get_stdlib_modules()

imported_modules = [
ModuleBuilder(
mod,
local_modules,
stdlib_modules,
dependencies_extract.dependencies,
dependencies_extract.dev_dependencies,
).build()
for mod in get_imported_modules_for_list_of_files(all_python_files)
imported_modules_with_locations = [
ModuleLocations(
ModuleBuilder(
module,
local_modules,
stdlib_modules,
dependencies_extract.dependencies,
dependencies_extract.dev_dependencies,
).build(),
locations,
)
for module, locations in get_imported_modules_for_list_of_files(all_python_files).items()
]
imported_modules_with_locations = [
module_with_locations
for module_with_locations in imported_modules_with_locations
if not module_with_locations.module.standard_library
]
imported_modules = [mod for mod in imported_modules if not mod.standard_library]

violations = self._find_violations(imported_modules, dependencies_extract.dependencies)
violations = self._find_violations(imported_modules_with_locations, dependencies_extract.dependencies)
TextReporter(violations).report()

if self.json_output:
Expand All @@ -89,27 +96,41 @@ def run(self) -> None:
self._exit(violations)

def _find_violations(
self, imported_modules: list[Module], dependencies: list[Dependency]
) -> dict[str, list[Violation]]:
result = {}
self, imported_modules_with_locations: list[ModuleLocations], dependencies: list[Dependency]
) -> list[Violation]:
violations = []

if not self.skip_obsolete:
result["obsolete"] = ObsoleteDependenciesFinder(imported_modules, dependencies, self.ignore_obsolete).find()
violations.extend(
ObsoleteDependenciesFinder(imported_modules_with_locations, dependencies, self.ignore_obsolete).find()
)

if not self.skip_missing:
result["missing"] = MissingDependenciesFinder(imported_modules, dependencies, self.ignore_missing).find()
violations.extend(
MissingDependenciesFinder(imported_modules_with_locations, dependencies, self.ignore_missing).find()
)

if not self.skip_transitive:
result["transitive"] = TransitiveDependenciesFinder(
imported_modules, dependencies, self.ignore_transitive
).find()
violations.extend(
TransitiveDependenciesFinder(
imported_modules_with_locations, dependencies, self.ignore_transitive
).find()
)

if not self.skip_misplaced_dev:
result["misplaced_dev"] = MisplacedDevDependenciesFinder(
imported_modules, dependencies, self.ignore_misplaced_dev
).find()
violations.extend(
MisplacedDevDependenciesFinder(
imported_modules_with_locations, dependencies, self.ignore_misplaced_dev
).find()
)

return self._get_sorted_violations(violations)

return result
@staticmethod
def _get_sorted_violations(violations: list[Violation]) -> list[Violation]:
return sorted(
violations, key=operator.attrgetter("location.file", "location.line", "location.column", "error_code")
)

def _get_dependencies(self, dependency_management_format: DependencyManagementFormat) -> DependenciesExtract:
if dependency_management_format is DependencyManagementFormat.POETRY:
Expand Down Expand Up @@ -161,5 +182,5 @@ def _log_config(self) -> None:
logging.debug("")

@staticmethod
def _exit(violations: dict[str, list[Violation]]) -> None:
sys.exit(int(any(violations.values())))
def _exit(violations: list[Violation]) -> None:
sys.exit(bool(violations))
14 changes: 10 additions & 4 deletions deptry/imports/extract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING

from deptry.imports.extractors import NotebookImportExtractor, PythonImportExtractor
Expand All @@ -10,19 +10,25 @@
from pathlib import Path

from deptry.imports.extractors.base import ImportExtractor
from deptry.imports.location import Location


def get_imported_modules_for_list_of_files(list_of_files: list[Path]) -> list[str]:
def get_imported_modules_for_list_of_files(list_of_files: list[Path]) -> dict[str, list[Location]]:
logging.info(f"Scanning {len(list_of_files)} files...")

modules = sorted(set(itertools.chain.from_iterable(get_imported_modules_from_file(file) for file in list_of_files)))
modules: dict[str, list[Location]] = defaultdict(list)

for file in list_of_files:
for module, locations in get_imported_modules_from_file(file).items():
for location in locations:
modules[module].append(location)

logging.debug(f"All imported modules: {modules}\n")

return modules


def get_imported_modules_from_file(path_to_file: Path) -> set[str]:
def get_imported_modules_from_file(path_to_file: Path) -> dict[str, list[Location]]:
logging.debug(f"Scanning {path_to_file}...")

modules = _get_extractor_class(path_to_file)(path_to_file).extract_imports()
Expand Down
17 changes: 11 additions & 6 deletions deptry/imports/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import ast
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING

import chardet

from deptry.imports.location import Location

if TYPE_CHECKING:
from pathlib import Path

Expand All @@ -20,11 +23,10 @@ class ImportExtractor(ABC):
file: Path

@abstractmethod
def extract_imports(self) -> set[str]:
def extract_imports(self) -> dict[str, list[Location]]:
raise NotImplementedError()

@staticmethod
def _extract_imports_from_ast(tree: ast.AST) -> set[str]:
def _extract_imports_from_ast(self, tree: ast.AST) -> dict[str, list[Location]]:
"""
Given an Abstract Syntax Tree, find the imported top-level modules.
For example, given the source tree of a file with contents:
Expand All @@ -34,13 +36,16 @@ def _extract_imports_from_ast(tree: ast.AST) -> set[str]:
Will return the set {"pandas"}.
"""

imported_modules: set[str] = set()
imported_modules: dict[str, list[Location]] = defaultdict(list)

for node in ast.walk(tree):
if isinstance(node, ast.Import):
imported_modules |= {module.name.split(".")[0] for module in node.names}
for module in node.names:
imported_modules[module.name.split(".")[0]].append(
Location(self.file, node.lineno, node.col_offset)
)
elif isinstance(node, ast.ImportFrom) and node.module and node.level == 0:
imported_modules.add(node.module.split(".")[0])
imported_modules[node.module.split(".")[0]].append(Location(self.file, node.lineno, node.col_offset))

return imported_modules

Expand Down
6 changes: 4 additions & 2 deletions deptry/imports/extractors/notebook_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
if TYPE_CHECKING:
from pathlib import Path

from deptry.imports.location import Location


@dataclass
class NotebookImportExtractor(ImportExtractor):
"""Extract import statements from a Jupyter notebook."""

def extract_imports(self) -> set[str]:
def extract_imports(self) -> dict[str, list[Location]]:
"""Extract the imported top-level modules from all code cells in the Jupyter Notebook."""
notebook = self._read_ipynb_file(self.file)
if not notebook:
return set()
return {}

cells = self._keep_code_cells(notebook)
import_statements = [self._extract_import_statements_from_cell(cell) for cell in cells]
Expand Down
8 changes: 6 additions & 2 deletions deptry/imports/extractors/python_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import ast
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING

from deptry.imports.extractors.base import ImportExtractor

if TYPE_CHECKING:
from deptry.imports.location import Location


@dataclass
class PythonImportExtractor(ImportExtractor):
"""Extract import statements from a Python module."""

def extract_imports(self) -> set[str]:
def extract_imports(self) -> dict[str, list[Location]]:
"""Extract all imported top-level modules from the Python file."""
try:
with open(self.file) as python_file:
Expand All @@ -22,6 +26,6 @@ def extract_imports(self) -> set[str]:
tree = ast.parse(python_file.read(), str(self.file))
except UnicodeDecodeError:
logging.warning(f"Warning: File {self.file} could not be decoded. Skipping...")
return set()
return {}

return self._extract_imports_from_ast(tree)
19 changes: 19 additions & 0 deletions deptry/imports/location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path


@dataclass(frozen=True)
class Location:
file: Path
line: int | None = None
column: int | None = None

def format_for_terminal(self) -> str:
if self.line is not None and self.column is not None:
return f"{self.file}:{self.line}:{self.column}"
return str(self.file)
4 changes: 2 additions & 2 deletions deptry/issues_finder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

if TYPE_CHECKING:
from deptry.dependency import Dependency
from deptry.module import Module
from deptry.module import ModuleLocations
from deptry.violations import Violation


@dataclass
class IssuesFinder(ABC):
"""Base class for all issues finders."""

imported_modules: list[Module]
imported_modules_with_locations: list[ModuleLocations]
dependencies: list[Dependency]
ignored_modules: tuple[str, ...] = ()

Expand Down
7 changes: 5 additions & 2 deletions deptry/issues_finder/misplaced_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ def find(self) -> list[Violation]:
logging.debug("\nScanning for incorrect development dependencies...")
misplaced_dev_dependencies: list[Violation] = []

for module in self.imported_modules:
for module_with_locations in self.imported_modules_with_locations:
module = module_with_locations.module

logging.debug(f"Scanning module {module.name}...")
corresponding_package_name = self._get_package_name(module)

if corresponding_package_name and self._is_development_dependency(module, corresponding_package_name):
misplaced_dev_dependencies.append(MisplacedDevDependencyViolation(module))
for location in module_with_locations.locations:
misplaced_dev_dependencies.append(MisplacedDevDependencyViolation(module, location))

return misplaced_dev_dependencies

Expand Down
7 changes: 5 additions & 2 deletions deptry/issues_finder/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def find(self) -> list[Violation]:
logging.debug("\nScanning for missing dependencies...")
missing_dependencies: list[Violation] = []

for module in self.imported_modules:
for module_with_locations in self.imported_modules_with_locations:
module = module_with_locations.module

logging.debug(f"Scanning module {module.name}...")

if self._is_missing(module):
missing_dependencies.append(MissingDependencyViolation(module))
for location in module_with_locations.locations:
missing_dependencies.append(MissingDependencyViolation(module, location))

return missing_dependencies

Expand Down
16 changes: 13 additions & 3 deletions deptry/issues_finder/obsolete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from deptry.imports.location import Location
from deptry.issues_finder.base import IssuesFinder
from deptry.violations import ObsoleteDependencyViolation

Expand Down Expand Up @@ -32,7 +33,9 @@ def find(self) -> list[Violation]:
logging.debug(f"Scanning module {dependency.name}...")

if self._is_obsolete(dependency):
obsolete_dependencies.append(ObsoleteDependencyViolation(dependency))
obsolete_dependencies.append(
ObsoleteDependencyViolation(dependency, Location(dependency.definition_file))
)

return obsolete_dependencies

Expand All @@ -48,12 +51,19 @@ def _is_obsolete(self, dependency: Dependency) -> bool:
return True

def _dependency_found_in_imported_modules(self, dependency: Dependency) -> bool:
return any(module.package == dependency.name for module in self.imported_modules)
return any(
module_with_locations.module.package == dependency.name
for module_with_locations in self.imported_modules_with_locations
)

def _any_of_the_top_levels_imported(self, dependency: Dependency) -> bool:
if not dependency.top_levels:
return False

return any(
any(module.name == top_level for module in self.imported_modules) for top_level in dependency.top_levels
any(
module_with_locations.module.name == top_level
for module_with_locations in self.imported_modules_with_locations
)
for top_level in dependency.top_levels
)
7 changes: 5 additions & 2 deletions deptry/issues_finder/transitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ def find(self) -> list[Violation]:
logging.debug("\nScanning for transitive dependencies...")
transitive_dependencies: list[Violation] = []

for module in self.imported_modules:
for module_with_locations in self.imported_modules_with_locations:
module = module_with_locations.module

logging.debug(f"Scanning module {module.name}...")

if self._is_transitive(module):
# `self._is_transitive` only returns `True` if the package is not None.
transitive_dependencies.append(TransitiveDependencyViolation(module))
for location in module_with_locations.locations:
transitive_dependencies.append(TransitiveDependencyViolation(module, location))

return transitive_dependencies

Expand Down
Loading

0 comments on commit eaf9546

Please sign in to comment.