Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add location to error reports #357

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions deptry/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import operator
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING
Expand All @@ -16,7 +17,7 @@
from deptry.issues_finder.missing import MissingDependenciesFinder
from deptry.issues_finder.obsolete import ObsoleteDependenciesFinder
from deptry.issues_finder.transitive import TransitiveDependenciesFinder
from deptry.module import ModuleBuilder
from deptry.module import ModuleBuilder, ModuleLocations
from deptry.python_file_finder import PythonFileFinder
from deptry.reporters import JSONReporter, TextReporter
from deptry.stdlibs import STDLIBS_PYTHON
Expand All @@ -27,7 +28,6 @@

from deptry.dependency import Dependency
from deptry.dependency_getter.base import DependenciesExtract
from deptry.module import Module
from deptry.violations import Violation


Expand Down Expand Up @@ -68,19 +68,26 @@ def run(self) -> None:
local_modules = self._get_local_modules()
stdlib_modules = self._get_stdlib_modules()

imported_modules = [
ModuleBuilder(
mod,
local_modules,
stdlib_modules,
dependencies_extract.dependencies,
dependencies_extract.dev_dependencies,
).build()
for mod in get_imported_modules_for_list_of_files(all_python_files)
imported_modules_with_locations = [
ModuleLocations(
ModuleBuilder(
module,
local_modules,
stdlib_modules,
dependencies_extract.dependencies,
dependencies_extract.dev_dependencies,
).build(),
locations,
)
for module, locations in get_imported_modules_for_list_of_files(all_python_files).items()
]
imported_modules_with_locations = [
module_with_locations
for module_with_locations in imported_modules_with_locations
if not module_with_locations.module.standard_library
]
imported_modules = [mod for mod in imported_modules if not mod.standard_library]

violations = self._find_violations(imported_modules, dependencies_extract.dependencies)
violations = self._find_violations(imported_modules_with_locations, dependencies_extract.dependencies)
TextReporter(violations).report()

if self.json_output:
Expand All @@ -89,27 +96,41 @@ def run(self) -> None:
self._exit(violations)

def _find_violations(
self, imported_modules: list[Module], dependencies: list[Dependency]
) -> dict[str, list[Violation]]:
result = {}
self, imported_modules_with_locations: list[ModuleLocations], dependencies: list[Dependency]
) -> list[Violation]:
violations = []

if not self.skip_obsolete:
result["obsolete"] = ObsoleteDependenciesFinder(imported_modules, dependencies, self.ignore_obsolete).find()
violations.extend(
ObsoleteDependenciesFinder(imported_modules_with_locations, dependencies, self.ignore_obsolete).find()
)

if not self.skip_missing:
result["missing"] = MissingDependenciesFinder(imported_modules, dependencies, self.ignore_missing).find()
violations.extend(
MissingDependenciesFinder(imported_modules_with_locations, dependencies, self.ignore_missing).find()
)

if not self.skip_transitive:
result["transitive"] = TransitiveDependenciesFinder(
imported_modules, dependencies, self.ignore_transitive
).find()
violations.extend(
TransitiveDependenciesFinder(
imported_modules_with_locations, dependencies, self.ignore_transitive
).find()
)

if not self.skip_misplaced_dev:
result["misplaced_dev"] = MisplacedDevDependenciesFinder(
imported_modules, dependencies, self.ignore_misplaced_dev
).find()
violations.extend(
MisplacedDevDependenciesFinder(
imported_modules_with_locations, dependencies, self.ignore_misplaced_dev
).find()
)

return self._get_sorted_violations(violations)

return result
@staticmethod
def _get_sorted_violations(violations: list[Violation]) -> list[Violation]:
return sorted(
violations, key=operator.attrgetter("location.file", "location.line", "location.column", "error_code")
)

def _get_dependencies(self, dependency_management_format: DependencyManagementFormat) -> DependenciesExtract:
if dependency_management_format is DependencyManagementFormat.POETRY:
Expand Down Expand Up @@ -161,5 +182,5 @@ def _log_config(self) -> None:
logging.debug("")

@staticmethod
def _exit(violations: dict[str, list[Violation]]) -> None:
sys.exit(int(any(violations.values())))
def _exit(violations: list[Violation]) -> None:
sys.exit(bool(violations))
14 changes: 10 additions & 4 deletions deptry/imports/extract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING

from deptry.imports.extractors import NotebookImportExtractor, PythonImportExtractor
Expand All @@ -10,19 +10,25 @@
from pathlib import Path

from deptry.imports.extractors.base import ImportExtractor
from deptry.imports.location import Location


def get_imported_modules_for_list_of_files(list_of_files: list[Path]) -> list[str]:
def get_imported_modules_for_list_of_files(list_of_files: list[Path]) -> dict[str, list[Location]]:
logging.info(f"Scanning {len(list_of_files)} files...")

modules = sorted(set(itertools.chain.from_iterable(get_imported_modules_from_file(file) for file in list_of_files)))
modules: dict[str, list[Location]] = defaultdict(list)

for file in list_of_files:
for module, locations in get_imported_modules_from_file(file).items():
for location in locations:
modules[module].append(location)

logging.debug(f"All imported modules: {modules}\n")

return modules


def get_imported_modules_from_file(path_to_file: Path) -> set[str]:
def get_imported_modules_from_file(path_to_file: Path) -> dict[str, list[Location]]:
logging.debug(f"Scanning {path_to_file}...")

modules = _get_extractor_class(path_to_file)(path_to_file).extract_imports()
Expand Down
17 changes: 11 additions & 6 deletions deptry/imports/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import ast
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING

import chardet

from deptry.imports.location import Location

if TYPE_CHECKING:
from pathlib import Path

Expand All @@ -20,11 +23,10 @@ class ImportExtractor(ABC):
file: Path

@abstractmethod
def extract_imports(self) -> set[str]:
def extract_imports(self) -> dict[str, list[Location]]:
raise NotImplementedError()

@staticmethod
def _extract_imports_from_ast(tree: ast.AST) -> set[str]:
def _extract_imports_from_ast(self, tree: ast.AST) -> dict[str, list[Location]]:
"""
Given an Abstract Syntax Tree, find the imported top-level modules.
For example, given the source tree of a file with contents:
Expand All @@ -34,13 +36,16 @@ def _extract_imports_from_ast(tree: ast.AST) -> set[str]:
Will return the set {"pandas"}.
"""

imported_modules: set[str] = set()
imported_modules: dict[str, list[Location]] = defaultdict(list)

for node in ast.walk(tree):
if isinstance(node, ast.Import):
imported_modules |= {module.name.split(".")[0] for module in node.names}
for module in node.names:
imported_modules[module.name.split(".")[0]].append(
Location(self.file, node.lineno, node.col_offset)
)
elif isinstance(node, ast.ImportFrom) and node.module and node.level == 0:
imported_modules.add(node.module.split(".")[0])
imported_modules[node.module.split(".")[0]].append(Location(self.file, node.lineno, node.col_offset))

return imported_modules

Expand Down
6 changes: 4 additions & 2 deletions deptry/imports/extractors/notebook_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
if TYPE_CHECKING:
from pathlib import Path

from deptry.imports.location import Location


@dataclass
class NotebookImportExtractor(ImportExtractor):
"""Extract import statements from a Jupyter notebook."""

def extract_imports(self) -> set[str]:
def extract_imports(self) -> dict[str, list[Location]]:
"""Extract the imported top-level modules from all code cells in the Jupyter Notebook."""
notebook = self._read_ipynb_file(self.file)
if not notebook:
return set()
return {}

Check warning on line 27 in deptry/imports/extractors/notebook_import_extractor.py

View check run for this annotation

Codecov / codecov/patch

deptry/imports/extractors/notebook_import_extractor.py#L27

Added line #L27 was not covered by tests

cells = self._keep_code_cells(notebook)
import_statements = [self._extract_import_statements_from_cell(cell) for cell in cells]
Expand Down
8 changes: 6 additions & 2 deletions deptry/imports/extractors/python_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import ast
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING

from deptry.imports.extractors.base import ImportExtractor

if TYPE_CHECKING:
from deptry.imports.location import Location


@dataclass
class PythonImportExtractor(ImportExtractor):
"""Extract import statements from a Python module."""

def extract_imports(self) -> set[str]:
def extract_imports(self) -> dict[str, list[Location]]:
"""Extract all imported top-level modules from the Python file."""
try:
with open(self.file) as python_file:
Expand All @@ -22,6 +26,6 @@ def extract_imports(self) -> set[str]:
tree = ast.parse(python_file.read(), str(self.file))
except UnicodeDecodeError:
logging.warning(f"Warning: File {self.file} could not be decoded. Skipping...")
return set()
return {}

