From 3d97ee51642e2f75478e95fd684b26c4b3d6b7b0 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Wed, 12 Feb 2025 16:01:32 -0800 Subject: [PATCH 1/4] progress support --- src/codegen/extensions/lsp/progress.py | 61 +++++++++++++++++++ src/codegen/extensions/lsp/protocol.py | 11 ++-- src/codegen/sdk/codebase/codebase_context.py | 32 ++++++++-- src/codegen/sdk/codebase/progress/progress.py | 13 ++++ .../sdk/codebase/progress/stub_progress.py | 7 +++ .../sdk/codebase/progress/stub_task.py | 9 +++ src/codegen/sdk/codebase/progress/task.py | 11 ++++ src/codegen/sdk/core/codebase.py | 6 +- uv.lock | 2 +- 9 files changed, 138 insertions(+), 14 deletions(-) create mode 100644 src/codegen/extensions/lsp/progress.py create mode 100644 src/codegen/sdk/codebase/progress/progress.py create mode 100644 src/codegen/sdk/codebase/progress/stub_progress.py create mode 100644 src/codegen/sdk/codebase/progress/stub_task.py create mode 100644 src/codegen/sdk/codebase/progress/task.py diff --git a/src/codegen/extensions/lsp/progress.py b/src/codegen/extensions/lsp/progress.py new file mode 100644 index 000000000..8f1615826 --- /dev/null +++ b/src/codegen/extensions/lsp/progress.py @@ -0,0 +1,61 @@ +import uuid + +from lsprotocol import types +from lsprotocol.types import ProgressToken +from pygls.lsp.server import LanguageServer + +from codegen.sdk.codebase.progress.progress import Progress +from codegen.sdk.codebase.progress.stub_task import StubTask +from codegen.sdk.codebase.progress.task import Task + + +class LSPTask(Task): + count: int | None + + def __init__(self, server: LanguageServer, message: str, token: ProgressToken, count: int | None = None, create_token: bool = True) -> None: + self.token = token + if create_token: + server.work_done_progress.begin(self.token, types.WorkDoneProgressBegin(title=message)) + self.server = server + self.message = message + self.count = count + self.create_token = create_token + + def update(self, message: str, count: int | None = None) -> None: + if self.count is not None and count is not None: + percent = int(count * 100 / self.count) + else: + percent = None + self.server.work_done_progress.report(self.token, types.WorkDoneProgressReport(message=message, percentage=percent)) + + def end(self) -> None: + if self.create_token: + self.server.work_done_progress.end(self.token, value=types.WorkDoneProgressEnd()) + + +class LSPProgress(Progress[LSPTask | StubTask]): + initialized = False + + def __init__(self, server: LanguageServer, initial_token: ProgressToken | None = None): + self.server = server + self.initial_token = initial_token + if initial_token is not None: + self.server.work_done_progress.begin(initial_token, types.WorkDoneProgressBegin(title="Parsing codebase...")) + + def begin_with_token(self, message: str, token: ProgressToken, *, count: int | None = None) -> LSPTask: + return LSPTask(self.server, message, token, count, create_token=False) + + def begin(self, message: str, count: int | None = None) -> LSPTask | StubTask: + if self.initialized: + token = str(uuid.uuid4()) + self.server.work_done_progress.create(token).result() + return LSPTask(self.server, message, token, count, create_token=False) + elif self.initial_token is not None: + return self.begin_with_token(message, self.initial_token, count=None) + else: + return StubTask() + + def finish_initialization(self) -> None: + self.initialized = True + if self.initial_token is not None: + self.server.work_done_progress.end(self.initial_token, value=types.WorkDoneProgressEnd()) diff --git a/src/codegen/extensions/lsp/protocol.py b/src/codegen/extensions/lsp/protocol.py index 7799a56e8..879699df1 100644 --- a/src/codegen/extensions/lsp/protocol.py +++ b/src/codegen/extensions/lsp/protocol.py @@ -2,10 +2,11 @@ from pathlib import Path from typing import TYPE_CHECKING -from lsprotocol.types import INITIALIZE, InitializeParams, InitializeResult, WorkDoneProgressBegin, WorkDoneProgressEnd +from lsprotocol.types import INITIALIZE, InitializeParams, InitializeResult from pygls.protocol import LanguageServerProtocol, lsp_method from codegen.extensions.lsp.io import LSPIO +from codegen.extensions.lsp.progress import LSPProgress from codegen.extensions.lsp.utils import get_path from codegen.sdk.codebase.config import CodebaseConfig from codegen.sdk.core.codebase import Codebase @@ -19,6 +20,7 @@ class CodegenLanguageServerProtocol(LanguageServerProtocol): _server: "CodegenLanguageServer" def _init_codebase(self, params: InitializeParams) -> None: + progress = LSPProgress(self._server, params.work_done_token) if params.root_path: root = Path(params.root_path) elif params.root_uri: @@ -27,15 +29,12 @@ def _init_codebase(self, params: InitializeParams) -> None: root = os.getcwd() config = CodebaseConfig(feature_flags=CodebaseFeatureFlags(full_range_index=True)) io = LSPIO(self.workspace) - self._server.codebase = Codebase(repo_path=str(root), config=config, io=io) + self._server.codebase = Codebase(repo_path=str(root), config=config, io=io, progress=progress) self._server.io = io - if params.work_done_token: - self._server.work_done_progress.end(params.work_done_token, WorkDoneProgressEnd(message="Parsing codebase...")) + progress.finish_initialization() @lsp_method(INITIALIZE) def lsp_initialize(self, params: InitializeParams) -> InitializeResult: ret = super().lsp_initialize(params) - if params.work_done_token: - self._server.work_done_progress.begin(params.work_done_token, WorkDoneProgressBegin(title="Parsing codebase...")) self._init_codebase(params) return ret diff --git a/src/codegen/sdk/codebase/codebase_context.py b/src/codegen/sdk/codebase/codebase_context.py index 8a8f1679b..53664c73e 100644 --- a/src/codegen/sdk/codebase/codebase_context.py +++ b/src/codegen/sdk/codebase/codebase_context.py @@ -16,6 +16,7 @@ from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite from codegen.sdk.codebase.flagging.flags import Flags from codegen.sdk.codebase.io.file_io import FileIO +from codegen.sdk.codebase.progress.stub_progress import StubProgress from codegen.sdk.codebase.transaction_manager import TransactionManager from codegen.sdk.codebase.validation import get_edges, post_reset_validation from codegen.sdk.core.autocommit import AutoCommit, commiter @@ -39,6 +40,7 @@ from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.sdk.codebase.io.io import IO from codegen.sdk.codebase.node_classes.node_classes import NodeClasses + from codegen.sdk.codebase.progress.progress import Progress from codegen.sdk.core.dataclasses.usage import Usage from codegen.sdk.core.expressions import Expression from codegen.sdk.core.external_module import ExternalModule @@ -111,16 +113,19 @@ class CodebaseContext: projects: list[ProjectConfig] unapplied_diffs: list[DiffLite] io: IO + progress: Progress def __init__( self, projects: list[ProjectConfig], config: CodebaseConfig = DefaultConfig, io: IO | None = None, + progress: Progress | None = None, ) -> None: """Initializes codebase graph and TransactionManager""" from codegen.sdk.core.parser import Parser + self.progress = progress or StubProgress() self._graph = PyDiGraph() self.filepath_idx = {} self._ext_module_idx = {} @@ -371,7 +376,6 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr skip_uncache = incremental and ((len(files_to_sync[SyncType.DELETE]) + len(files_to_sync[SyncType.REPARSE])) == 0) if not skip_uncache: uncache_all() - # Step 0: Start the dependency manager and language engine if they exist # Start the dependency manager. This may or may not run asynchronously, depending on the implementation if self.dependency_manager is not None: @@ -429,17 +433,21 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr file = self.get_file(file_path) file.remove_internal_edges() + task = self.progress.begin("Reparsing updated files", count=len(files_to_sync[SyncType.REPARSE])) files_to_resolve = [] # Step 4: Reparse updated files - for file_path in files_to_sync[SyncType.REPARSE]: + for idx, file_path in enumerate(files_to_sync[SyncType.REPARSE]): + task.update(f"Reparsing {self.to_relative(file_path)}", count=idx) file = self.get_file(file_path) to_resolve.extend(file.unparse(reparse=True)) to_resolve = list(filter(lambda node: self.has_node(node.node_id) and node is not None, to_resolve)) file.sync_with_file_content() files_to_resolve.append(file) - + task.end() # Step 5: Add new files as nodes to graph (does not yet add edges) - for filepath in files_to_sync[SyncType.ADD]: + task = self.progress.begin("Adding new files", count=len(files_to_sync[SyncType.ADD])) + for idx, filepath in enumerate(files_to_sync[SyncType.ADD]): + task.update(f"Adding {self.to_relative(filepath)}", count=idx) content = self.io.read_text(filepath) # TODO: this is wrong with context changes if filepath.suffix in self.extensions: @@ -447,6 +455,7 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr new_file = file_cls.from_content(filepath, content, self, sync=False, verify_syntax=False) if new_file is not None: files_to_resolve.append(new_file) + task.end() for file in files_to_resolve: to_resolve.append(file) to_resolve.extend(file.get_nodes()) @@ -474,27 +483,35 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr self._computing = True try: logger.info(f"> Computing import resolution edges for {counter[NodeType.IMPORT]} imports") + task = self.progress.begin("Resolving imports", count=counter[NodeType.IMPORT]) for node in to_resolve: if node.node_type == NodeType.IMPORT: + task.update(f"Resolving imports in {node.filepath}", count=idx) node._remove_internal_edges(EdgeType.IMPORT_SYMBOL_RESOLUTION) node.add_symbol_resolution_edge() to_resolve.extend(node.symbol_usages) + task.end() if counter[NodeType.EXPORT] > 0: logger.info(f"> Computing export dependencies for {counter[NodeType.EXPORT]} exports") + task = self.progress.begin("Computing export dependencies", count=counter[NodeType.EXPORT]) for node in to_resolve: if node.node_type == NodeType.EXPORT: + task.update(f"Computing export dependencies for {node.filepath}", count=idx) node._remove_internal_edges(EdgeType.EXPORT) node.compute_export_dependencies() to_resolve.extend(node.symbol_usages) + task.end() if counter[NodeType.SYMBOL] > 0: from codegen.sdk.core.interfaces.inherits import Inherits logger.info("> Computing superclass dependencies") + task = self.progress.begin("Computing superclass dependencies", count=counter[NodeType.SYMBOL]) for symbol in to_resolve: if isinstance(symbol, Inherits): + task.update(f"Computing superclass dependencies for {symbol.filepath}", count=idx) symbol._remove_internal_edges(EdgeType.SUBCLASS) symbol.compute_superclass_dependencies() - + task.end() if not skip_uncache: uncache_all() self._compute_dependencies(to_resolve, incremental) @@ -504,10 +521,12 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr def _compute_dependencies(self, to_update: list[Importable], incremental: bool): seen = set() while to_update: + task = self.progress.begin("Computing dependencies", count=len(to_update)) step = to_update.copy() to_update.clear() logger.info(f"> Incrementally computing dependencies for {len(step)} nodes") - for current in step: + for idx, current in enumerate(step): + task.update(f"Computing dependencies for {current.filepath}", count=idx) if current not in seen: seen.add(current) to_update.extend(current.recompute(incremental)) @@ -515,6 +534,7 @@ def _compute_dependencies(self, to_update: list[Importable], incremental: bool): for node in self._graph.nodes(): if node not in seen: to_update.append(node) + task.end() seen.clear() def build_subgraph(self, nodes: list[NodeId]) -> PyDiGraph[Importable, Edge]: diff --git a/src/codegen/sdk/codebase/progress/progress.py b/src/codegen/sdk/codebase/progress/progress.py new file mode 100644 index 000000000..ec1c8b6e1 --- /dev/null +++ b/src/codegen/sdk/codebase/progress/progress.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from codegen.sdk.codebase.progress.task import Task + +T = TypeVar("T", bound="Task") + + +class Progress(ABC, Generic[T]): + @abstractmethod + def begin(self, message: str, count: int | None = None) -> T: + pass diff --git a/src/codegen/sdk/codebase/progress/stub_progress.py b/src/codegen/sdk/codebase/progress/stub_progress.py new file mode 100644 index 000000000..6c0aac5aa --- /dev/null +++ b/src/codegen/sdk/codebase/progress/stub_progress.py @@ -0,0 +1,7 @@ +from codegen.sdk.codebase.progress.progress import Progress +from codegen.sdk.codebase.progress.stub_task import StubTask + + +class StubProgress(Progress[StubTask]): + def begin(self, message: str, count: int | None = None) -> StubTask: + return StubTask() diff --git a/src/codegen/sdk/codebase/progress/stub_task.py b/src/codegen/sdk/codebase/progress/stub_task.py new file mode 100644 index 000000000..43d25acf7 --- /dev/null +++ b/src/codegen/sdk/codebase/progress/stub_task.py @@ -0,0 +1,9 @@ +from codegen.sdk.codebase.progress.task import Task + + +class StubTask(Task): + def update(self, message: str, count: int | None = None) -> None: + pass + + def end(self) -> None: + pass diff --git a/src/codegen/sdk/codebase/progress/task.py b/src/codegen/sdk/codebase/progress/task.py new file mode 100644 index 000000000..c0814513d --- /dev/null +++ b/src/codegen/sdk/codebase/progress/task.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + + +class Task(ABC): + @abstractmethod + def update(self, message: str, count: int | None = None) -> None: + pass + + @abstractmethod + def end(self) -> None: + pass diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 44242732f..3c0709dc4 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -36,6 +36,7 @@ from codegen.sdk.codebase.flagging.enums import FlagKwargs from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.io.io import IO +from codegen.sdk.codebase.progress.progress import Progress from codegen.sdk.codebase.span import Span from codegen.sdk.core.assignment import Assignment from codegen.sdk.core.class_definition import Class @@ -132,6 +133,7 @@ def __init__( projects: list[ProjectConfig] | ProjectConfig, config: CodebaseConfig = DefaultConfig, io: IO | None = None, + progress: Progress | None = None, ) -> None: ... @overload @@ -143,6 +145,7 @@ def __init__( projects: None = None, config: CodebaseConfig = DefaultConfig, io: IO | None = None, + progress: Progress | None = None, ) -> None: ... def __init__( @@ -153,6 +156,7 @@ def __init__( projects: list[ProjectConfig] | ProjectConfig | None = None, config: CodebaseConfig = DefaultConfig, io: IO | None = None, + progress: Progress | None = None, ) -> None: # Sanity check inputs if repo_path is not None and projects is not None: @@ -182,7 +186,7 @@ def __init__( self._op = main_project.repo_operator self.viz = VisualizationManager(op=self._op) self.repo_path = Path(self._op.repo_path) - self.ctx = CodebaseContext(projects, config=config, io=io) + self.ctx = CodebaseContext(projects, config=config, io=io, progress=progress) self.console = Console(record=True, soft_wrap=True) @noapidoc diff --git a/uv.lock b/uv.lock index 75f118ea9..ca8557c63 100644 --- a/uv.lock +++ b/uv.lock @@ -537,7 +537,7 @@ wheels = [ [[package]] name = "codegen" -version = "0.9.1.dev2+gd60aa9c7.d20250212" +version = "0.11.2.dev5+gef60333b.d20250212" source = { editable = "." } dependencies = [ { name = "anthropic" }, From 725271baceee84e218249ab5404bc90035c0b0de Mon Sep 17 00:00:00 2001 From: bagel897 Date: Wed, 12 Feb 2025 17:10:35 -0800 Subject: [PATCH 2/4] Add more progress support --- src/codegen/extensions/lsp/lsp.py | 21 +++++++++++++------- src/codegen/extensions/lsp/progress.py | 11 +++++------ src/codegen/extensions/lsp/protocol.py | 1 + src/codegen/extensions/lsp/server.py | 27 ++++++++++++++++---------- 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/codegen/extensions/lsp/lsp.py b/src/codegen/extensions/lsp/lsp.py index 06b227400..50b7dc8ed 100644 --- a/src/codegen/extensions/lsp/lsp.py +++ b/src/codegen/extensions/lsp/lsp.py @@ -71,6 +71,7 @@ def did_close(server: CodegenLanguageServer, params: types.DidCloseTextDocumentP @server.feature( types.TEXT_DOCUMENT_RENAME, + options=types.RenameOptions(work_done_progress=True), ) def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.RenameResult: symbol = server.get_symbol(params.text_document.uri, params.position) @@ -78,28 +79,38 @@ def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.R logger.warning(f"No symbol found at {params.text_document.uri}:{params.position}") return logger.info(f"Renaming symbol {symbol.name} to {params.new_name}") + task = server.progress_manager.begin_with_token(f"Renaming symbol {symbol.name} to {params.new_name}", params.work_done_token) symbol.rename(params.new_name) + task.update("Committing changes") server.codebase.commit() + task.end() return server.io.get_workspace_edit() @server.feature( types.TEXT_DOCUMENT_DOCUMENT_SYMBOL, + options=types.DocumentSymbolOptions(work_done_progress=True), ) def document_symbol(server: CodegenLanguageServer, params: types.DocumentSymbolParams) -> types.DocumentSymbolResult: file = server.get_file(params.text_document.uri) symbols = [] - for symbol in file.symbols: + task = server.progress_manager.begin_with_token(f"Getting document symbols for {params.text_document.uri}", params.work_done_token, count=len(file.symbols)) + for idx, symbol in enumerate(file.symbols): + task.update(f"Getting document symbols for {params.text_document.uri}", count=idx) symbols.append(get_document_symbol(symbol)) + task.end() return symbols @server.feature( types.TEXT_DOCUMENT_DEFINITION, + options=types.DefinitionOptions(work_done_progress=True), ) def definition(server: CodegenLanguageServer, params: types.DefinitionParams): node = server.get_node_under_cursor(params.text_document.uri, params.position) + task = server.progress_manager.begin_with_token(f"Getting definition for {params.text_document.uri}", params.work_done_token) resolved = go_to_definition(node, params.text_document.uri, params.position) + task.end() return types.Location( uri=resolved.file.path.as_uri(), range=get_range(resolved), @@ -108,15 +119,11 @@ def definition(server: CodegenLanguageServer, params: types.DefinitionParams): @server.feature( types.TEXT_DOCUMENT_CODE_ACTION, - options=types.CodeActionOptions(resolve_provider=True), + options=types.CodeActionOptions(resolve_provider=True, work_done_progress=True), ) def code_action(server: CodegenLanguageServer, params: types.CodeActionParams) -> types.CodeActionResult: logger.info(f"Received code action: {params}") - if params.context.only: - only = [types.CodeActionKind(kind) for kind in params.context.only] - else: - only = None - actions = server.get_actions_for_range(params.text_document.uri, params.range, only) + actions = server.get_actions_for_range(params) return actions diff --git a/src/codegen/extensions/lsp/progress.py b/src/codegen/extensions/lsp/progress.py index 8f1615826..70b16f568 100644 --- a/src/codegen/extensions/lsp/progress.py +++ b/src/codegen/extensions/lsp/progress.py @@ -42,7 +42,9 @@ def __init__(self, server: LanguageServer, initial_token: ProgressToken | None = if initial_token is not None: self.server.work_done_progress.begin(initial_token, types.WorkDoneProgressBegin(title="Parsing codebase...")) - def begin_with_token(self, message: str, token: ProgressToken, *, count: int | None = None) -> LSPTask: + def begin_with_token(self, message: str, token: ProgressToken | None = None, *, count: int | None = None) -> LSPTask | StubTask: + if token is None: + return StubTask() return LSPTask(self.server, message, token, count, create_token=False) def begin(self, message: str, count: int | None = None) -> LSPTask | StubTask: @@ -50,12 +52,9 @@ def begin(self, message: str, count: int | None = None) -> LSPTask | StubTask: token = str(uuid.uuid4()) self.server.work_done_progress.create(token).result() return LSPTask(self.server, message, token, count, create_token=False) - elif self.initial_token is not None: - return self.begin_with_token(message, self.initial_token, count=None) - else: - return StubTask() + return self.begin_with_token(message, self.initial_token, count=None) def finish_initialization(self) -> None: - self.initialized = True + self.initialized = False # We can't initiate server work during syncs if self.initial_token is not None: self.server.work_done_progress.end(self.initial_token, value=types.WorkDoneProgressEnd()) diff --git a/src/codegen/extensions/lsp/protocol.py b/src/codegen/extensions/lsp/protocol.py index 879699df1..9b96d7f47 100644 --- a/src/codegen/extensions/lsp/protocol.py +++ b/src/codegen/extensions/lsp/protocol.py @@ -30,6 +30,7 @@ def _init_codebase(self, params: InitializeParams) -> None: config = CodebaseConfig(feature_flags=CodebaseFeatureFlags(full_range_index=True)) io = LSPIO(self.workspace) self._server.codebase = Codebase(repo_path=str(root), config=config, io=io, progress=progress) + self._server.progress_manager = progress self._server.io = io progress.finish_initialization() diff --git a/src/codegen/extensions/lsp/server.py b/src/codegen/extensions/lsp/server.py index ce49004b9..90eca9d4d 100644 --- a/src/codegen/extensions/lsp/server.py +++ b/src/codegen/extensions/lsp/server.py @@ -1,5 +1,4 @@ import logging -from collections.abc import Sequence from typing import Any, Optional from lsprotocol import types @@ -8,8 +7,9 @@ from codegen.extensions.lsp.codemods import ACTIONS from codegen.extensions.lsp.codemods.base import CodeAction -from codegen.extensions.lsp.execute import execute_action, get_execute_action +from codegen.extensions.lsp.execute import execute_action from codegen.extensions.lsp.io import LSPIO +from codegen.extensions.lsp.progress import LSPProgress from codegen.extensions.lsp.range import get_tree_sitter_range from codegen.extensions.lsp.utils import get_path from codegen.sdk.core.codebase import Codebase @@ -23,13 +23,14 @@ class CodegenLanguageServer(LanguageServer): codebase: Optional[Codebase] io: Optional[LSPIO] + progress_manager: Optional[LSPProgress] actions: dict[str, CodeAction] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.actions = {action.command_name(): action for action in ACTIONS} - for action in self.actions.values(): - self.command(action.command_name())(get_execute_action(action)) + # for action in self.actions.values(): + # self.command(action.command_name())(get_execute_action(action)) def get_file(self, uri: str) -> SourceFile | File: path = get_path(uri) @@ -68,19 +69,25 @@ def get_node_for_range(self, uri: str, range: Range) -> Editable | None: return node return None - def get_actions_for_range(self, uri: str, range: Range, only: Sequence[types.CodeActionKind] | None = None) -> list[types.CodeAction]: - node = self.get_node_under_cursor(uri, range.start, range.end) + def get_actions_for_range(self, params: types.CodeActionParams) -> list[types.CodeAction]: + if params.context.only is not None: + only = [types.CodeActionKind(kind) for kind in params.context.only] + else: + only = None + node = self.get_node_under_cursor(params.text_document.uri, params.range.start) if node is None: - logger.warning(f"No node found for range {range} in {uri}") + logger.warning(f"No node found for range {params.range} in {params.text_document.uri}") return [] actions = [] - for action in self.actions.values(): + task = self.progress_manager.begin_with_token(f"Getting code actions for {params.text_document.uri}", params.work_done_token, count=len(self.actions)) + for idx, action in enumerate(self.actions.values()): + task.update(f"Checking action {action.name}", idx) if only and action.kind not in only: logger.warning(f"Skipping action {action.kind} because it is not in {only}") continue if action.is_applicable(self, node): - actions.append(action.to_lsp(uri, range)) - + actions.append(action.to_lsp(params.text_document.uri, params.range)) + task.end() return actions def resolve_action(self, action: types.CodeAction) -> types.CodeAction: From 5034b064a62cb626ff9aa2965b0ec7aa40cac507 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 13 Feb 2025 10:35:59 -0800 Subject: [PATCH 3/4] add test for progress reporting --- src/codegen/extensions/lsp/progress.py | 6 +++--- tests/unit/codegen/extensions/lsp/conftest.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/codegen/extensions/lsp/progress.py b/src/codegen/extensions/lsp/progress.py index 70b16f568..70eb365e5 100644 --- a/src/codegen/extensions/lsp/progress.py +++ b/src/codegen/extensions/lsp/progress.py @@ -42,17 +42,17 @@ def __init__(self, server: LanguageServer, initial_token: ProgressToken | None = if initial_token is not None: self.server.work_done_progress.begin(initial_token, types.WorkDoneProgressBegin(title="Parsing codebase...")) - def begin_with_token(self, message: str, token: ProgressToken | None = None, *, count: int | None = None) -> LSPTask | StubTask: + def begin_with_token(self, message: str, token: ProgressToken | None = None, *, count: int | None = None, create_token: bool = True) -> LSPTask | StubTask: if token is None: return StubTask() - return LSPTask(self.server, message, token, count, create_token=False) + return LSPTask(self.server, message, token, count, create_token=create_token) def begin(self, message: str, count: int | None = None) -> LSPTask | StubTask: if self.initialized: token = str(uuid.uuid4()) self.server.work_done_progress.create(token).result() return LSPTask(self.server, message, token, count, create_token=False) - return self.begin_with_token(message, self.initial_token, count=None) + return self.begin_with_token(message, self.initial_token, count=None, create_token=False) def finish_initialization(self) -> None: self.initialized = False # We can't initiate server work during syncs diff --git a/tests/unit/codegen/extensions/lsp/conftest.py b/tests/unit/codegen/extensions/lsp/conftest.py index 099329327..ef7737c1d 100644 --- a/tests/unit/codegen/extensions/lsp/conftest.py +++ b/tests/unit/codegen/extensions/lsp/conftest.py @@ -13,6 +13,16 @@ from codegen.sdk.core.codebase import Codebase +@pytest_lsp.fixture( + config=ClientServerConfig( + server_command=[sys.executable, "-m", "codegen.extensions.lsp.lsp"], + ), +) +async def lsp_client_uninitialized(lsp_client: LanguageClient): + yield lsp_client + await lsp_client.shutdown_session() + + @pytest_lsp.fixture( config=ClientServerConfig( server_command=[sys.executable, "-m", "codegen.extensions.lsp.lsp"], From 898ea096fb904901bd9f0d1fa13262febf3ad050 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 13 Feb 2025 10:37:11 -0800 Subject: [PATCH 4/4] add test for progress reporting --- .../codegen/extensions/lsp/test_progress.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/unit/codegen/extensions/lsp/test_progress.py diff --git a/tests/unit/codegen/extensions/lsp/test_progress.py b/tests/unit/codegen/extensions/lsp/test_progress.py new file mode 100644 index 000000000..7ddc4d801 --- /dev/null +++ b/tests/unit/codegen/extensions/lsp/test_progress.py @@ -0,0 +1,76 @@ +import uuid + +import pytest +from lsprotocol import types +from pytest_lsp import LanguageClient, client_capabilities + +from codegen.sdk.core.codebase import Codebase +from tests.unit.codegen.extensions.lsp.utils import apply_edit + + +def check_ascending(reports: list[types.WorkDoneProgressReport]): + prev = 0 + for report in reports: + if isinstance(report, types.WorkDoneProgressEnd) or report.percentage is None: + continue + assert report.percentage > prev + prev = report.percentage + + +def check_reports(reports: list[types.WorkDoneProgressReport]): + assert isinstance(reports[0], types.WorkDoneProgressBegin) + assert isinstance(reports[-1], types.WorkDoneProgressEnd) + check_ascending(reports) + + +@pytest.mark.parametrize( + "original, expected", + [ + ( + { + "test.py": """ +def hello(): + pass +""".strip(), + }, + { + "test.py": """ +def world(): + pass +""".strip() + }, + ) + ], +) +async def test_progress(lsp_client_uninitialized: LanguageClient, codebase: Codebase, assert_expected, original: dict[str, str]): + token = str(uuid.uuid4()) + assert lsp_client_uninitialized.progress_reports.get(token, None) is None + req = await lsp_client_uninitialized.initialize_session( + types.InitializeParams( + capabilities=client_capabilities("visual-studio-code"), + root_uri="file://" + str(codebase.repo_path.resolve()), + root_path=str(codebase.repo_path.resolve()), + work_done_token=token, + ) + ) + reports = lsp_client_uninitialized.progress_reports.get(token, None) + assert reports is not None + check_reports(reports) + for file in original.keys(): + assert any(file in report.message for report in reports if isinstance(report, types.WorkDoneProgressReport)) + rename_token = str(uuid.uuid4()) + result = await lsp_client_uninitialized.text_document_rename_async( + params=types.RenameParams( + position=types.Position(line=0, character=5), + text_document=types.TextDocumentIdentifier(uri=f"file://{codebase.repo_path}/test.py"), + new_name="world", + work_done_token=rename_token, + ) + ) + if result: + apply_edit(codebase, result) + reports = lsp_client_uninitialized.progress_reports.get(rename_token, None) + assert reports is not None + check_reports(reports) + assert "Renaming" in reports[0].title + assert_expected(codebase)