diff --git a/pyproject.toml b/pyproject.toml index 1f1ddaf80..76e2367c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,12 @@ keywords = [ codegen = "codegen.cli.cli:main" [project.optional-dependencies] -types = ["types-networkx>=3.2.1.20240918", "types-tabulate>=0.9.0.20240106"] +types = [ + "types-networkx>=3.2.1.20240918", + "types-tabulate>=0.9.0.20240106", + "types-requests>=2.32.0.20241016", + "types-toml>=0.10.8.20240310", +] [tool.uv] cache-keys = [{ git = { commit = true, tags = true } }] dev-dependencies = [ @@ -199,6 +204,7 @@ tmp_path_retention_policy = "failed" requires = ["hatchling>=1.26.3", "hatch-vcs>=0.4.0", "setuptools-scm>=8.0.0"] build-backend = "hatchling.build" + [tool.deptry] extend_exclude = [".*/eval/test_files/.*.py", ".*conftest.py"] pep621_dev_dependency_groups = ["types"] diff --git a/src/codegen/py.typed b/src/codegen/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/src/codegen/sdk/codebase/flagging/code_flag.py b/src/codegen/sdk/codebase/flagging/code_flag.py index cb10a0057..1b1a92fc5 100644 --- a/src/codegen/sdk/codebase/flagging/code_flag.py +++ b/src/codegen/sdk/codebase/flagging/code_flag.py @@ -1,14 +1,14 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Generic, TypeVar from codegen.sdk.codebase.flagging.enums import MessageType +from codegen.sdk.core.interfaces.editable import Editable -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable +Symbol = TypeVar("Symbol", bound=Editable | None) @dataclass -class CodeFlag[Symbol: Editable | None]: +class CodeFlag(Generic[Symbol]): symbol: Symbol message: str | None = None # a short desc of the code flag/violation. ex: enums should be ordered alphabetically message_type: MessageType = MessageType.GITHUB | MessageType.CODEGEN # where to send the message (either Github or Slack) diff --git a/src/codegen/sdk/codebase/flagging/flags.py b/src/codegen/sdk/codebase/flagging/flags.py index 13288e40c..636d5145a 100644 --- a/src/codegen/sdk/codebase/flagging/flags.py +++ b/src/codegen/sdk/codebase/flagging/flags.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import TypeVar from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.enums import MessageType @@ -6,6 +7,8 @@ from codegen.sdk.core.interfaces.editable import Editable from codegen.shared.decorators.docs import noapidoc +Symbol = TypeVar("Symbol", bound=Editable) + @dataclass class Flags: @@ -13,9 +16,9 @@ class Flags: _find_mode: bool = False _active_group: list[CodeFlag] | None = None - def flag_instance[Symbol: Editable | None]( + def flag_instance( self, - symbol: Symbol = None, + symbol: Symbol | None = None, message: str | None = None, message_type: MessageType = MessageType.GITHUB | MessageType.CODEGEN, message_recipient: str | None = None, diff --git a/src/codegen/sdk/codebase/multigraph.py b/src/codegen/sdk/codebase/multigraph.py index 735fc01d1..2a76fec70 100644 --- a/src/codegen/sdk/codebase/multigraph.py +++ b/src/codegen/sdk/codebase/multigraph.py @@ -1,5 +1,6 @@ from collections import defaultdict from dataclasses import dataclass, field +from typing import Generic, TypeVar from codegen.sdk import TYPE_CHECKING from codegen.sdk.core.detached_symbols.function_call import FunctionCall @@ -7,9 +8,11 @@ if TYPE_CHECKING: from codegen.sdk.core.function import Function +TFunction = TypeVar("TFunction", bound=Function) + @dataclass -class MultiGraph[TFunction: Function]: +class MultiGraph(Generic[TFunction]): """Mapping of API endpoints to their definitions and usages across languages.""" api_definitions: dict[str, TFunction] = field(default_factory=dict) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 140b9436b..aaacb5fcd 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -587,7 +587,7 @@ def invalidate(self): @classmethod @noapidoc - def from_content(cls, filepath: str, content: str, G: CodebaseGraph, sync: bool = True, verify_syntax: bool = True) -> Self | None: + def from_content(cls, filepath: str | PathLike | Path, content: str, G: CodebaseGraph, sync: bool = True, verify_syntax: bool = True) -> Self | None: """Creates a new file from content and adds it to the graph.""" path = G.to_absolute(filepath) ts_node = parse_file(path, content) @@ -605,7 +605,7 @@ def from_content(cls, filepath: str, content: str, G: CodebaseGraph, sync: bool G.add_single_file(path) return G.get_file(filepath) else: - return cls(ts_node, filepath, G) + return cls(ts_node, Path(filepath), G) @classmethod @noapidoc diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index 30d8c0085..0fc5343ef 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -22,7 +22,7 @@ from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: - from collections.abc import Callable, Generator, Iterable + from collections.abc import Callable, Generator, Iterable, Sequence import rich.repr from rich.console import Console, ConsoleOptions, RenderResult @@ -157,7 +157,7 @@ def __repr__(self) -> str: def __rich_repr__(self) -> rich.repr.Result: yield escape(self.filepath) - __rich_repr__.angular = ANGULAR_STYLE + __rich_repr__.angular = ANGULAR_STYLE # type: ignore def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: yield Pretty(self, max_string=MAX_STRING_LENGTH) @@ -315,14 +315,14 @@ def extended_source(self, value: str) -> None: @property @reader @noapidoc - def children(self) -> list[Editable]: + def children(self) -> list[Editable[Self]]: """List of Editable instances that are children of this node.""" return [self._parse_expression(child) for child in self.ts_node.named_children] @property @reader @noapidoc - def _anonymous_children(self) -> list[Editable]: + def _anonymous_children(self) -> list[Editable[Self]]: """All anonymous children of an editable.""" return [self._parse_expression(child) for child in self.ts_node.children if not child.is_named] @@ -343,7 +343,7 @@ def next_sibling(self) -> Editable | None: @property @reader @noapidoc - def next_named_sibling(self) -> Editable | None: + def next_named_sibling(self) -> Editable[Parent] | None: if self.ts_node is None: return None @@ -351,12 +351,12 @@ def next_named_sibling(self) -> Editable | None: if next_named_sibling_node is None: return None - return self._parse_expression(next_named_sibling_node) + return self.parent._parse_expression(next_named_sibling_node) @property @reader @noapidoc - def previous_named_sibling(self) -> Editable | None: + def previous_named_sibling(self) -> Editable[Parent] | None: if self.ts_node is None: return None @@ -364,7 +364,7 @@ def previous_named_sibling(self) -> Editable | None: if previous_named_sibling_node is None: return None - return self._parse_expression(previous_named_sibling_node) + return self.parent._parse_expression(previous_named_sibling_node) @property def file(self) -> SourceFile: @@ -377,7 +377,7 @@ def file(self) -> SourceFile: """ if self._file is None: self._file = self.G.get_node(self.file_node_id) - return self._file + return self._file # type: ignore @property def filepath(self) -> str: @@ -391,7 +391,7 @@ def filepath(self) -> str: return self.file.file_path @reader - def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable]: + def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable[Self]]: """Returns a list of string literals within this node's source that match any of the given strings. @@ -400,19 +400,20 @@ def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = fuzzy_match (bool): If True, matches substrings within string literals. If False, only matches exact strings. Defaults to False. Returns: - list[Editable]: A list of Editable objects representing the matching string literals. + list[Editable[Self]]: A list of Editable objects representing the matching string literals. """ - matches = [] + matches: list[Editable[Self]] = [] for node in self.extended_nodes: matches.extend(node._find_string_literals(strings_to_match, fuzzy_match)) return matches @noapidoc @reader - def _find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable]: + def _find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> Sequence[Editable[Self]]: all_string_nodes = find_all_descendants(self.ts_node, type_names={"string"}) editables = [] for string_node in all_string_nodes: + assert string_node.text is not None full_string = string_node.text.strip(b'"').strip(b"'") if fuzzy_match: if not any([str_to_match.encode("utf-8") in full_string for str_to_match in strings_to_match]): @@ -461,7 +462,7 @@ def _replace(self, old: str, new: str, count: int = -1, is_regex: bool = False, if not is_regex: old = re.escape(old) - for match in re.finditer(old.encode("utf-8"), self.ts_node.text): + for match in re.finditer(old.encode("utf-8"), self.ts_node.text): # type: ignore start_byte = self.ts_node.start_byte + match.start() end_byte = self.ts_node.start_byte + match.end() t = EditTransaction( @@ -538,7 +539,7 @@ def _search(self, regex_pattern: str, include_strings: bool = True, include_comm pattern = re.compile(regex_pattern.encode("utf-8")) start_byte_offset = self.ts_node.byte_range[0] - for match in pattern.finditer(string): + for match in pattern.finditer(string): # type: ignore matching_byte_ranges.append((match.start() + start_byte_offset, match.end() + start_byte_offset)) matches: list[Editable] = [] @@ -738,7 +739,7 @@ def should_keep(node: TSNode): # Delete the node t = RemoveTransaction(removed_start_byte, removed_end_byte, self.file, priority=priority, exec_func=exec_func) if self.transaction_manager.add_transaction(t, dedupe=dedupe): - if exec_func: + if exec_func is not None: self.parent._removed_child() # If there are sibling nodes, delete the surrounding whitespace & formatting (commas) @@ -873,11 +874,13 @@ def variable_usages(self) -> list[Editable]: Editable corresponds to a TreeSitter node instance where the variable is referenced. """ - usages = [] + usages: Sequence[Editable[Self]] = [] identifiers = get_all_identifiers(self.ts_node) for identifier in identifiers: # Excludes function names parent = identifier.parent + if parent is None: + continue if parent.type in ["call", "call_expression"]: continue # Excludes local import statements @@ -899,7 +902,7 @@ def variable_usages(self) -> list[Editable]: return usages @reader - def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> list[Editable]: + def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> Sequence[Editable[Self]]: """Returns Editables for all TreeSitter nodes corresponding to instances of variable usage that matches the given variable name. @@ -917,6 +920,12 @@ def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> list[ else: return [usage for usage in self.variable_usages if var_name == usage.source] + @overload + def _parse_expression(self, node: TSNode, **kwargs) -> Expression[Self]: ... + + @overload + def _parse_expression(self, node: TSNode | None, **kwargs) -> Expression[Self] | None: ... + def _parse_expression(self, node: TSNode | None, **kwargs) -> Expression[Self] | None: return self.G.parser.parse_expression(node, self.file_node_id, self.G, self, **kwargs) diff --git a/src/codegen/sdk/types.py b/src/codegen/sdk/types.py index 496df934d..7f070aa0d 100644 --- a/src/codegen/sdk/types.py +++ b/src/codegen/sdk/types.py @@ -1 +1,3 @@ -type JSON = dict[str, JSON] | list[JSON] | str | int | float | bool | None +from typing import TypeAlias + +JSON: TypeAlias = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None diff --git a/src/codegen/sdk/typescript/statements/switch_case.py b/src/codegen/sdk/typescript/statements/switch_case.py index 01a49d72c..1e93fdc67 100644 --- a/src/codegen/sdk/typescript/statements/switch_case.py +++ b/src/codegen/sdk/typescript/statements/switch_case.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_graph import CodebaseGraph - from src.codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement + from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement @ts_apidoc diff --git a/tests/integration/codemod/conftest.py b/tests/integration/codemod/conftest.py index 20fd0b362..e6bb60271 100644 --- a/tests/integration/codemod/conftest.py +++ b/tests/integration/codemod/conftest.py @@ -2,6 +2,7 @@ import shutil from collections.abc import Generator from pathlib import Path +from typing import TYPE_CHECKING from unittest.mock import MagicMock import filelock @@ -13,12 +14,14 @@ from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags, ProjectConfig from codegen.sdk.core.codebase import Codebase -from codemods.codemod import Codemod from tests.shared.codemod.constants import DIFF_FILEPATH from tests.shared.codemod.models import BASE_PATH, BASE_TMP_DIR, VERIFIED_CODEMOD_DIFFS, CodemodMetadata, Repo, Size from tests.shared.codemod.test_discovery import find_codemod_test_cases, find_repos, find_verified_codemod_cases from tests.shared.utils.recursion import set_recursion_limit +if TYPE_CHECKING: + from codemods.codemod import Codemod + logger = logging.getLogger(__name__) ONLY_STORE_CHANGED_DIFFS = True @@ -201,7 +204,7 @@ def codemod(raw_codemod: type["Codemod"]): @pytest.fixture -def verified_codemod(codemod_metadata: CodemodMetadata, expected: Path) -> YieldFixture[Codemod]: +def verified_codemod(codemod_metadata: CodemodMetadata, expected: Path) -> YieldFixture["Codemod"]: # write the diff to the file diff_path = expected diff_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/uv.lock b/uv.lock index bc9495866..907483b28 100644 --- a/uv.lock +++ b/uv.lock @@ -401,7 +401,9 @@ dependencies = [ [package.optional-dependencies] types = [ { name = "types-networkx" }, + { name = "types-requests" }, { name = "types-tabulate" }, + { name = "types-toml" }, ] [package.dev-dependencies] @@ -489,7 +491,9 @@ requires-dist = [ { name = "tree-sitter-python", specifier = ">=0.23.4" }, { name = "tree-sitter-typescript", specifier = ">=0.23.2" }, { name = "types-networkx", marker = "extra == 'types'", specifier = ">=3.2.1.20240918" }, + { name = "types-requests", marker = "extra == 'types'", specifier = ">=2.32.0.20241016" }, { name = "types-tabulate", marker = "extra == 'types'", specifier = ">=0.9.0.20240106" }, + { name = "types-toml", marker = "extra == 'types'", specifier = ">=0.10.8.20240310" }, { name = "typing-extensions", specifier = ">=4.12.2" }, { name = "unidiff", specifier = ">=0.7.5" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.30.0" }, @@ -2575,6 +2579,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/c1/d73ff5900c6b462879039ac92f89424ad1eb544b1f6bd77f12f9c3013e20/types_networkx-3.4.2.20241227-py3-none-any.whl", hash = "sha256:adb0e3f0a16c1481a2cfa97772a0b925b220dcf857f0def1c5ab4c4f349e309d", size = 130194 }, ] +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fa/3c/4f2a430c01a22abd49a583b6b944173e39e7d01b688190a5618bd59a2e22/types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95", size = 18065 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/01/485b3026ff90e5190b5e24f1711522e06c79f4a56c8f4b95848ac072e20f/types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747", size = 15836 }, +] + [[package]] name = "types-setuptools" version = "75.8.0.20250110" @@ -2593,6 +2609,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/86/a9ebfd509cbe74471106dffed320e208c72537f9aeb0a55eaa6b1b5e4d17/types_tabulate-0.9.0.20241207-py3-none-any.whl", hash = "sha256:b8dad1343c2a8ba5861c5441370c3e35908edd234ff036d4298708a1d4cf8a85", size = 8307 }, ] +[[package]] +name = "types-toml" +version = "0.10.8.20240310" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/86/47/3e4c75042792bff8e90d7991aa5c51812cc668828cc6cce711e97f63a607/types-toml-0.10.8.20240310.tar.gz", hash = "sha256:3d41501302972436a6b8b239c850b26689657e25281b48ff0ec06345b8830331", size = 4392 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/a2/d32ab58c0b216912638b140ab2170ee4b8644067c293b170e19fba340ccc/types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d", size = 4777 }, +] + [[package]] name = "typing-extensions" version = "4.12.2"