return self._extract_imports_from_ast(tree)
19 changes: 19 additions & 0 deletions deptry/imports/location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path


@dataclass(frozen=True)
class Location:
file: Path
line: int | None = None
column: int | None = None

def format_for_terminal(self) -> str:
if self.line is not None and self.column is not None:
return f"{self.file}:{self.line}:{self.column}"
return str(self.file)
4 changes: 2 additions & 2 deletions deptry/issues_finder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

if TYPE_CHECKING:
from deptry.dependency import Dependency
from deptry.module import Module
from deptry.module import ModuleLocations
from deptry.violations import Violation


@dataclass
class IssuesFinder(ABC):
"""Base class for all issues finders."""

imported_modules: list[Module]
imported_modules_with_locations: list[ModuleLocations]
dependencies: list[Dependency]
ignored_modules: tuple[str, ...] = ()

Expand Down
7 changes: 5 additions & 2 deletions deptry/issues_finder/misplaced_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ def find(self) -> list[Violation]:
logging.debug("\nScanning for incorrect development dependencies...")
misplaced_dev_dependencies: list[Violation] = []

for module in self.imported_modules:
for module_with_locations in self.imported_modules_with_locations:
module = module_with_locations.module

logging.debug(f"Scanning module {module.name}...")
corresponding_package_name = self._get_package_name(module)

if corresponding_package_name and self._is_development_dependency(module, corresponding_package_name):
misplaced_dev_dependencies.append(MisplacedDevDependencyViolation(module))
for location in module_with_locations.locations:
misplaced_dev_dependencies.append(MisplacedDevDependencyViolation(module, location))

return misplaced_dev_dependencies

Expand Down
7 changes: 5 additions & 2 deletions deptry/issues_finder/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def find(self) -> list[Violation]:
logging.debug("\nScanning for missing dependencies...")
missing_dependencies: list[Violation] = []

for module in self.imported_modules:
for module_with_locations in self.imported_modules_with_locations:
module = module_with_locations.module

logging.debug(f"Scanning module {module.name}...")

if self._is_missing(module):
missing_dependencies.append(MissingDependencyViolation(module))
for location in module_with_locations.locations:
missing_dependencies.append(MissingDependencyViolation(module, location))

return missing_dependencies

Expand Down
16 changes: 13 additions & 3 deletions deptry/issues_finder/obsolete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from deptry.imports.location import Location
from deptry.issues_finder.base import IssuesFinder
from deptry.violations import ObsoleteDependencyViolation

Expand Down Expand Up @@ -32,7 +33,9 @@ def find(self) -> list[Violation]:
logging.debug(f"Scanning module {dependency.name}...")

if self._is_obsolete(dependency):
obsolete_dependencies.append(ObsoleteDependencyViolation(dependency))
obsolete_dependencies.append(
ObsoleteDependencyViolation(dependency, Location(dependency.definition_file))
)

return obsolete_dependencies

Expand All @@ -48,12 +51,19 @@ def _is_obsolete(self, dependency: Dependency) -> bool:
return True

def _dependency_found_in_imported_modules(self, dependency: Dependency) -> bool:
return any(module.package == dependency.name for module in self.imported_modules)
return any(
module_with_locations.module.package == dependency.name
for module_with_locations in self.imported_modules_with_locations
)

def _any_of_the_top_levels_imported(self, dependency: Dependency) -> bool:
if not dependency.top_levels:
return False

return any(
any(module.name == top_level for module in self.imported_modules) for top_level in dependency.top_levels
any(
module_with_locations.module.name == top_level
for module_with_locations in self.imported_modules_with_locations
)
for top_level in dependency.top_levels
)
7 changes: 5 additions & 2 deletions deptry/issues_finder/transitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ def find(self) -> list[Violation]:
logging.debug("\nScanning for transitive dependencies...")
transitive_dependencies: list[Violation] = []

for module in self.imported_modules:
for module_with_locations in self.imported_modules_with_locations:
module = module_with_locations.module

logging.debug(f"Scanning module {module.name}...")

if self._is_transitive(module):
# `self._is_transitive` only returns `True` if the package is not None.
transitive_dependencies.append(TransitiveDependencyViolation(module))
for location in module_with_locations.locations:
transitive_dependencies.append(TransitiveDependencyViolation(module, location))

return transitive_dependencies

Expand Down
Loading