diff --git a/ruff.toml b/ruff.toml index f30d38974..e638ba30c 100644 --- a/ruff.toml +++ b/ruff.toml @@ -57,6 +57,7 @@ extend-generics = [ "codegen.sdk.core.assignment.Assignment", "codegen.sdk.core.class_definition.Class", "codegen.sdk.core.codebase.Codebase", + "codegen.sdk.core.codeowner.CodeOwner", "codegen.sdk.core.dataclasses.usage.Usage", "codegen.sdk.core.dataclasses.usage.UsageType", "codegen.sdk.core.dataclasses.usage.UsageKind", diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index ca2a6087b..d42be21f4 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -7,6 +7,7 @@ import re from collections.abc import Generator from contextlib import contextmanager +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Generic, Literal, TypeVar, Unpack, overload @@ -36,6 +37,7 @@ from codegen.sdk.codebase.span import Span from codegen.sdk.core.assignment import Assignment from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.codeowner import CodeOwner from codegen.sdk.core.detached_symbols.code_block import CodeBlock from codegen.sdk.core.detached_symbols.parameter import Parameter from codegen.sdk.core.directory import Directory @@ -257,6 +259,17 @@ def files(self, *, extensions: list[str] | Literal["*"] | None = None) -> list[T # Sort files alphabetically return sort_editables(files, alphabetical=True, dedupe=False) + @cached_property + def codeowners(self) -> list["CodeOwner[TSourceFile]"]: + """List all CodeOnwers in the codebase. + + Returns: + list[CodeOwners]: A list of CodeOwners objects in the codebase. + """ + if self.G.codeowners_parser is None: + return [] + return CodeOwner.from_parser(self.G.codeowners_parser, lambda *args, **kwargs: self.files(*args, **kwargs)) + @property def directories(self) -> list[TDirectory]: """List all directories in the codebase. diff --git a/src/codegen/sdk/core/codeowner.py b/src/codegen/sdk/core/codeowner.py new file mode 100644 index 000000000..bb896d83b --- /dev/null +++ b/src/codegen/sdk/core/codeowner.py @@ -0,0 +1,97 @@ +import logging +from collections.abc import Iterable, Iterator +from typing import Callable, Generic, Literal + +from codeowners import CodeOwners as CodeOwnersParser + +from codegen.sdk._proxy import proxy_property +from codegen.sdk.core.interfaces.has_symbols import ( + FilesParam, + HasSymbols, + TClass, + TFile, + TFunction, + TGlobalVar, + TImport, + TImportStatement, + TSymbol, +) +from codegen.sdk.core.utils.cache_utils import cached_generator +from codegen.shared.decorators.docs import apidoc, noapidoc + +logger = logging.getLogger(__name__) + + +@apidoc +class CodeOwner( + HasSymbols[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], + Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], +): + """CodeOwner is a class that represents a code owner in a codebase. + + It is used to iterate over all files that are owned by a specific owner. + + Attributes: + owner_type: The type of the owner (USERNAME, TEAM, EMAIL). + owner_value: The value of the owner. + files_source: A callable that returns an iterable of all files in the codebase. + """ + + owner_type: Literal["USERNAME", "TEAM", "EMAIL"] + owner_value: str + files_source: Callable[FilesParam, Iterable[TFile]] + + def __init__( + self, + files_source: Callable[FilesParam, Iterable[TFile]], + owner_type: Literal["USERNAME", "TEAM", "EMAIL"], + owner_value: str, + ): + self.owner_type = owner_type + self.owner_value = owner_value + self.files_source = files_source + + @classmethod + def from_parser( + cls, + parser: CodeOwnersParser, + file_source: Callable[FilesParam, Iterable[TFile]], + ) -> list["CodeOwner"]: + """Create a list of CodeOwner objects from a CodeOwnersParser. + + Args: + parser (CodeOwnersParser): The CodeOwnersParser to use. + file_source (Callable[FilesParam, Iterable[TFile]]): A callable that returns an iterable of all files in the codebase. + + Returns: + list[CodeOwner]: A list of CodeOwner objects. + """ + codeowners = [] + for _, _, owners, _, _ in parser.paths: + for owner_label, owner_value in owners: + codeowners.append(CodeOwner(file_source, owner_label, owner_value)) + return codeowners + + @cached_generator(maxsize=16) + @noapidoc + def files_generator(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -> Iterable[TFile]: + for source_file in self.files_source(*args, **kwargs): + # Filter files by owner value + if self.owner_value in source_file.owners: + yield source_file + + @proxy_property + def files(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -> Iterable[TFile]: + """Recursively iterate over all files in the codebase that are owned by the current code owner.""" + return self.files_generator(*args, **kwargs) + + @property + def name(self) -> str: + """The name of the code owner.""" + return self.owner_value + + def __iter__(self) -> Iterator[TFile]: + return iter(self.files_generator()) + + def __repr__(self) -> str: + return f"CodeOwner(owner_type={self.owner_type}, owner_value={self.owner_value})" diff --git a/src/codegen/sdk/core/directory.py b/src/codegen/sdk/core/directory.py index fcbf8711a..115a28062 100644 --- a/src/codegen/sdk/core/directory.py +++ b/src/codegen/sdk/core/directory.py @@ -1,43 +1,30 @@ +import logging import os -from itertools import chain +from collections.abc import Iterator from pathlib import Path -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.shared.decorators.docs import apidoc, py_noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.assignment import Assignment - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.file import File - from codegen.sdk.core.function import Function - from codegen.sdk.core.import_resolution import Import, ImportStatement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.typescript.class_definition import TSClass - from codegen.sdk.typescript.export import TSExport - from codegen.sdk.typescript.file import TSFile - from codegen.sdk.typescript.function import TSFunction - from codegen.sdk.typescript.import_resolution import TSImport - from codegen.sdk.typescript.statements.import_statement import TSImportStatement - from codegen.sdk.typescript.symbol import TSSymbol - -import logging +from typing import Generic, Self + +from codegen.sdk.core.interfaces.has_symbols import ( + HasSymbols, + TClass, + TFile, + TFunction, + TGlobalVar, + TImport, + TImportStatement, + TSymbol, +) +from codegen.sdk.core.utils.cache_utils import cached_generator +from codegen.shared.decorators.docs import apidoc, noapidoc logger = logging.getLogger(__name__) -TFile = TypeVar("TFile", bound="File") -TSymbol = TypeVar("TSymbol", bound="Symbol") -TImportStatement = TypeVar("TImportStatement", bound="ImportStatement") -TGlobalVar = TypeVar("TGlobalVar", bound="Assignment") -TClass = TypeVar("TClass", bound="Class") -TFunction = TypeVar("TFunction", bound="Function") -TImport = TypeVar("TImport", bound="Import") - -TSGlobalVar = TypeVar("TSGlobalVar", bound="Assignment") - - @apidoc -class Directory(Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport]): +class Directory( + HasSymbols[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], + Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], +): """Directory representation for codebase. GraphSitter abstraction of a file directory that can be used to look for files and symbols within a specific directory. @@ -58,7 +45,7 @@ def __init__(self, path: Path, dirpath: str, parent: Self | None): self.path = path self.dirpath = dirpath self.parent = parent - self.items = dict() + self.items = {} def __iter__(self): return iter(self.items.values()) @@ -126,62 +113,13 @@ def _get_subdirectories(directory: Directory): _get_subdirectories(self) return subdirectories - @property - def symbols(self) -> list[TSymbol]: - """Get a recursive list of all symbols in the directory and its subdirectories.""" - return list(chain.from_iterable(f.symbols for f in self.files)) - - @property - def import_statements(self) -> list[TImportStatement]: - """Get a recursive list of all import statements in the directory and its subdirectories.""" - return list(chain.from_iterable(f.import_statements for f in self.files)) - - @property - def global_vars(self) -> list[TGlobalVar]: - """Get a recursive list of all global variables in the directory and its subdirectories.""" - return list(chain.from_iterable(f.global_vars for f in self.files)) - - @property - def classes(self) -> list[TClass]: - """Get a recursive list of all classes in the directory and its subdirectories.""" - return list(chain.from_iterable(f.classes for f in self.files)) - - @property - def functions(self) -> list[TFunction]: - """Get a recursive list of all functions in the directory and its subdirectories.""" - return list(chain.from_iterable(f.functions for f in self.files)) - - @property - @py_noapidoc - def exports(self: "Directory[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]") -> "list[TSExport]": - """Get a recursive list of all exports in the directory and its subdirectories.""" - return list(chain.from_iterable(f.exports for f in self.files)) - - @property - def imports(self) -> list[TImport]: - """Get a recursive list of all imports in the directory and its subdirectories.""" - return list(chain.from_iterable(f.imports for f in self.files)) - - def get_symbol(self, name: str) -> TSymbol | None: - """Get a symbol by name in the directory and its subdirectories.""" - return next((s for s in self.symbols if s.name == name), None) - - def get_import_statement(self, name: str) -> TImportStatement | None: - """Get an import statement by name in the directory and its subdirectories.""" - return next((s for s in self.import_statements if s.name == name), None) - - def get_global_var(self, name: str) -> TGlobalVar | None: - """Get a global variable by name in the directory and its subdirectories.""" - return next((s for s in self.global_vars if s.name == name), None) - - def get_class(self, name: str) -> TClass | None: - """Get a class by name in the directory and its subdirectories.""" - return next((s for s in self.classes if s.name == name), None) - - def get_function(self, name: str) -> TFunction | None: - """Get a function by name in the directory and its subdirectories.""" - return next((s for s in self.functions if s.name == name), None) + @noapidoc + @cached_generator() + def files_generator(self) -> Iterator[TFile]: + """Yield files recursively from the directory.""" + yield from self.files + # Directory-specific methods def add_file(self, file: TFile) -> None: """Add a file to the directory.""" rel_path = os.path.relpath(file.file_path, self.dirpath) @@ -202,18 +140,12 @@ def get_file(self, filename: str, ignore_case: bool = False) -> TFile | None: from codegen.sdk.core.file import File if ignore_case: - return next((f for name, f in self.items.items() if name.lower() == filename.lower() and isinstance(f, File)), None) + return next( + (f for name, f in self.items.items() if name.lower() == filename.lower() and isinstance(f, File)), + None, + ) return self.items.get(filename, None) - @py_noapidoc - def get_export(self: "Directory[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]", name: str) -> "TSExport | None": - """Get an export by name in the directory and its subdirectories (supports only typescript).""" - return next((s for s in self.exports if s.name == name), None) - - def get_import(self, name: str) -> TImport | None: - """Get an import by name in the directory and its subdirectories.""" - return next((s for s in self.imports if s.name == name), None) - def add_subdirectory(self, subdirectory: Self) -> None: """Add a subdirectory to the directory.""" rel_path = os.path.relpath(subdirectory.dirpath, self.dirpath) @@ -230,23 +162,22 @@ def remove_subdirectory_by_path(self, subdirectory_path: str) -> None: del self.items[rel_path] def get_subdirectory(self, subdirectory_name: str) -> Self | None: - """Get a subdirectory by its path relative to the directory.""" + """Get a subdirectory by its name (relative to the directory).""" return self.items.get(subdirectory_name, None) - def remove(self) -> None: - """Remove the directory and all its files and subdirectories.""" - for f in self.files: - f.remove() - def update_filepath(self, new_filepath: str) -> None: - """Update the filepath of the directory.""" + """Update the filepath of the directory and its contained files.""" old_path = self.dirpath new_path = new_filepath - for file in self.files: new_file_path = os.path.join(new_path, os.path.relpath(file.file_path, old_path)) file.update_filepath(new_file_path) + def remove(self) -> None: + """Remove all the files in the files container.""" + for f in self.files: + f.remove() + def rename(self, new_name: str) -> None: """Rename the directory.""" parent_dir, _ = os.path.split(self.dirpath) diff --git a/src/codegen/sdk/core/interfaces/has_symbols.py b/src/codegen/sdk/core/interfaces/has_symbols.py new file mode 100644 index 000000000..7daefa99b --- /dev/null +++ b/src/codegen/sdk/core/interfaces/has_symbols.py @@ -0,0 +1,117 @@ +import logging +from collections.abc import Iterator +from itertools import chain +from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar + +from codegen.sdk.core.utils.cache_utils import cached_generator +from codegen.shared.decorators.docs import py_noapidoc + +if TYPE_CHECKING: + from codegen.sdk.core.assignment import Assignment + from codegen.sdk.core.class_definition import Class + from codegen.sdk.core.file import SourceFile + from codegen.sdk.core.function import Function + from codegen.sdk.core.import_resolution import Import, ImportStatement + from codegen.sdk.core.symbol import Symbol + from codegen.sdk.typescript.class_definition import TSClass + from codegen.sdk.typescript.export import TSExport + from codegen.sdk.typescript.file import TSFile + from codegen.sdk.typescript.function import TSFunction + from codegen.sdk.typescript.import_resolution import TSImport + from codegen.sdk.typescript.statements.import_statement import TSImportStatement + from codegen.sdk.typescript.symbol import TSSymbol + +logger = logging.getLogger(__name__) + + +TFile = TypeVar("TFile", bound="SourceFile") +TSymbol = TypeVar("TSymbol", bound="Symbol") +TImportStatement = TypeVar("TImportStatement", bound="ImportStatement") +TGlobalVar = TypeVar("TGlobalVar", bound="Assignment") +TClass = TypeVar("TClass", bound="Class") +TFunction = TypeVar("TFunction", bound="Function") +TImport = TypeVar("TImport", bound="Import") +FilesParam = ParamSpec("FilesParam") + +TSGlobalVar = TypeVar("TSGlobalVar", bound="Assignment") + + +class HasSymbols(Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport]): + """Abstract interface for files in a codebase. + + Abstract interface for files in a codebase. + """ + + @cached_generator() + def files_generator(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -> Iterator[TFile]: + """Generator for yielding files of the current container's scope.""" + msg = "This method should be implemented by the subclass" + raise NotImplementedError(msg) + + @property + def symbols(self) -> list[TSymbol]: + """Get a recursive list of all symbols in files container.""" + return list(chain.from_iterable(f.symbols for f in self.files_generator())) + + @property + def import_statements(self) -> list[TImportStatement]: + """Get a recursive list of all import statements in files container.""" + return list(chain.from_iterable(f.import_statements for f in self.files_generator())) + + @property + def global_vars(self) -> list[TGlobalVar]: + """Get a recursive list of all global variables in files container.""" + return list(chain.from_iterable(f.global_vars for f in self.files_generator())) + + @property + def classes(self) -> list[TClass]: + """Get a recursive list of all classes in files container.""" + return list(chain.from_iterable(f.classes for f in self.files_generator())) + + @property + def functions(self) -> list[TFunction]: + """Get a recursive list of all functions in files container.""" + return list(chain.from_iterable(f.functions for f in self.files_generator())) + + @property + @py_noapidoc + def exports(self) -> "list[TSExport]": + """Get a recursive list of all exports in files container.""" + return list(chain.from_iterable(f.exports for f in self.files_generator())) + + @property + def imports(self) -> list[TImport]: + """Get a recursive list of all imports in files container.""" + return list(chain.from_iterable(f.imports for f in self.files_generator())) + + def get_symbol(self, name: str) -> TSymbol | None: + """Get a symbol by name in files container.""" + return next((s for s in self.symbols if s.name == name), None) + + def get_import_statement(self, name: str) -> TImportStatement | None: + """Get an import statement by name in files container.""" + return next((s for s in self.import_statements if s.name == name), None) + + def get_global_var(self, name: str) -> TGlobalVar | None: + """Get a global variable by name in files container.""" + return next((s for s in self.global_vars if s.name == name), None) + + def get_class(self, name: str) -> TClass | None: + """Get a class by name in files container.""" + return next((s for s in self.classes if s.name == name), None) + + def get_function(self, name: str) -> TFunction | None: + """Get a function by name in files container.""" + return next((s for s in self.functions if s.name == name), None) + + @py_noapidoc + def get_export( + self: "HasSymbols[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]", + name: str, + ) -> "TSExport | None": + """Get an export by name in files container (supports only typescript).""" + return next((s for s in self.exports if s.name == name), None) + + def get_import(self, name: str) -> TImport | None: + """Get an import by name in files container.""" + return next((s for s in self.imports if s.name == name), None) diff --git a/src/codegen/sdk/core/utils/cache_utils.py b/src/codegen/sdk/core/utils/cache_utils.py new file mode 100644 index 000000000..60f7c4dbf --- /dev/null +++ b/src/codegen/sdk/core/utils/cache_utils.py @@ -0,0 +1,45 @@ +import functools +from collections.abc import Iterator +from typing import Callable, Generic, ParamSpec, TypeVar + +from codegen.sdk.extensions.utils import lru_cache + +ItemType = TypeVar("ItemType") +GenParamSpec = ParamSpec("GenParamSpec") + + +class LazyGeneratorCache(Generic[ItemType]): + """A cache for a generator that is lazily evaluated.""" + + _cache: list[ItemType] + gen: Iterator[ItemType] + + def __init__(self, gen: Iterator[ItemType]): + self._cache = [] + self.gen = gen + + def __iter__(self) -> Iterator[ItemType]: + for item in self._cache: + yield item + + for item in self.gen: + self._cache.append(item) + yield item + + +def cached_generator(maxsize: int = 16, typed: bool = False) -> Callable[[Callable[GenParamSpec, Iterator[ItemType]]], Callable[GenParamSpec, Iterator[ItemType]]]: + """Decorator to cache the output of a generator function. + + The generator's output is fully consumed on the first call and stored as a list. + Subsequent calls with the same arguments yield values from the cached list. + """ + + def decorator(func: Callable[GenParamSpec, Iterator[ItemType]]) -> Callable[GenParamSpec, Iterator[ItemType]]: + @lru_cache(maxsize=maxsize, typed=typed) + @functools.wraps(func) + def wrapper(*args: GenParamSpec.args, **kwargs: GenParamSpec.kwargs) -> Iterator[ItemType]: + return LazyGeneratorCache(func(*args, **kwargs)) + + return wrapper + + return decorator diff --git a/src/codegen/sdk/extensions/utils.pyi b/src/codegen/sdk/extensions/utils.pyi index 67e9d92ea..952bdd0ef 100644 --- a/src/codegen/sdk/extensions/utils.pyi +++ b/src/codegen/sdk/extensions/utils.pyi @@ -1,5 +1,6 @@ from collections.abc import Generator, Iterable -from functools import cached_property +from functools import cached_property as functools_cached_property +from functools import lru_cache as functools_lru_cache from tree_sitter import Node as TSNode @@ -18,7 +19,8 @@ def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]: def find_first_descendant(node: TSNode, type_names: list[str], max_depth: int | None = None) -> TSNode | None: ... -cached_property = cached_property +cached_property = functools_cached_property +lru_cache = functools_lru_cache def uncache_all(): ... def is_descendant_of(node: TSNode, possible_parent: TSNode) -> bool: ... diff --git a/src/codegen/sdk/extensions/utils.pyx b/src/codegen/sdk/extensions/utils.pyx index e95b69ef4..992db3663 100644 --- a/src/codegen/sdk/extensions/utils.pyx +++ b/src/codegen/sdk/extensions/utils.pyx @@ -1,6 +1,7 @@ from collections import Counter from collections.abc import Generator, Iterable -from functools import cached_property +from functools import cached_property as functools_cached_property +from functools import lru_cache as functools_lru_cache from tabulate import tabulate from tree_sitter import Node as TSNode @@ -106,10 +107,11 @@ def find_first_descendant(node: TSNode, type_names: list[str], max_depth: int | to_uncache = [] +lru_caches = [] counter = Counter() -class cached_property(cached_property): +class cached_property(functools_cached_property): def __get__(self, instance, owner=None): ret = super().__get__(instance) if instance is not None: @@ -118,6 +120,20 @@ class cached_property(cached_property): return ret +def lru_cache(func=None, *, maxsize=128, typed=False): + """A wrapper around functools.lru_cache that tracks the cached function so that its cache + can be cleared later via uncache_all(). + """ + if func is None: + # return decorator + return lambda f: lru_cache(f, maxsize=maxsize, typed=typed) + + # return decorated + cached_func = functools_lru_cache(maxsize=maxsize, typed=typed)(func) + lru_caches.append(cached_func) + return cached_func + + def uncache_all(): for instance, name in to_uncache: try: @@ -125,6 +141,9 @@ def uncache_all(): except KeyError: pass + for cached_func in lru_caches: + cached_func.cache_clear() + def report(): print(tabulate(counter.most_common(10))) diff --git a/src/codegen/shared/compilation/function_imports.py b/src/codegen/shared/compilation/function_imports.py index c020230b5..f8539926c 100644 --- a/src/codegen/shared/compilation/function_imports.py +++ b/src/codegen/shared/compilation/function_imports.py @@ -28,6 +28,7 @@ def get_generated_imports(): from codegen.sdk.core.codebase import CodebaseType from codegen.sdk.core.codebase import PyCodebaseType from codegen.sdk.core.codebase import TSCodebaseType +from codegen.sdk.core.codeowner import CodeOwner from codegen.sdk.core.dataclasses.usage import Usage from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.dataclasses.usage import UsageType diff --git a/tests/unit/codegen/extensions/test_utils.py b/tests/unit/codegen/extensions/test_utils.py new file mode 100644 index 000000000..395d1905a --- /dev/null +++ b/tests/unit/codegen/extensions/test_utils.py @@ -0,0 +1,43 @@ +from threading import Event + +import pytest + +from codegen.sdk.extensions.utils import lru_cache, uncache_all + + +def test_lru_cache_with_uncache_all(): + event = Event() + + @lru_cache + def cached_function(): + assert not event.is_set() + event.set() + return 42 + + assert cached_function() == 42 + assert cached_function() == 42 + + uncache_all() + + with pytest.raises(AssertionError): + cached_function() + + +def test_lru_cache_args_with_uncache_all(): + event = [Event() for _ in range(2)] + + @lru_cache(maxsize=2) + def cached_function(a): + assert not event[a].is_set() + event[a].set() + return a + + for _ in range(2): + for idx in range(2): + assert cached_function(idx) == idx + + uncache_all() + + for idx in range(2): + with pytest.raises(AssertionError): + cached_function(idx) diff --git a/tests/unit/codegen/sdk/core/interfaces/test_files_interface.py b/tests/unit/codegen/sdk/core/interfaces/test_files_interface.py new file mode 100644 index 000000000..d2e38d94d --- /dev/null +++ b/tests/unit/codegen/sdk/core/interfaces/test_files_interface.py @@ -0,0 +1,153 @@ +from unittest.mock import MagicMock + +import pytest + +from codegen.sdk.core.interfaces.has_symbols import HasSymbols + + +@pytest.fixture +def fake_interface(): + class FakeHasSymbols(HasSymbols): + def __init__(self, files): + self._files = files + + def files_generator(self, *args, **kwargs): + yield from self._files + + # File 1 with its fake attributes. + file1 = MagicMock() + file1.symbols = [MagicMock(), MagicMock()] + file1.symbols[0].name = "symbol1" + file1.symbols[1].name = "symbol2" + file1.import_statements = [MagicMock()] + file1.import_statements[0].name = "import_statement1" + file1.global_vars = [MagicMock()] + file1.global_vars[0].name = "global_variable1" + file1.classes = [MagicMock()] + file1.classes[0].name = "class1" + file1.functions = [MagicMock()] + file1.functions[0].name = "function1" + file1.exports = [MagicMock()] + file1.exports[0].name = "export_item1" + file1.imports = [MagicMock()] + file1.imports[0].name = "import1" + + # File 2 with its fake attributes. + file2 = MagicMock() + file2.symbols = [MagicMock()] + file2.symbols[0].name = "symbol3" + file2.import_statements = [MagicMock()] + file2.import_statements[0].name = "import_statement2" + file2.global_vars = [MagicMock(), MagicMock()] + file2.global_vars[0].name = "global_variable2" + file2.global_vars[1].name = "global_variable3" + file2.classes = [MagicMock()] + file2.classes[0].name = "class2" + file2.functions = [MagicMock()] + file2.functions[0].name = "function2" + file2.exports = [MagicMock(), MagicMock()] + file2.exports[0].name = "export_item2" + file2.exports[1].name = "export_item3" + file2.imports = [MagicMock()] + file2.imports[0].name = "import2" + + fake_files = [file1, file2] + return FakeHasSymbols(fake_files) + + +def test_files_generator_not_implemented(): + # Instantiating HasSymbols directly should cause files_generator to raise NotImplementedError. + fi = HasSymbols() + with pytest.raises(NotImplementedError): + list(fi.files_generator()) + + +def test_symbols_property(fake_interface): + symbols = fake_interface.symbols + names = sorted([item.name for item in symbols]) + assert names == ["symbol1", "symbol2", "symbol3"] + + +def test_import_statements_property(fake_interface): + import_statements = fake_interface.import_statements + names = sorted([item.name for item in import_statements]) + assert names == ["import_statement1", "import_statement2"] + + +def test_global_vars_property(fake_interface): + global_vars = fake_interface.global_vars + names = sorted([item.name for item in global_vars]) + assert names == ["global_variable1", "global_variable2", "global_variable3"] + + +def test_classes_property(fake_interface): + classes = fake_interface.classes + names = sorted([item.name for item in classes]) + assert names == ["class1", "class2"] + + +def test_functions_property(fake_interface): + functions = fake_interface.functions + names = sorted([item.name for item in functions]) + assert names == ["function1", "function2"] + + +def test_exports_property(fake_interface): + exports = fake_interface.exports + names = sorted([item.name for item in exports]) + assert names == ["export_item1", "export_item2", "export_item3"] + + +def test_imports_property(fake_interface): + imports = fake_interface.imports + names = sorted([item.name for item in imports]) + assert names == ["import1", "import2"] + + +def test_get_symbol(fake_interface): + symbol = fake_interface.get_symbol("symbol1") + assert symbol is not None + assert symbol.name == "symbol1" + assert fake_interface.get_symbol("nonexistent") is None + + +def test_get_import_statement(fake_interface): + imp_stmt = fake_interface.get_import_statement("import_statement2") + assert imp_stmt is not None + assert imp_stmt.name == "import_statement2" + assert fake_interface.get_import_statement("nonexistent") is None + + +def test_get_global_var(fake_interface): + global_var = fake_interface.get_global_var("global_variable3") + assert global_var is not None + assert global_var.name == "global_variable3" + assert fake_interface.get_global_var("nonexistent") is None + + +def test_get_class(fake_interface): + cls = fake_interface.get_class("class2") + assert cls is not None + assert cls.name == "class2" + assert fake_interface.get_class("nonexistent") is None + + +def test_get_function(fake_interface): + func = fake_interface.get_function("function2") + assert func is not None + assert func.name == "function2" + assert fake_interface.get_function("nonexistent") is None + + +def test_get_export(fake_interface): + export_item = fake_interface.get_export("export_item3") + assert export_item is not None + assert export_item.name == "export_item3" + assert fake_interface.get_export("nonexistent") is None + + +def test_get_import(fake_interface): + imp = fake_interface.get_import("import1") + assert imp is not None + assert imp.name == "import1" + assert fake_interface.get_import("nonexistent") is None diff --git a/tests/unit/codegen/sdk/core/test_codeowner.py b/tests/unit/codegen/sdk/core/test_codeowner.py new file mode 100644 index 000000000..c075c9fe0 --- /dev/null +++ b/tests/unit/codegen/sdk/core/test_codeowner.py @@ -0,0 +1,84 @@ +from unittest.mock import MagicMock + +import pytest + +from codegen.sdk.core.codeowner import CodeOwner + + +# Dummy file objects used for testing CodeOwner. +@pytest.fixture +def fake_files() -> list[MagicMock]: + file1 = MagicMock() + file1.owners = ["alice", "bob"] + + file2 = MagicMock() + file2.owners = ["charlie"] + + file3 = MagicMock() + file3.owners = ["alice"] + + return [file1, file2, file3] + + +def test_files_generator_returns_correct_files(fake_files): + def file_source(*args, **kwargs): + return fake_files + + codeowner = CodeOwner(file_source, "USERNAME", "alice") + files = list(codeowner.files_generator()) + # file1 and file3 contain "alice" as one of their owners. + assert fake_files[0] in files + assert fake_files[2] in files + assert fake_files[1] not in files + + +def test_files_property(fake_files): + def file_source(*args, **kwargs): + return fake_files + + codeowner = CodeOwner(file_source, "USERNAME", "alice") + files = list(codeowner.files) + # file1 and file3 contain "alice" as one of their owners. + assert fake_files[0] in files + assert fake_files[2] in files + assert fake_files[1] not in files + + assert files == list(codeowner.files()) + + +def test_name_property_and_repr(): + def dummy_source(*args, **kwargs): + return [] + + codeowner = CodeOwner(dummy_source, "TEAM", "dev_team") + assert codeowner.name == "dev_team" + rep = repr(codeowner) + assert "TEAM" in rep and "dev_team" in rep + + +def test_iter_method(fake_files): + def file_source(*args, **kwargs): + return fake_files + + codeowner = CodeOwner(file_source, "USERNAME", "charlie") + iterated_files = list(codeowner) + assert iterated_files == [fake_files[1]] + + +def test_from_parser_method(fake_files): + # Create a fake parser with a paths attribute. + fake_parser = MagicMock() + fake_parser.paths = [ + ("pattern1", "ignored", [("USERNAME", "alice"), ("TEAM", "devs")], "ignored", "ignored"), + ("pattern2", "ignored", [("EMAIL", "bob@example.com")], "ignored", "ignored"), + ] + + def file_source(*args, **kwargs): + return fake_files + + codeowners = CodeOwner.from_parser(fake_parser, file_source) + assert len(codeowners) == 3 + owner_values = [co.owner_value for co in codeowners] + assert "alice" in owner_values + assert "devs" in owner_values + assert "bob@example.com" in owner_values diff --git a/tests/unit/codegen/sdk/core/test_directory.py b/tests/unit/codegen/sdk/core/test_directory.py new file mode 100644 index 000000000..1031730f2 --- /dev/null +++ b/tests/unit/codegen/sdk/core/test_directory.py @@ -0,0 +1,221 @@ +import os +import types +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from codegen.sdk.codebase.codebase_graph import CodebaseGraph +from codegen.sdk.codebase.config import CodebaseConfig +from codegen.sdk.core.directory import Directory +from codegen.sdk.core.file import File + + +@pytest.fixture +def mock_codebase_graph(tmp_path): + mock = MagicMock(spec=CodebaseGraph) + mock.transaction_manager = MagicMock() + mock.config = CodebaseConfig() + mock.repo_path = tmp_path + mock.to_absolute = types.MethodType(CodebaseGraph.to_absolute, mock) + mock.to_relative = types.MethodType(CodebaseGraph.to_relative, mock) + return mock + + +@pytest.fixture +def subdir_path(tmp_path): + return tmp_path / "mock_dir" / "subdir" + + +@pytest.fixture +def dir_path(tmp_path): + return tmp_path / "mock_dir" + + +@pytest.fixture +def sub_dir(subdir_path, tmp_path): + return Directory(path=subdir_path.absolute(), dirpath=subdir_path.relative_to(tmp_path), parent=None) + + +@pytest.fixture +def mock_file(dir_path, mock_codebase_graph): + return File(filepath=dir_path / "example.py", G=mock_codebase_graph) + + +@pytest.fixture +def mock_directory(tmp_path, dir_path, sub_dir, mock_file): + directory = Directory(path=dir_path.absolute(), dirpath=dir_path.relative_to(tmp_path), parent=None) + directory.add_file(mock_file) + directory.add_subdirectory(sub_dir) + return directory + + +def test_directory_init(tmp_path, mock_directory): + """Test initialization of Directory object.""" + assert mock_directory.path == tmp_path / "mock_dir" + assert mock_directory.dirpath == Path("mock_dir") + assert mock_directory.parent is None + assert len(mock_directory.items) == 2 + assert mock_directory.items["subdir"] is not None + assert mock_directory.items["example.py"] is not None + + +def test_name_property(mock_directory): + """Test name property returns the basename of the dirpath.""" + assert mock_directory.name == "mock_dir" + + +def test_add_and_file(mock_directory, mock_codebase_graph): + """Test adding a file to the directory.""" + mock_file = File(filepath=Path("mock_dir/example_2.py"), G=mock_codebase_graph) + mock_directory.add_file(mock_file) + rel_path = os.path.relpath(mock_file.file_path, mock_directory.dirpath) + assert rel_path in mock_directory.items + assert mock_directory.items[rel_path] is mock_file + + +def test_remove_file(mock_directory, mock_file): + """Test removing a file from the directory.""" + mock_directory.remove_file(mock_file) + + rel_path = os.path.relpath(mock_file.file_path, mock_directory.dirpath) + assert rel_path not in mock_directory.items + + +def test_remove_file_by_path(mock_directory, mock_file): + """Test removing a file by path.""" + mock_directory.remove_file_by_path(mock_file.file_path) + + rel_path = os.path.relpath(mock_file.file_path, mock_directory.dirpath) + assert rel_path not in mock_directory.items + + +def test_get_file(mock_directory, mock_file): + """Test retrieving a file by name.""" + retrieved_file = mock_directory.get_file("example.py") + assert retrieved_file is mock_file + + # Case-insensitive match + retrieved_file_ci = mock_directory.get_file("EXAMPLE.PY", ignore_case=True) + assert retrieved_file_ci is mock_file + + +def test_get_file_not_found(mock_directory): + """Test retrieving a non-existing file returns None.""" + assert mock_directory.get_file("nonexistent.py") is None + + +def test_add_subdirectory(mock_directory, dir_path): + """Test adding a subdirectory.""" + new_subdir_path = dir_path / "new_subdir" + subdir = Directory(path=new_subdir_path.absolute(), dirpath=new_subdir_path.relative_to(dir_path), parent=mock_directory) + mock_directory.add_subdirectory(subdir) + rel_path = os.path.relpath(subdir.dirpath, mock_directory.dirpath) + assert rel_path in mock_directory.items + assert mock_directory.items[rel_path] is subdir + + +def test_remove_subdirectory(mock_directory, sub_dir): + """Test removing a subdirectory.""" + mock_directory.add_subdirectory(sub_dir) + mock_directory.remove_subdirectory(sub_dir) + + rel_path = os.path.relpath(sub_dir.dirpath, mock_directory.dirpath) + assert rel_path not in mock_directory.items + + +def test_remove_subdirectory_by_path(mock_directory, sub_dir): + """Test removing a subdirectory by path.""" + mock_directory.remove_subdirectory_by_path(sub_dir.dirpath) + + rel_path = os.path.relpath(sub_dir.dirpath, mock_directory.dirpath) + assert rel_path not in mock_directory.items + + +def test_get_subdirectory(mock_directory, sub_dir): + """Test retrieving a subdirectory by name.""" + retrieved_subdir = mock_directory.get_subdirectory("subdir") + assert retrieved_subdir is sub_dir + + +def test_files_property(mock_directory, sub_dir, mock_codebase_graph): + """Test the 'files' property returns all files recursively.""" + all_files = mock_directory.files + assert len(all_files) == 1 + + new_file = File(filepath=Path("mock_dir/example_2.py"), G=mock_codebase_graph) + sub_dir.add_file(new_file) + + all_files = mock_directory.files + assert len(all_files) == 2 + assert new_file in all_files + + gen = mock_directory.files_generator() + files_list = list(gen) + assert len(files_list) == 2 + assert new_file in files_list + + +def test_subdirectories_property(mock_directory, sub_dir): + """Test the 'subdirectories' property returns all directories recursively.""" + all_subdirs = mock_directory.subdirectories + assert len(all_subdirs) == 1 + assert sub_dir in all_subdirs + + new_sub_dir = Directory(path=sub_dir.path / "new_subdir", dirpath=sub_dir.dirpath / "new_subdir", parent=sub_dir) + sub_dir.add_subdirectory(new_sub_dir) + + all_subdirs = mock_directory.subdirectories + assert len(all_subdirs) == 2 + assert new_sub_dir in all_subdirs + + +def test_update_filepath(mock_directory, mock_codebase_graph, mock_file): + """Test updating file paths when the directory path changes.""" + mock_directory.update_filepath("/absolute/new_mock_dir") + + # Verify the files have updated file paths + mock_codebase_graph.transaction_manager.add_file_rename_transaction.assert_called_once_with(mock_file, "/absolute/new_mock_dir/example.py") + + +def test_remove(mock_directory, sub_dir, mock_codebase_graph, mock_file): + mock_directory.remove() + + mock_codebase_graph.transaction_manager.add_file_remove_transaction.assert_called_once_with(mock_file) + + +def test_rename(mock_directory, mock_codebase_graph, mock_file): + """Test renaming the directory.""" + mock_directory.rename("renamed_dir") + # This fails because it is not implemented to rename the directory itself. + # assert mock_directory.dirpath == "/absolute/renamed_dir" + mock_codebase_graph.transaction_manager.add_file_rename_transaction.assert_called_once_with(mock_file, "renamed_dir/example.py") + + +def test_iteration(mock_directory): + """Test iterating over the directory items.""" + items = list(mock_directory) # uses Directory.__iter__ + assert len(items) == 2 + assert mock_directory.items["subdir"] in items + assert mock_directory.items["example.py"] in items + + +def test_contains(mock_directory): + """Test the containment checks using the 'in' operator.""" + assert "subdir" in mock_directory + assert "example.py" in mock_directory + + +def test_len(mock_directory): + """Test the __len__ method returns the number of items.""" + assert len(mock_directory) == 2 + + +def test_get_set_delete_item(mock_directory): + """Test __getitem__, __setitem__, and __delitem__ methods.""" + mock_file = mock_directory.items["example.py"] + mock_directory["example.py"] = mock_file + assert mock_directory["example.py"] == mock_file + + with pytest.raises(KeyError, match="subdir_2"): + del mock_directory["subdir_2"] diff --git a/tests/unit/codegen/sdk/core/utils/test_cache_utils.py b/tests/unit/codegen/sdk/core/utils/test_cache_utils.py new file mode 100644 index 000000000..2075465f1 --- /dev/null +++ b/tests/unit/codegen/sdk/core/utils/test_cache_utils.py @@ -0,0 +1,21 @@ +from threading import Event + +from codegen.sdk.core.utils.cache_utils import cached_generator + + +def test_cached_generator(): + event = Event() + + @cached_generator() + def cached_function(): + assert not event.is_set() + event.set() + yield from range(10) + + # First call + result = cached_function() + assert list(result) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + # Second call + result = cached_function() + assert list(result) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]