diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index aa5f8e91f..ecf7d3e6d 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -80,9 +80,16 @@ def paneled_text( console.print(panel) -def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None: +def code_print( + code_str: str, + file_name: Optional[str] = None, + function_name: Optional[str] = None, + lsp_message_id: Optional[str] = None, +) -> None: if is_LSP_enabled(): - lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name)) + lsp_log( + LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id) + ) return """Print code with syntax highlighting.""" from rich.syntax import Syntax diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index d04c387d1..e2faff2bd 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -353,7 +353,7 @@ def cleanup_optimizer(_params: any) -> dict[str, str]: @server.feature("initializeFunctionOptimization") def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]: - with execution_context(task_id=params.task_id): + with execution_context(task_id=getattr(params, "task_id", None)): document_uri = params.textDocument.uri document = server.workspace.get_text_document(document_uri) file_path = Path(document.path) @@ -423,7 +423,7 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) -> @server.feature("performFunctionOptimization") async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]: - with execution_context(task_id=params.task_id): + with execution_context(task_id=getattr(params, "task_id", None)): loop = asyncio.get_running_loop() cancel_event = threading.Event() diff --git a/codeflash/lsp/lsp_logger.py b/codeflash/lsp/lsp_logger.py index 50491c54c..14af72ea3 100644 --- a/codeflash/lsp/lsp_logger.py +++ b/codeflash/lsp/lsp_logger.py @@ -3,13 +3,15 @@ import logging import sys from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, Optional from codeflash.lsp.helpers import is_LSP_enabled -from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter +from codeflash.lsp.lsp_message import LSPMessageId, LspTextMessage, message_delimiter root_logger = None +message_id_prefix = "id:" + @dataclass class LspMessageTags: @@ -18,6 +20,7 @@ class LspMessageTags: lsp: bool = False # lsp (lsp only) force_lsp: bool = False # force_lsp (you can use this to force a message to be sent to the LSP even if the level is not supported) loading: bool = False # loading (you can use this to indicate that the message is a loading message) + message_id: Optional[LSPMessageId] = None # example: id:best_candidate highlight: bool = False # highlight (you can use this to highlight the message by wrapping it in ``) h1: bool = False # h1 h2: bool = False # h2 @@ -52,24 +55,27 @@ def extract_tags(msg: str) -> tuple[LspMessageTags, str]: tags = {tag.strip() for tag in tags_str.split(",")} message_tags = LspMessageTags() # manually check and set to avoid repeated membership tests - if "lsp" in tags: - message_tags.lsp = True - if "!lsp" in tags: - message_tags.not_lsp = True - if "force_lsp" in tags: - message_tags.force_lsp = True - if "loading" in tags: - message_tags.loading = True - if "highlight" in tags: - message_tags.highlight = True - if "h1" in tags: - message_tags.h1 = True - if "h2" in tags: - message_tags.h2 = True - if "h3" in tags: - message_tags.h3 = True - if "h4" in tags: - message_tags.h4 = True + for tag in tags: + if tag.startswith(message_id_prefix): + message_tags.message_id = LSPMessageId(tag[len(message_id_prefix) :]).value + elif tag == "lsp": + message_tags.lsp = True + elif tag == "!lsp": + message_tags.not_lsp = True + elif tag == "force_lsp": + message_tags.force_lsp = True + elif tag == "loading": + message_tags.loading = True + elif tag == "highlight": + message_tags.highlight = True + elif tag == "h1": + message_tags.h1 = True + elif tag == "h2": + message_tags.h2 = True + elif tag == "h3": + message_tags.h3 = True + elif tag == "h4": + message_tags.h4 = True return message_tags, content return LspMessageTags(), msg @@ -118,7 +124,7 @@ def enhanced_log( if is_normal_text_message: clean_msg = add_heading_tags(clean_msg, tags) clean_msg = add_highlight_tags(clean_msg, tags) - clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize() + clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading, message_id=tags.message_id).serialize() actual_log_fn(clean_msg, *args, **kwargs) diff --git a/codeflash/lsp/lsp_message.py b/codeflash/lsp/lsp_message.py index 492535735..d9c6e13bd 100644 --- a/codeflash/lsp/lsp_message.py +++ b/codeflash/lsp/lsp_message.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import json from dataclasses import asdict, dataclass from pathlib import Path @@ -14,10 +15,17 @@ message_delimiter = "\u241f" +# allow the client to know which message it is receiving +class LSPMessageId(enum.Enum): + BEST_CANDIDATE = "best_candidate" + CANDIDATE = "candidate" + + @dataclass class LspMessage: # to show a loading indicator if the operation is taking time like generating candidates or tests takes_time: bool = False + message_id: Optional[str] = None def _loop_through(self, obj: Any) -> Any: # noqa: ANN401 if isinstance(obj, list): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 54dfbd41f..093ff0966 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -66,7 +66,7 @@ from codeflash.discovery.functions_to_optimize import was_function_previously_optimized from codeflash.either import Failure, Success, is_successful from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown -from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage +from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( BestOptimization, @@ -510,7 +510,11 @@ def determine_best_candidate( get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"h3|Optimization candidate {candidate_index}/{processor.candidate_len}:") - code_print(candidate.source_code.flat, file_name=f"candidate_{candidate_index}.py") + code_print( + candidate.source_code.flat, + file_name=f"candidate_{candidate_index}.py", + lsp_message_id=LSPMessageId.CANDIDATE.value, + ) # map ast normalized code to diff len, unnormalized code # map opt id to the shortest unnormalized code try: @@ -1291,6 +1295,7 @@ def find_and_process_best_optimization( best_optimization.candidate.source_code.flat, file_name="best_candidate.py", function_name=self.function_to_optimize.function_name, + lsp_message_id=LSPMessageId.BEST_CANDIDATE.value, ) processed_benchmark_info = None if self.args.benchmark: