diff --git a/pyproject.toml b/pyproject.toml index 08b79e0b1..bf4da2bb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ types = [ "types-requests>=2.32.0.20241016", "types-toml>=0.10.8.20240310", ] +lsp = ["pygls>=2.0.0a2", "lsprotocol==2024.0.0b1"] [tool.uv] cache-keys = [{ git = { commit = true, tags = true } }] dev-dependencies = [ @@ -149,11 +150,12 @@ dev-dependencies = [ "isort>=5.13.2", "emoji>=2.14.0", "pytest-benchmark[histogram]>=5.1.0", - "pytest-asyncio<1.0.0,>=0.21.1", + "pytest-asyncio>=0.21.1,<1.0.0", "loguru>=0.7.3", "httpx<0.28.2,>=0.28.1", "jupyterlab>=4.3.5", "modal>=0.73.25", + "pytest-lsp>=1.0.0b1", ] @@ -212,6 +214,8 @@ xfail_strict = true junit_duration_report = "call" junit_logging = "all" tmp_path_retention_policy = "failed" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" [build-system] requires = ["hatchling>=1.26.3", "hatch-vcs>=0.4.0", "setuptools-scm>=8.0.0"] build-backend = "hatchling.build" diff --git a/src/codegen/extensions/lsp/completion.py b/src/codegen/extensions/lsp/completion.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codegen/extensions/lsp/definition.py b/src/codegen/extensions/lsp/definition.py new file mode 100644 index 000000000..318587c0d --- /dev/null +++ b/src/codegen/extensions/lsp/definition.py @@ -0,0 +1,36 @@ +import logging + +from lsprotocol.types import Position + +from codegen.sdk.core.assignment import Assignment +from codegen.sdk.core.detached_symbols.function_call import FunctionCall +from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute +from codegen.sdk.core.expressions.expression import Expression +from codegen.sdk.core.expressions.name import Name +from codegen.sdk.core.interfaces.editable import Editable +from codegen.sdk.core.interfaces.has_name import HasName + +logger = logging.getLogger(__name__) + + +def go_to_definition(node: Editable | None, uri: str, position: Position) -> Editable | None: + if node is None or not isinstance(node, (Expression)): + logger.warning(f"No node found at {uri}:{position}") + return None + if isinstance(node, Name) and isinstance(node.parent, ChainedAttribute) and node.parent.attribute == node: + node = node.parent + if isinstance(node.parent, FunctionCall) and node.parent.get_name() == node: + node = node.parent + logger.info(f"Resolving definition for {node}") + if isinstance(node, FunctionCall): + resolved = node.function_definition + else: + resolved = node.resolved_value + if resolved is None: + logger.warning(f"No resolved value found for {node.name} at {uri}:{position}") + return None + if isinstance(resolved, (HasName,)): + resolved = resolved.get_name() + if isinstance(resolved.parent, Assignment) and resolved.parent.value == resolved: + resolved = resolved.parent.get_name() + return resolved diff --git a/src/codegen/extensions/lsp/document_symbol.py b/src/codegen/extensions/lsp/document_symbol.py new file mode 100644 index 000000000..01000755a --- /dev/null +++ b/src/codegen/extensions/lsp/document_symbol.py @@ -0,0 +1,26 @@ +from lsprotocol.types import DocumentSymbol + +from codegen.extensions.lsp.kind import get_kind +from codegen.extensions.lsp.range import get_range +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.interfaces.editable import Editable +from codegen.sdk.extensions.sort import sort_editables + + +def get_document_symbol(node: Editable) -> DocumentSymbol: + children = [] + nodes = [] + if isinstance(node, Class): + nodes.extend(node.methods) + nodes.extend(node.attributes) + nodes.extend(node.nested_classes) + nodes = sort_editables(nodes) + for child in nodes: + children.append(get_document_symbol(child)) + return DocumentSymbol( + name=node.name, + kind=get_kind(node), + range=get_range(node), + selection_range=get_range(node.get_name()), + children=children, + ) diff --git a/src/codegen/extensions/lsp/io.py b/src/codegen/extensions/lsp/io.py new file mode 100644 index 000000000..cc6152e24 --- /dev/null +++ b/src/codegen/extensions/lsp/io.py @@ -0,0 +1,71 @@ +import logging +from pathlib import Path + +from lsprotocol import types +from lsprotocol.types import Position, Range, TextEdit +from pygls.workspace import TextDocument, Workspace + +from codegen.sdk.codebase.io.file_io import FileIO +from codegen.sdk.codebase.io.io import IO + +logger = logging.getLogger(__name__) + + +class LSPIO(IO): + base_io: FileIO + workspace: Workspace + changes: dict[str, TextEdit] = {} + + def __init__(self, workspace: Workspace): + self.workspace = workspace + self.base_io = FileIO() + + def _get_doc(self, path: Path) -> TextDocument | None: + uri = path.as_uri() + logger.info(f"Getting document for {uri}") + return self.workspace.get_text_document(uri) + + def read_bytes(self, path: Path) -> bytes: + if self.changes.get(path.as_uri()): + return self.changes[path.as_uri()].new_text.encode("utf-8") + if doc := self._get_doc(path): + return doc.source.encode("utf-8") + return self.base_io.read_bytes(path) + + def write_bytes(self, path: Path, content: bytes) -> None: + logger.info(f"Writing bytes to {path}") + start = Position(line=0, character=0) + if doc := self._get_doc(path): + end = Position(line=len(doc.source), character=len(doc.source)) + else: + end = Position(line=0, character=0) + self.changes[path.as_uri()] = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8")) + + def save_files(self, files: set[Path] | None = None) -> None: + self.base_io.save_files(files) + + def check_changes(self) -> None: + self.base_io.check_changes() + + def delete_file(self, path: Path) -> None: + self.base_io.delete_file(path) + + def file_exists(self, path: Path) -> bool: + if doc := self._get_doc(path): + try: + doc.source + except FileNotFoundError: + return False + return True + return self.base_io.file_exists(path) + + def untrack_file(self, path: Path) -> None: + self.base_io.untrack_file(path) + + def get_document_changes(self) -> list[types.TextDocumentEdit]: + ret = [] + for uri, change in self.changes.items(): + id = types.OptionalVersionedTextDocumentIdentifier(uri=uri) + ret.append(types.TextDocumentEdit(text_document=id, edits=[change])) + self.changes = {} + return ret diff --git a/src/codegen/extensions/lsp/kind.py b/src/codegen/extensions/lsp/kind.py new file mode 100644 index 000000000..609885164 --- /dev/null +++ b/src/codegen/extensions/lsp/kind.py @@ -0,0 +1,31 @@ +from lsprotocol.types import SymbolKind + +from codegen.sdk.core.assignment import Assignment +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.file import File +from codegen.sdk.core.function import Function +from codegen.sdk.core.interface import Interface +from codegen.sdk.core.interfaces.editable import Editable +from codegen.sdk.core.statements.attribute import Attribute +from codegen.sdk.typescript.namespace import TSNamespace + +kinds = { + File: SymbolKind.File, + Class: SymbolKind.Class, + Function: SymbolKind.Function, + Assignment: SymbolKind.Variable, + Interface: SymbolKind.Interface, + TSNamespace: SymbolKind.Namespace, + Attribute: SymbolKind.Variable, +} + + +def get_kind(node: Editable) -> SymbolKind: + if isinstance(node, Function): + if node.is_method: + return SymbolKind.Method + for kind in kinds: + if isinstance(node, kind): + return kinds[kind] + msg = f"No kind found for {node}, {type(node)}" + raise ValueError(msg) diff --git a/src/codegen/extensions/lsp/lsp.py b/src/codegen/extensions/lsp/lsp.py new file mode 100644 index 000000000..2a7cd7aec --- /dev/null +++ b/src/codegen/extensions/lsp/lsp.py @@ -0,0 +1,109 @@ +import logging + +from lsprotocol import types + +import codegen +from codegen.extensions.lsp.definition import go_to_definition +from codegen.extensions.lsp.document_symbol import get_document_symbol +from codegen.extensions.lsp.protocol import CodegenLanguageServerProtocol +from codegen.extensions.lsp.range import get_range +from codegen.extensions.lsp.server import CodegenLanguageServer +from codegen.extensions.lsp.utils import get_path +from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite +from codegen.sdk.core.file import SourceFile + +version = getattr(codegen, "__version__", "v0.1") +server = CodegenLanguageServer("codegen", version, protocol_cls=CodegenLanguageServerProtocol) +logger = logging.getLogger(__name__) + + +@server.feature(types.TEXT_DOCUMENT_DID_OPEN) +def did_open(server: CodegenLanguageServer, params: types.DidOpenTextDocumentParams) -> None: + """Handle document open notification.""" + logger.info(f"Document opened: {params.text_document.uri}") + # The document is automatically added to the workspace by pygls + # We can perform any additional processing here if needed + path = get_path(params.text_document.uri) + file = server.codebase.get_file(str(path), optional=True) + if not isinstance(file, SourceFile) and path.suffix in server.codebase.ctx.extensions: + sync = DiffLite(change_type=ChangeType.Added, path=path) + server.codebase.ctx.apply_diffs([sync]) + + +@server.feature(types.TEXT_DOCUMENT_DID_CHANGE) +def did_change(server: CodegenLanguageServer, params: types.DidChangeTextDocumentParams) -> None: + """Handle document change notification.""" + logger.info(f"Document changed: {params.text_document.uri}") + # The document is automatically updated in the workspace by pygls + # We can perform any additional processing here if needed + path = get_path(params.text_document.uri) + sync = DiffLite(change_type=ChangeType.Modified, path=path) + server.codebase.ctx.apply_diffs([sync]) + + +@server.feature(types.WORKSPACE_TEXT_DOCUMENT_CONTENT) +def workspace_text_document_content(server: CodegenLanguageServer, params: types.TextDocumentContentParams) -> types.TextDocumentContentResult: + """Handle workspace text document content notification.""" + logger.debug(f"Workspace text document content: {params.uri}") + path = get_path(params.uri) + if not server.io.file_exists(path): + logger.warning(f"File does not exist: {path}") + return types.TextDocumentContentResult( + text="", + ) + content = server.io.read_text(path) + return types.TextDocumentContentResult( + text=content, + ) + + +@server.feature(types.TEXT_DOCUMENT_DID_CLOSE) +def did_close(server: CodegenLanguageServer, params: types.DidCloseTextDocumentParams) -> None: + """Handle document close notification.""" + logger.info(f"Document closed: {params.text_document.uri}") + # The document is automatically removed from the workspace by pygls + # We can perform any additional cleanup here if needed + + +@server.feature( + types.TEXT_DOCUMENT_RENAME, +) +def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.RenameResult: + symbol = server.get_symbol(params.text_document.uri, params.position) + if symbol is None: + logger.warning(f"No symbol found at {params.text_document.uri}:{params.position}") + return + logger.info(f"Renaming symbol {symbol.name} to {params.new_name}") + symbol.rename(params.new_name) + server.codebase.commit() + return types.WorkspaceEdit( + document_changes=server.io.get_document_changes(), + ) + + +@server.feature( + types.TEXT_DOCUMENT_DOCUMENT_SYMBOL, +) +def document_symbol(server: CodegenLanguageServer, params: types.DocumentSymbolParams) -> types.DocumentSymbolResult: + file = server.get_file(params.text_document.uri) + symbols = [] + for symbol in file.symbols: + symbols.append(get_document_symbol(symbol)) + return symbols + + +@server.feature( + types.TEXT_DOCUMENT_DEFINITION, +) +def definition(server: CodegenLanguageServer, params: types.DefinitionParams): + node = server.get_node_under_cursor(params.text_document.uri, params.position) + resolved = go_to_definition(node, params.text_document.uri, params.position) + return types.Location( + uri=resolved.file.path.as_uri(), + range=get_range(resolved), + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + server.start_io() diff --git a/src/codegen/extensions/lsp/protocol.py b/src/codegen/extensions/lsp/protocol.py new file mode 100644 index 000000000..78e9bd945 --- /dev/null +++ b/src/codegen/extensions/lsp/protocol.py @@ -0,0 +1,51 @@ +import os +import threading +from pathlib import Path +from typing import TYPE_CHECKING + +from lsprotocol.types import INITIALIZE, INITIALIZED, InitializedParams, InitializeParams, InitializeResult +from pygls.protocol import LanguageServerProtocol, lsp_method + +from codegen.extensions.lsp.io import LSPIO +from codegen.extensions.lsp.utils import get_path +from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags +from codegen.sdk.core.codebase import Codebase + +if TYPE_CHECKING: + from codegen.extensions.lsp.server import CodegenLanguageServer + + +class CodegenLanguageServerProtocol(LanguageServerProtocol): + _server: "CodegenLanguageServer" + + def _init_codebase(self, params: InitializeParams) -> None: + if params.root_path: + root = Path(params.root_path) + elif params.root_uri: + root = get_path(params.root_uri) + else: + root = os.getcwd() + config = CodebaseConfig(feature_flags=GSFeatureFlags(full_range_index=True)) + io = LSPIO(self.workspace) + self._server.codebase = Codebase(repo_path=str(root), config=config, io=io) + self._server.io = io + + @lsp_method(INITIALIZE) + def lsp_initialize(self, params: InitializeParams) -> InitializeResult: + if params.root_path: + root = Path(params.root_path) + elif params.root_uri: + root = get_path(params.root_uri) + else: + root = os.getcwd() + config = CodebaseConfig(feature_flags=GSFeatureFlags(full_range_index=True)) + ret = super().lsp_initialize(params) + + self._worker = threading.Thread(target=self._init_codebase, args=(params,)) + self._worker.start() + return ret + + @lsp_method(INITIALIZED) + def lsp_initialized(self, params: InitializedParams) -> None: + self._worker.join() + super().lsp_initialized(params) diff --git a/src/codegen/extensions/lsp/range.py b/src/codegen/extensions/lsp/range.py new file mode 100644 index 000000000..9762e9d00 --- /dev/null +++ b/src/codegen/extensions/lsp/range.py @@ -0,0 +1,32 @@ +import tree_sitter +from lsprotocol.types import Position, Range +from pygls.workspace import TextDocument + +from codegen.sdk.core.interfaces.editable import Editable + + +def get_range(node: Editable) -> Range: + start_point = node.start_point + end_point = node.end_point + for extended_node in node.extended_nodes: + if extended_node.start_point.row < start_point.row: + start_point = extended_node.start_point + if extended_node.end_point.row > end_point.row: + end_point = extended_node.end_point + return Range( + start=Position(line=start_point.row, character=start_point.column), + end=Position(line=end_point.row, character=end_point.column), + ) + + +def get_tree_sitter_range(range: Range, document: TextDocument) -> tree_sitter.Range: + start_pos = tree_sitter.Point(row=range.start.line, column=range.start.character) + end_pos = tree_sitter.Point(row=range.end.line, column=range.end.character) + start_byte = document.offset_at_position(range.start) + end_byte = document.offset_at_position(range.end) + return tree_sitter.Range( + start_point=start_pos, + end_point=end_pos, + start_byte=start_byte, + end_byte=end_byte, + ) diff --git a/src/codegen/extensions/lsp/server.py b/src/codegen/extensions/lsp/server.py new file mode 100644 index 000000000..bec090f89 --- /dev/null +++ b/src/codegen/extensions/lsp/server.py @@ -0,0 +1,55 @@ +import logging +from typing import Any, Optional + +from lsprotocol.types import Position, Range +from pygls.lsp.server import LanguageServer + +from codegen.extensions.lsp.io import LSPIO +from codegen.extensions.lsp.range import get_tree_sitter_range +from codegen.extensions.lsp.utils import get_path +from codegen.sdk.codebase.flagging.code_flag import Symbol +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.core.file import File, SourceFile +from codegen.sdk.core.interfaces.editable import Editable + +logger = logging.getLogger(__name__) + + +class CodegenLanguageServer(LanguageServer): + codebase: Optional[Codebase] + io: Optional[LSPIO] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def get_file(self, uri: str) -> SourceFile | File: + path = get_path(uri) + return self.codebase.get_file(str(path)) + + def get_symbol(self, uri: str, position: Position) -> Symbol | None: + node = self.get_node_under_cursor(uri, position) + if node is None: + return None + return node.parent_symbol + + def get_node_under_cursor(self, uri: str, position: Position) -> Editable | None: + file = self.get_file(uri) + resolved_uri = file.path.absolute().as_uri() + logger.info(f"Getting node under cursor for {resolved_uri} at {position}") + document = self.workspace.get_text_document(resolved_uri) + candidates = [] + target_byte = document.offset_at_position(position) + for node in file._range_index.nodes: + if node.start_byte <= target_byte and node.end_byte >= target_byte: + candidates.append(node) + if not candidates: + return None + return min(candidates, key=lambda node: abs(node.end_byte - node.start_byte)) + + def get_node_for_range(self, uri: str, range: Range) -> Editable | None: + file = self.get_file(uri) + document = self.workspace.get_text_document(uri) + ts_range = get_tree_sitter_range(range, document) + for node in file._range_index.get_all_for_range(ts_range): + return node + return None diff --git a/src/codegen/extensions/lsp/utils.py b/src/codegen/extensions/lsp/utils.py new file mode 100644 index 000000000..3dce5f751 --- /dev/null +++ b/src/codegen/extensions/lsp/utils.py @@ -0,0 +1,7 @@ +from pathlib import Path + +from pygls.uris import to_fs_path + + +def get_path(uri: str) -> Path: + return Path(to_fs_path(uri)).absolute() diff --git a/src/codegen/sdk/codebase/codebase_context.py b/src/codegen/sdk/codebase/codebase_context.py index 6214c4227..806225f96 100644 --- a/src/codegen/sdk/codebase/codebase_context.py +++ b/src/codegen/sdk/codebase/codebase_context.py @@ -115,6 +115,7 @@ def __init__( self, projects: list[ProjectConfig], config: CodebaseConfig = DefaultConfig, + io: IO | None = None, ) -> None: """Initializes codebase graph and TransactionManager""" from codegen.sdk.core.parser import Parser @@ -134,7 +135,7 @@ def __init__( # =====[ __init__ attributes ]===== self.projects = projects - self.io = FileIO() + self.io = io or FileIO() context = projects[0] self.node_classes = get_node_classes(context.programming_language) self.config = config diff --git a/src/codegen/sdk/codebase/io/io.py b/src/codegen/sdk/codebase/io/io.py index 3321f072b..710474aab 100644 --- a/src/codegen/sdk/codebase/io/io.py +++ b/src/codegen/sdk/codebase/io/io.py @@ -18,10 +18,6 @@ def write_file(self, path: Path, content: str | bytes | None) -> None: def write_text(self, path: Path, content: str) -> None: self.write_bytes(path, content.encode("utf-8")) - @abstractmethod - def untrack_file(self, path: Path) -> None: - pass - @abstractmethod def write_bytes(self, path: Path, content: bytes) -> None: pass diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index ce7e33d0f..d3b572043 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -35,6 +35,7 @@ from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.enums import FlagKwargs from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.io.io import IO from codegen.sdk.codebase.span import Span from codegen.sdk.core.assignment import Assignment from codegen.sdk.core.class_definition import Class @@ -129,6 +130,7 @@ def __init__( programming_language: None = None, projects: list[ProjectConfig] | ProjectConfig, config: CodebaseConfig = DefaultConfig, + io: IO | None = None, ) -> None: ... @overload @@ -139,6 +141,7 @@ def __init__( programming_language: ProgrammingLanguage | None = None, projects: None = None, config: CodebaseConfig = DefaultConfig, + io: IO | None = None, ) -> None: ... def __init__( @@ -148,6 +151,7 @@ def __init__( programming_language: ProgrammingLanguage | None = None, projects: list[ProjectConfig] | ProjectConfig | None = None, config: CodebaseConfig = DefaultConfig, + io: IO | None = None, ) -> None: # Sanity check inputs if repo_path is not None and projects is not None: @@ -177,7 +181,7 @@ def __init__( self._op = main_project.repo_operator self.viz = VisualizationManager(op=self._op) self.repo_path = Path(self._op.repo_path) - self.ctx = CodebaseContext(projects, config=config) + self.ctx = CodebaseContext(projects, config=config, io=io) self.console = Console(record=True, soft_wrap=True) @noapidoc @@ -505,6 +509,8 @@ def get_file_from_path(path: Path) -> File | None: if file is not None: return file absolute_path = self.ctx.to_absolute(filepath) + if absolute_path.suffix in self.ctx.extensions: + return None if self.ctx.io.file_exists(absolute_path): return get_file_from_path(absolute_path) elif ignore_case: diff --git a/tests/unit/codegen/sdk/conftest.py b/tests/unit/codegen/conftest.py similarity index 68% rename from tests/unit/codegen/sdk/conftest.py rename to tests/unit/codegen/conftest.py index 03162cc4a..3877426ae 100644 --- a/tests/unit/codegen/sdk/conftest.py +++ b/tests/unit/codegen/conftest.py @@ -28,14 +28,17 @@ def codebase(tmp_path, original: dict[str, str], programming_language: Programmi @pytest.fixture def assert_expected(expected: dict[str, str], tmp_path): - def assert_expected(codebase: Codebase) -> None: - codebase.commit() + def assert_expected(codebase: Codebase, check_codebase: bool = True) -> None: + if check_codebase: + codebase.commit() for file in expected: assert tmp_path.joinpath(file).exists() assert tmp_path.joinpath(file).read_text() == expected[file] - assert codebase.get_file(file).content.strip() == expected[file].strip() - for file in codebase.files: - if file.file.path.exists(): - assert file.filepath in expected + if check_codebase: + assert codebase.get_file(file).content.strip() == expected[file].strip() + if check_codebase: + for file in codebase.files: + if file.file.path.exists(): + assert file.filepath in expected return assert_expected diff --git a/tests/unit/codegen/extensions/lsp/conftest.py b/tests/unit/codegen/extensions/lsp/conftest.py new file mode 100644 index 000000000..099329327 --- /dev/null +++ b/tests/unit/codegen/extensions/lsp/conftest.py @@ -0,0 +1,34 @@ +import sys + +import pytest_lsp +from lsprotocol.types import ( + InitializeParams, +) +from pytest_lsp import ( + ClientServerConfig, + LanguageClient, + client_capabilities, +) + +from codegen.sdk.core.codebase import Codebase + + +@pytest_lsp.fixture( + config=ClientServerConfig( + server_command=[sys.executable, "-m", "codegen.extensions.lsp.lsp"], + ), +) +async def client(lsp_client: LanguageClient, codebase: Codebase): + # Setup + response = await lsp_client.initialize_session( + InitializeParams( + capabilities=client_capabilities("visual-studio-code"), + root_uri="file://" + str(codebase.repo_path.resolve()), + root_path=str(codebase.repo_path.resolve()), + ) + ) + + yield + + # Teardown + await lsp_client.shutdown_session() diff --git a/tests/unit/codegen/extensions/lsp/test_definition.py b/tests/unit/codegen/extensions/lsp/test_definition.py new file mode 100644 index 000000000..955e9041d --- /dev/null +++ b/tests/unit/codegen/extensions/lsp/test_definition.py @@ -0,0 +1,144 @@ +import pytest +from lsprotocol.types import ( + DefinitionParams, + Location, + Position, + Range, + TextDocumentIdentifier, +) +from pytest_lsp import LanguageClient + +from codegen.sdk.core.codebase import Codebase + + +@pytest.mark.parametrize( + "original, position, expected_location", + [ + ( + { + "test.py": """ +def example_function(): + pass + +def main(): + example_function() + """.strip(), + }, + Position(line=4, character=4), # Position of example_function call + Location( + uri="file://{workspaceFolder}/test.py", + range=Range( + start=Position(line=0, character=4), + end=Position(line=0, character=20), + ), + ), + ), + ( + { + "test.py": """ +class MyClass: + def method(self): + pass + +obj = MyClass() +obj.method() + """.strip(), + }, + Position(line=5, character=4), # Position of method call + Location( + uri="file://{workspaceFolder}/test.py", + range=Range( + start=Position(line=1, character=8), + end=Position(line=1, character=14), + ), + ), + ), + ( + { + "module/utils.py": """ +def utility_function(): + pass + """.strip(), + "test.py": """ +from module.utils import utility_function + +def main(): + utility_function() + """.strip(), + }, + Position(line=3, character=4), # Position of utility_function call in test.py + Location( + uri="file://{workspaceFolder}/module/utils.py", + range=Range( + start=Position(line=0, character=4), + end=Position(line=0, character=20), # Adjusted to end before () + ), + ), + ), + ( + { + "models.py": """ +class DatabaseModel: + def save(self): + pass + """.strip(), + "test.py": """ +from models import DatabaseModel + +def main(): + model = DatabaseModel() + model.save() + """.strip(), + }, + Position(line=4, character=10), # Position of save() call in test.py + Location( + uri="file://{workspaceFolder}/models.py", + range=Range( + start=Position(line=1, character=8), + end=Position(line=1, character=12), # Adjusted to end before () + ), + ), + ), + ( + { + "module/__init__.py": """ +from .constants import DEFAULT_TIMEOUT + """.strip(), + "module/constants.py": """ +DEFAULT_TIMEOUT = 30 + """.strip(), + "test.py": """ +from module import DEFAULT_TIMEOUT + +def main(): + timeout = DEFAULT_TIMEOUT + """.strip(), + }, + Position(line=3, character=14), # Position of DEFAULT_TIMEOUT reference in test.py + Location( + uri="file://{workspaceFolder}/module/constants.py", + range=Range( + start=Position(line=0, character=0), + end=Position(line=0, character=15), # Adjusted to end before = + ), + ), + ), + ], +) +async def test_go_to_definition( + client: LanguageClient, + codebase: Codebase, + original: dict, + position: Position, + expected_location: Location, +): + result = await client.text_document_definition_async( + params=DefinitionParams( + text_document=TextDocumentIdentifier(uri="file://test.py"), + position=position, + ) + ) + + assert isinstance(result, Location) + assert result.uri == expected_location.uri.format(workspaceFolder=str(codebase.repo_path)) + assert result.range == expected_location.range diff --git a/tests/unit/codegen/extensions/lsp/test_document_symbols.py b/tests/unit/codegen/extensions/lsp/test_document_symbols.py new file mode 100644 index 000000000..4340c2594 --- /dev/null +++ b/tests/unit/codegen/extensions/lsp/test_document_symbols.py @@ -0,0 +1,239 @@ +from collections.abc import Sequence +from typing import cast + +import pytest +from lsprotocol.types import ( + DocumentSymbol, + DocumentSymbolParams, + Position, + Range, + SymbolKind, + TextDocumentIdentifier, +) +from pytest_lsp import LanguageClient + +from codegen.sdk.core.codebase import Codebase + + +@pytest.mark.parametrize( + "original, expected_symbols", + [ + ( + { + "test.py": """ +class TestClass: + def test_method(self): + pass + +def top_level_function(): + pass + """.strip(), + }, + [ + DocumentSymbol( + name="TestClass", + kind=SymbolKind.Class, + range=Range( + start=Position(line=0, character=0), + end=Position(line=2, character=12), + ), + selection_range=Range( + start=Position(line=0, character=6), + end=Position(line=0, character=15), + ), + children=[ + DocumentSymbol( + name="test_method", + kind=SymbolKind.Method, + range=Range( + start=Position(line=1, character=4), + end=Position(line=2, character=12), + ), + selection_range=Range( + start=Position(line=1, character=8), + end=Position(line=1, character=19), + ), + children=[], + ) + ], + ), + DocumentSymbol( + name="top_level_function", + kind=SymbolKind.Function, + range=Range( + start=Position(line=4, character=0), + end=Position(line=5, character=8), + ), + selection_range=Range( + start=Position(line=4, character=4), + end=Position(line=4, character=22), + ), + children=[], + ), + ], + ), + ( + { + "test.py": """ +@decorator +class OuterClass: + class InnerClass: + @property + def inner_method(self): + pass + + async def outer_method(self): + pass + +@decorator +async def async_function(): + pass + """.strip(), + }, + [ + DocumentSymbol( + name="OuterClass", + kind=SymbolKind.Class, + range=Range( + start=Position(line=0, character=0), + end=Position(line=8, character=12), + ), + selection_range=Range( + start=Position(line=1, character=6), + end=Position(line=1, character=16), + ), + children=[ + DocumentSymbol( + name="InnerClass", + kind=SymbolKind.Class, + range=Range( + start=Position(line=2, character=4), + end=Position(line=5, character=16), + ), + selection_range=Range( + start=Position(line=2, character=10), + end=Position(line=2, character=20), + ), + children=[ + DocumentSymbol( + name="inner_method", + kind=SymbolKind.Method, + range=Range( + start=Position(line=3, character=8), + end=Position(line=5, character=16), + ), + selection_range=Range( + start=Position(line=4, character=12), + end=Position(line=4, character=24), + ), + children=[], + ) + ], + ), + DocumentSymbol( + name="outer_method", + kind=SymbolKind.Method, + range=Range( + start=Position(line=7, character=4), + end=Position(line=8, character=12), + ), + selection_range=Range( + start=Position(line=7, character=14), + end=Position(line=7, character=26), + ), + children=[], + ), + ], + ), + DocumentSymbol( + name="async_function", + kind=SymbolKind.Function, + range=Range( + start=Position(line=10, character=0), + end=Position(line=12, character=8), + ), + selection_range=Range( + start=Position(line=11, character=10), + end=Position(line=11, character=24), + ), + children=[], + ), + ], + ), + ( + { + "test.py": """ +def function_with_args(arg1: str, arg2: int = 42): + pass + +class ClassWithDocstring: + \"\"\"This is a docstring.\"\"\" + def method_with_docstring(self): + \"\"\"Method docstring.\"\"\" + pass + """.strip(), + }, + [ + DocumentSymbol( + name="function_with_args", + kind=SymbolKind.Function, + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=8), + ), + selection_range=Range( + start=Position(line=0, character=4), + end=Position(line=0, character=22), + ), + children=[], + ), + DocumentSymbol( + name="ClassWithDocstring", + kind=SymbolKind.Class, + range=Range( + start=Position(line=3, character=0), + end=Position(line=7, character=12), + ), + selection_range=Range( + start=Position(line=3, character=6), + end=Position(line=3, character=24), + ), + children=[ + DocumentSymbol( + name="method_with_docstring", + kind=SymbolKind.Method, + range=Range( + start=Position(line=5, character=4), + end=Position(line=7, character=12), + ), + selection_range=Range( + start=Position(line=5, character=8), + end=Position(line=5, character=29), + ), + children=[], + ), + ], + ), + ], + ), + ], +) +async def test_document_symbols( + client: LanguageClient, + codebase: Codebase, + original: dict, + expected_symbols: list[DocumentSymbol], +): + result = await client.text_document_document_symbol_async(params=DocumentSymbolParams(text_document=TextDocumentIdentifier(uri="file://test.py"))) + + assert result is not None + symbols = cast(Sequence[DocumentSymbol], result) + assert len(symbols) == len(expected_symbols) + for actual, expected in zip(symbols, expected_symbols): + assert actual.name == expected.name + assert actual.kind == expected.kind + assert actual.range == expected.range + assert actual.selection_range == expected.selection_range + assert actual.children == expected.children + assert actual == expected + assert symbols == expected_symbols diff --git a/tests/unit/codegen/extensions/lsp/test_rename.py b/tests/unit/codegen/extensions/lsp/test_rename.py new file mode 100644 index 000000000..7c75ea8b9 --- /dev/null +++ b/tests/unit/codegen/extensions/lsp/test_rename.py @@ -0,0 +1,42 @@ +import pytest +from lsprotocol.types import ( + Position, + RenameParams, + TextDocumentIdentifier, +) +from pytest_lsp import ( + LanguageClient, +) + +from codegen.sdk.core.codebase import Codebase + + +@pytest.mark.parametrize( + "original, expected", + [ + ( + { + "test.py": """ +def hello(): + pass + """.strip(), + }, + { + "test.py": """ +def world(): + pass + """.strip(), + }, + ) + ], +) +async def test_rename(client: LanguageClient, codebase: Codebase, assert_expected): + result = await client.text_document_rename_async( + params=RenameParams( + position=Position(line=0, character=0), + text_document=TextDocumentIdentifier(uri="file://test.py"), + new_name="world", + ) + ) + + assert_expected(codebase, check_codebase=False) diff --git a/tests/unit/codegen/extensions/lsp/test_workspace_sync.py b/tests/unit/codegen/extensions/lsp/test_workspace_sync.py new file mode 100644 index 000000000..20839a0ef --- /dev/null +++ b/tests/unit/codegen/extensions/lsp/test_workspace_sync.py @@ -0,0 +1,250 @@ +import pytest +from lsprotocol.types import ( + DidChangeTextDocumentParams, + DidCloseTextDocumentParams, + DidOpenTextDocumentParams, + Position, + Range, + RenameParams, + TextDocumentContentChangeEvent, + TextDocumentContentChangePartial, + TextDocumentContentParams, + TextDocumentIdentifier, + TextDocumentItem, + VersionedTextDocumentIdentifier, +) +from pytest_lsp import LanguageClient + +from codegen.sdk.core.codebase import Codebase + + +@pytest.fixture() +def document_uri(codebase: Codebase, request) -> str: + return request.param.format(workspaceFolder=str(codebase.repo_path)) + + +@pytest.mark.parametrize( + "original, document_uri", + [ + ( + { + "test.py": """ +def example_function(): + pass + """.strip(), + }, + "file://{workspaceFolder}/test.py", + ), + ], + indirect=True, +) +async def test_did_open( + client: LanguageClient, + codebase: Codebase, + original: dict, + document_uri: str, +): + # Send didOpen notification + client.text_document_did_open( + params=DidOpenTextDocumentParams( + text_document=TextDocumentItem( + uri=document_uri, + language_id="python", + version=1, + text=original["test.py"], + ) + ) + ) + + # Verify the file is in the workspace + document = await client.workspace_text_document_content_async(TextDocumentContentParams(uri=document_uri)) + assert document is not None + assert document.text == original["test.py"] + + +@pytest.mark.parametrize( + "original, document_uri, changes, expected_text", + [ + ( + { + "test.py": """ +def example_function(): + pass + """.strip(), + }, + "file://{workspaceFolder}/test.py", + [ + TextDocumentContentChangePartial( + range=Range( + start=Position(line=1, character=4), + end=Position(line=1, character=8), + ), + text="return True", + ), + ], + """ +def example_function(): + return True + """.strip(), + ), + ], + indirect=["document_uri", "original"], +) +async def test_did_change( + client: LanguageClient, + codebase: Codebase, + original: dict, + document_uri: str, + changes: list[TextDocumentContentChangeEvent], + expected_text: str, +): + # First open the document + client.text_document_did_open( + params=DidOpenTextDocumentParams( + text_document=TextDocumentItem( + uri=document_uri, + language_id="python", + version=1, + text=original["test.py"], + ) + ) + ) + + # Send didChange notification + client.text_document_did_change( + params=DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + uri=document_uri, + version=2, + ), + content_changes=changes, + ) + ) + + # Verify the changes were applied + document = await client.workspace_text_document_content_async(TextDocumentContentParams(uri=document_uri)) + assert document is not None + assert document.text == expected_text + + +@pytest.mark.parametrize( + "original, document_uri", + [ + ( + { + "test.py": """ +def example_function(): + pass + """.strip(), + }, + "file://{worskpaceFolder}test.py", + ), + ], +) +async def test_did_close( + client: LanguageClient, + codebase: Codebase, + original: dict, + document_uri: str, +): + document_uri = document_uri.format(worskpaceFolder=str(codebase.repo_path)) + # First open the document + client.text_document_did_open( + params=DidOpenTextDocumentParams( + text_document=TextDocumentItem( + uri=document_uri, + language_id="python", + version=1, + text=original["test.py"], + ) + ) + ) + + # Send didClose notification + client.text_document_did_close(params=DidCloseTextDocumentParams(text_document=TextDocumentIdentifier(uri=document_uri))) + + # Verify the document is removed from the workspace + document = await client.workspace_text_document_content_async(TextDocumentContentParams(uri=document_uri)) + assert document.text == "" + + +@pytest.mark.parametrize( + "original, document_uri, position, new_name, expected_text", + [ + ( + { + "test.py": """ +def example_function(): + pass + +def main(): + example_function() + """.strip(), + }, + "file://{workspaceFolder}/test.py", + Position(line=0, character=0), # Position of 'example_function' + "renamed_function", + """ +def renamed_function(): + pass # modified + +def main(): + renamed_function() + """.strip(), + ), + ], + indirect=["document_uri", "original"], +) +async def test_rename_after_sync( + client: LanguageClient, + codebase: Codebase, + original: dict, + document_uri: str, + position: Position, + new_name: str, + expected_text: str, +): + # First open the document + client.text_document_did_open( + params=DidOpenTextDocumentParams( + text_document=TextDocumentItem( + uri=document_uri, + language_id="python", + version=1, + text=original["test.py"], + ) + ) + ) + + # Make a change to the document + client.text_document_did_change( + params=DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + uri=document_uri, + version=2, + ), + content_changes=[ + TextDocumentContentChangePartial( + range=Range( + start=Position(line=1, character=4), + end=Position(line=1, character=8), + ), + text="pass # modified", + ), + ], + ) + ) + + # Perform rename operation + result = await client.text_document_rename_async( + params=RenameParams( + text_document=TextDocumentIdentifier(uri=document_uri), + position=position, + new_name=new_name, + ) + ) + + # Verify the rename was successful + document = await client.workspace_text_document_content_async(TextDocumentContentParams(uri=document_uri)) + assert document is not None + assert document.text == expected_text diff --git a/uv.lock b/uv.lock index a2649cda7..5094f7eaf 100644 --- a/uv.lock +++ b/uv.lock @@ -604,6 +604,10 @@ dependencies = [ ] [package.optional-dependencies] +lsp = [ + { name = "lsprotocol" }, + { name = "pygls" }, +] types = [ { name = "types-networkx" }, { name = "types-requests" }, @@ -638,6 +642,7 @@ dev = [ { name = "pytest-asyncio" }, { name = "pytest-benchmark", extra = ["histogram"] }, { name = "pytest-cov" }, + { name = "pytest-lsp" }, { name = "pytest-mock" }, { name = "pytest-timeout" }, { name = "pytest-xdist" }, @@ -669,6 +674,7 @@ requires-dist = [ { name = "langchain-core" }, { name = "langchain-openai" }, { name = "lazy-object-proxy", specifier = ">=0.0.0" }, + { name = "lsprotocol", marker = "extra == 'lsp'", specifier = "==2024.0.0b1" }, { name = "mini-racer", specifier = ">=0.12.4" }, { name = "networkx", specifier = ">=3.4.1" }, { name = "numpy", specifier = ">=2.2.2" }, @@ -682,6 +688,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.0.0" }, { name = "pygit2", specifier = ">=1.16.0" }, { name = "pygithub", specifier = "==2.5.0" }, + { name = "pygls", marker = "extra == 'lsp'", specifier = ">=2.0.0a2" }, { name = "pyinstrument", specifier = ">=5.0.0" }, { name = "pyjson5", specifier = "==1.6.8" }, { name = "pyright", specifier = ">=1.1.372,<2.0.0" }, @@ -745,6 +752,7 @@ dev = [ { name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" }, { name = "pytest-benchmark", extras = ["histogram"], specifier = ">=5.1.0" }, { name = "pytest-cov", specifier = ">=6.0.0,<6.0.1" }, + { name = "pytest-lsp", specifier = ">=1.0.0b1" }, { name = "pytest-mock", specifier = ">=3.14.0,<4.0.0" }, { name = "pytest-timeout", specifier = ">=2.3.1" }, { name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" }, @@ -2041,15 +2049,15 @@ wheels = [ [[package]] name = "lsprotocol" -version = "2023.0.1" +version = "2024.0.0b1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, { name = "cattrs" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9d/f6/6e80484ec078d0b50699ceb1833597b792a6c695f90c645fbaf54b947e6f/lsprotocol-2023.0.1.tar.gz", hash = "sha256:cc5c15130d2403c18b734304339e51242d3018a05c4f7d0f198ad6e0cd21861d", size = 69434 } +sdist = { url = "https://files.pythonhosted.org/packages/dc/21/0282716d19591e573d20564ee4df65cb5cd8911bfdff35fcde1de2b54072/lsprotocol-2024.0.0b1.tar.gz", hash = "sha256:d3667fb70894d361aa6c495c5c8a1b2e6a44be65ff84c21a9cbb67ebfb4830fd", size = 75358 } wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/37/2351e48cb3309673492d3a8c59d407b75fb6630e560eb27ecd4da03adc9a/lsprotocol-2023.0.1-py3-none-any.whl", hash = "sha256:c75223c9e4af2f24272b14c6375787438279369236cd568f596d4951052a60f2", size = 70826 }, + { url = "https://files.pythonhosted.org/packages/4d/1b/526af91cd43eba22ac7d9dbdec729dd9d91c2ad335085a61dd42307a7b35/lsprotocol-2024.0.0b1-py3-none-any.whl", hash = "sha256:93785050ac155ae2be16b1ebfbd74c214feb3d3ef77b10399ce941e5ccef6ebd", size = 76600 }, ] [[package]] @@ -2837,15 +2845,15 @@ wheels = [ [[package]] name = "pygls" -version = "1.3.1" +version = "2.0.0a2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cattrs" }, { name = "lsprotocol" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/86/b9/41d173dad9eaa9db9c785a85671fc3d68961f08d67706dc2e79011e10b5c/pygls-1.3.1.tar.gz", hash = "sha256:140edceefa0da0e9b3c533547c892a42a7d2fd9217ae848c330c53d266a55018", size = 45527 } +sdist = { url = "https://files.pythonhosted.org/packages/68/a9/2110bbc90fde62ab7b8f21164caacb5288c06d98486cc569526ec6c0c9ca/pygls-2.0.0a2.tar.gz", hash = "sha256:03e00634ed8d989918268aaa4b4a0c3ab857ea2d4ee94514a52efa5ddd6d5d9f", size = 46279 } wheels = [ - { url = "https://files.pythonhosted.org/packages/11/19/b74a10dd24548e96e8c80226cbacb28b021bc3a168a7d2709fb0d0185348/pygls-1.3.1-py3-none-any.whl", hash = "sha256:6e00f11efc56321bdeb6eac04f6d86131f654c7d49124344a9ebb968da3dd91e", size = 56031 }, + { url = "https://files.pythonhosted.org/packages/f8/47/7d7b3911fbd27153ee38a1a15e3977c72733a41ee8d7f6ce6dca65843fe9/pygls-2.0.0a2-py3-none-any.whl", hash = "sha256:b202369321409343aa6440d73111d9fa0c22e580466ff1c7696b8358bb91f243", size = 58504 }, ] [[package]] @@ -3036,6 +3044,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/3b/48e79f2cd6a61dbbd4807b4ed46cb564b4fd50a76166b1c4ea5c1d9e2371/pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35", size = 22949 }, ] +[[package]] +name = "pytest-lsp" +version = "1.0.0b2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "pygls" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/47/1207bf70218c9cbb6e8a184a1957f699c35d9bf8b43dfa2be5885d35c283/pytest_lsp-1.0.0b2.tar.gz", hash = "sha256:459f62d578d700b63c4ea0b500b5a621461eb2c60d0fd941c3583b0d7930a1ea", size = 26634 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/cc/2f46f5a3db66e50e813cba64da0fed2c517c28b80877585461534c953f22/pytest_lsp-1.0.0b2-py3-none-any.whl", hash = "sha256:d989c69e134ac66e297f0e0eae5edb13470059d7028e50fb06c01674b067fc14", size = 24115 }, +] + [[package]] name = "pytest-mock" version = "3.14.0"