From c75bbf65e2f9785c7aaed11ee2540805e1adc6f5 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 11:56:09 +0300 Subject: [PATCH 01/16] check large diffs with black, and skipp formatting in such case (after optimizing) --- code_to_optimize/few_formatting_errors.py | 47 +++++++ code_to_optimize/many_formatting_errors.py | 147 +++++++++++++++++++++ codeflash/code_utils/formatter.py | 39 +++++- tests/test_formatter.py | 69 ++++++++++ 4 files changed, 300 insertions(+), 2 deletions(-) create mode 100644 code_to_optimize/few_formatting_errors.py create mode 100644 code_to_optimize/many_formatting_errors.py diff --git a/code_to_optimize/few_formatting_errors.py b/code_to_optimize/few_formatting_errors.py new file mode 100644 index 000000000..905be2b39 --- /dev/null +++ b/code_to_optimize/few_formatting_errors.py @@ -0,0 +1,47 @@ +import os + +class BadlyFormattedClass(object): + def __init__( + self, + name, + age= None, + email= None, + phone=None, + address=None, + city=None, + state=None, + zip_code=None, + ): + self.name = name + self.age = age + self.email = email + self.phone = phone + self. address = address + self.city = city + self.state = state + self.zip_code = zip_code + self.data = {"name": name, "age": age, "email": email} + + def get_info(self): + return f"Name: {self.name}, Age: {self.age}" + + def update_data(self, **kwargs): + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + self.data.update(kwargs) + + +def process_data( + data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False +): + if not data_list: + return [] + if filter_func: + data_list = [ item for item in data_list if filter_func(item)] + if transform_func: + data_list = [transform_func(item) for item in data_list] + if sort_key: + data_list = sorted(data_list, key=sort_key, reverse=reverse) + return data_list + diff --git a/code_to_optimize/many_formatting_errors.py b/code_to_optimize/many_formatting_errors.py new file mode 100644 index 000000000..bd792e3d3 --- /dev/null +++ b/code_to_optimize/many_formatting_errors.py @@ -0,0 +1,147 @@ +import os,sys,json,datetime,math,random;import requests;from collections import defaultdict,OrderedDict +from typing import List,Dict,Optional,Union,Tuple,Any;import numpy as np;import pandas as pd + +# This is a poorly formatted Python file with many style violations + +class BadlyFormattedClass( object ): + def __init__(self,name,age=None,email=None,phone=None,address=None,city=None,state=None,zip_code=None): + self.name=name;self.age=age;self.email=email;self.phone=phone + self.address=address;self.city=city;self.state=state;self.zip_code=zip_code + self.data={"name":name,"age":age,"email":email} + + def get_info(self ): + return f"Name: {self.name}, Age: {self.age}" + + def update_data(self,**kwargs): + for key,value in kwargs.items(): + if hasattr(self,key):setattr(self,key,value) + self.data.update(kwargs) + +def process_data(data_list,filter_func=None,transform_func=None,sort_key=None,reverse=False): + if not data_list:return[] + if filter_func:data_list=[item for item in data_list if filter_func(item)] + if transform_func:data_list=[transform_func(item)for item in data_list] + if sort_key:data_list=sorted(data_list,key=sort_key,reverse=reverse) + return data_list + +def calculate_statistics(numbers): + if not numbers:return None + mean=sum(numbers)/len(numbers); median=sorted(numbers)[len(numbers)//2] + variance=sum((x-mean)**2 for x in numbers)/len(numbers);std_dev=math.sqrt(variance) + return {"mean":mean,"median":median,"variance":variance,"std_dev":std_dev,"min":min(numbers),"max":max(numbers)} + +def complex_nested_function(x,y,z): + def inner_function_1(a,b): + def deeply_nested(c,d): + return c*d+a*b + return deeply_nested(a+1,b-1)+deeply_nested(a-1,b+1) + def inner_function_2 (a,b,c): + result=[] + for i in range(a): + for j in range(b): + for k in range(c): + if i*j*k>0:result.append(i*j*k) + elif i+j+k==0:result.append(-1) + else :result.append(0) + return result + return inner_function_1(x,y)+sum(inner_function_2(x,y,z)) + +# Long lines and poor dictionary formatting +user_data={"users":[{"id":1,"name":"John Doe","email":"john@example.com","preferences":{"theme":"dark","notifications":True,"language":"en"},"metadata":{"created_at":"2023-01-01","last_login":"2024-01-01","login_count":150}},{"id":2,"name":"Jane Smith","email":"jane@example.com","preferences":{"theme":"light","notifications":False,"language":"es"},"metadata":{"created_at":"2023-02-15","last_login":"2024-01-15","login_count":89}}]} + +# Poor list formatting and string concatenation +long_list_of_items=['item_1','item_2','item_3','item_4','item_5','item_6','item_7','item_8','item_9','item_10','item_11','item_12','item_13','item_14','item_15','item_16','item_17','item_18','item_19','item_20'] + +def generate_report(data,include_stats=True,include_charts=False,format_type='json',output_file=None): + if not data:raise ValueError("Data cannot be empty") + report={'timestamp':datetime.datetime.now().isoformat(),'data_count':len(data),'summary':{}} + + # Bad formatting in loops and conditionals + for i,item in enumerate(data): + if isinstance(item,dict): + for key,value in item.items(): + if key not in report['summary']:report['summary'][key]=[] + report['summary'][key].append(value) + elif isinstance(item,(int,float)): + if 'numbers' not in report['summary']:report['summary']['numbers']=[] + report['summary']['numbers'].append(item) + else: + if 'other' not in report['summary']:report['summary']['other']=[] + report['summary']['other'].append(str(item)) + + if include_stats and 'numbers' in report['summary']: + numbers=report['summary']['numbers'] + report['statistics']=calculate_statistics(numbers) + + # Long conditional chain with poor formatting + if format_type=='json':result=json.dumps(report,indent=None,separators=(',',':')) + elif format_type=='pretty_json':result=json.dumps(report,indent=2) + elif format_type=='string':result=str(report) + else:result=report + + if output_file: + with open(output_file,'w')as f:f.write(result if isinstance(result,str)else json.dumps(result)) + + return result + +class DataProcessor ( BadlyFormattedClass ) : + def __init__(self,data_source,config=None,debug=False): + super().__init__("DataProcessor") + self.data_source=data_source;self.config=config or{};self.debug=debug + self.processed_data=[];self.errors=[];self.warnings=[] + + def load_data ( self ) : + try: + if isinstance(self.data_source,str): + if self.data_source.endswith('.json'): + with open(self.data_source,'r')as f:data=json.load(f) + elif self.data_source.endswith('.csv'):data=pd.read_csv(self.data_source).to_dict('records') + else:raise ValueError(f"Unsupported file type: {self.data_source}") + elif isinstance(self.data_source,list):data=self.data_source + else:data=[self.data_source] + return data + except Exception as e: + self.errors.append(str(e));return[] + + def validate_data(self,data): + valid_items=[];invalid_items=[] + for item in data: + if isinstance(item,dict)and'id'in item and'name'in item:valid_items.append(item) + else:invalid_items.append(item) + if invalid_items:self.warnings.append(f"Found {len(invalid_items)} invalid items") + return valid_items + + def process(self): + data=self.load_data() + if not data:return{"success":False,"error":"No data loaded"} + + validated_data=self.validate_data(data) + processed_result=process_data(validated_data, + filter_func=lambda x:x.get('active',True), + transform_func=lambda x:{**x,'processed_at':datetime.datetime.now().isoformat()}, + sort_key=lambda x:x.get('name','')) + + self.processed_data=processed_result + return{"success":True,"count":len(processed_result),"data":processed_result} +if __name__=="__main__": + sample_data=[{"id":1,"name":"Alice","active":True},{"id":2,"name":"Bob","active":False},{"id":3,"name":"Charlie","active":True}] + + processor=DataProcessor(sample_data,config={"debug":True}) + result=processor.process() + + if result["success"]: + print(f"Successfully processed {result['count']} items") + for item in result["data"][:3]:print(f"- {item['name']} (ID: {item['id']})") + else:print(f"Processing failed: {result.get('error','Unknown error')}") + + # Generate report with poor formatting + report=generate_report(sample_data,include_stats=True,format_type='pretty_json') + print("Generated report:",report[:100]+"..."if len(report)>100 else report) + + # Complex calculation with poor spacing + numbers=[random.randint(1,100)for _ in range(50)] + stats=calculate_statistics(numbers) + complex_result=complex_nested_function(5,3,2) + + print(f"Statistics: mean={stats['mean']:.2f}, std_dev={stats['std_dev']:.2f}") + print(f"Complex calculation result: {complex_result}") \ No newline at end of file diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 927a4d4cb..0b673ae28 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -13,14 +13,49 @@ from pathlib import Path +def should_format_file(filepath, max_lines_changed=50): + try: + # check if black is installed + subprocess.run(['black', '--version'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + result = subprocess.run( + ['black', '--diff', filepath], + capture_output=True, + text=True + ) + + if result.returncode == 0 and not result.stdout: + return False + + diff_lines = [line for line in result.stdout.split('\n') + if line.startswith(('+', '-')) and not line.startswith(('+++', '---'))] + + changes_count = len(diff_lines) + + if changes_count > max_lines_changed: + logger.debug(f"Skipping {filepath}: {changes_count} lines would change (max: {max_lines_changed})") + return False + + return True + + except subprocess.CalledProcessError: + logger.warning(f"black command failed for {filepath}") + return False + except FileNotFoundError: + logger.warning("black is not installed. Skipping formatting check.") + return False + + + def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution formatter_name = formatter_cmds[0].lower() if not path.exists(): msg = f"File {path} does not exist. Cannot format the file." raise FileNotFoundError(msg) - if formatter_name == "disabled": + if formatter_name == "disabled" or not should_format_file(path): return path.read_text(encoding="utf8") + file_token = "$file" # noqa: S105 for command in formatter_cmds: formatter_cmd_list = shlex.split(command, posix=os.name != "nt") @@ -29,7 +64,7 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True result = subprocess.run(formatter_cmd_list, capture_output=True, check=False) if result.returncode == 0: if print_status: - console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}") + console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}") else: logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}") except FileNotFoundError as e: diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..14f6789e1 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,12 +1,17 @@ +import argparse import os import tempfile from pathlib import Path import pytest +import shutil from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig def test_remove_duplicate_imports(): """Test that duplicate imports are removed when should_sort_imports is True.""" @@ -209,3 +214,67 @@ def foo(): tmp_path = tmp.name with pytest.raises(FileNotFoundError): format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + + +def _run_formatting_test(source_filename: str, should_content_change: bool): + """Helper function to run formatting tests with common setup and teardown.""" + with tempfile.TemporaryDirectory() as test_dir_str: + test_dir = Path(test_dir_str) + this_file = Path(__file__).resolve() + repo_root_dir = this_file.parent.parent + source_file = repo_root_dir / "code_to_optimize" / source_filename + + original = source_file.read_text() + target_path = test_dir / "target.py" + + shutil.copy2(source_file, target_path) + + function_to_optimize = FunctionToOptimize( + function_name="process_data", + parents=[], + file_path=target_path + ) + + test_cfg = TestConfig( + tests_root=test_dir, + project_root_path=test_dir, + test_framework="pytest", + tests_project_rootdir=test_dir, + ) + + args = argparse.Namespace( + disable_imports_sorting=False, + formatter_cmds=[ + "ruff check --exit-zero --fix $file", + "ruff format $file" + ], + ) + + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + args=args, + ) + + optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + ) + + content = target_path.read_text() + + if should_content_change: + assert content != original, f"Expected content to change for {source_filename}" + else: + assert content == original, f"Expected content to remain unchanged for {source_filename}" + + +def test_formatting_file_with_many_diffs(): + """Test that files with many formatting errors are skipped (content unchanged).""" + _run_formatting_test("many_formatting_errors.py", should_content_change=False) + + +def test_formatting_file_with_few_diffs(): + """Test that files with few formatting errors are formatted (content changed).""" + _run_formatting_test("few_formatting_errors.py", should_content_change=True) \ No newline at end of file From 5cd13ad1caeb98fcc3b0c39f69b98d9abb84ed8f Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 11:58:55 +0300 Subject: [PATCH 02/16] new line --- tests/test_formatter.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 14f6789e1..3f45460eb 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -11,6 +11,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.optimization.function_optimizer import FunctionSource from codeflash.verification.verification_utils import TestConfig def test_remove_duplicate_imports(): @@ -257,7 +258,9 @@ def _run_formatting_test(source_filename: str, should_content_change: bool): ) optimizer.reformat_code_and_helpers( - helper_functions=[], + helper_functions=[ + FunctionSource() + ], path=target_path, original_code=optimizer.function_to_optimize_source_code, ) @@ -277,4 +280,4 @@ def test_formatting_file_with_many_diffs(): def test_formatting_file_with_few_diffs(): """Test that files with few formatting errors are formatted (content changed).""" - _run_formatting_test("few_formatting_errors.py", should_content_change=True) \ No newline at end of file + _run_formatting_test("few_formatting_errors.py", should_content_change=True) From 152222726c19b5abb28d983180334f0708bc5476 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 12:02:53 +0300 Subject: [PATCH 03/16] better log messages --- codeflash/code_utils/formatter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 0b673ae28..3d6eff6cd 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -39,10 +39,10 @@ def should_format_file(filepath, max_lines_changed=50): return True except subprocess.CalledProcessError: - logger.warning(f"black command failed for {filepath}") + logger.warning(f"black --diff command failed for {filepath}") return False except FileNotFoundError: - logger.warning("black is not installed. Skipping formatting check.") + logger.warning("black formatter is not installed. Skipping formatting diff check.") return False From d3ca1cbf94e464d0cbecd0c234cb20885b7bf517 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 12:10:43 +0300 Subject: [PATCH 04/16] remove unnecessary check --- codeflash/code_utils/formatter.py | 3 --- tests/test_formatter.py | 5 +---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 3d6eff6cd..f301bd013 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -23,9 +23,6 @@ def should_format_file(filepath, max_lines_changed=50): capture_output=True, text=True ) - - if result.returncode == 0 and not result.stdout: - return False diff_lines = [line for line in result.stdout.split('\n') if line.startswith(('+', '-')) and not line.startswith(('+++', '---'))] diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 3f45460eb..7b0a43b42 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -11,7 +11,6 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.optimization.function_optimizer import FunctionOptimizer -from codeflash.optimization.function_optimizer import FunctionSource from codeflash.verification.verification_utils import TestConfig def test_remove_duplicate_imports(): @@ -258,9 +257,7 @@ def _run_formatting_test(source_filename: str, should_content_change: bool): ) optimizer.reformat_code_and_helpers( - helper_functions=[ - FunctionSource() - ], + helper_functions=[], path=target_path, original_code=optimizer.function_to_optimize_source_code, ) From dcb084ad12df7e01b82593e3a5f47a8b15a534e3 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 12:14:59 +0300 Subject: [PATCH 05/16] new line --- code_to_optimize/many_formatting_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code_to_optimize/many_formatting_errors.py b/code_to_optimize/many_formatting_errors.py index bd792e3d3..702539f70 100644 --- a/code_to_optimize/many_formatting_errors.py +++ b/code_to_optimize/many_formatting_errors.py @@ -144,4 +144,4 @@ def process(self): complex_result=complex_nested_function(5,3,2) print(f"Statistics: mean={stats['mean']:.2f}, std_dev={stats['std_dev']:.2f}") - print(f"Complex calculation result: {complex_result}") \ No newline at end of file + print(f"Complex calculation result: {complex_result}") From 689a2d97af6e617407f1075da5e85ec9d67b8097 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 13:00:22 +0300 Subject: [PATCH 06/16] remove unused comment --- tests/test_formatter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 7b0a43b42..3106ee330 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -217,7 +217,6 @@ def foo(): def _run_formatting_test(source_filename: str, should_content_change: bool): - """Helper function to run formatting tests with common setup and teardown.""" with tempfile.TemporaryDirectory() as test_dir_str: test_dir = Path(test_dir_str) this_file = Path(__file__).resolve() From 44c0f85b6f7c1b4b047528426e6157fae852e681 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 13:55:10 +0300 Subject: [PATCH 07/16] the max lines for formatting changes to 100 --- codeflash/code_utils/formatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index f301bd013..13b330746 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -13,7 +13,7 @@ from pathlib import Path -def should_format_file(filepath, max_lines_changed=50): +def should_format_file(filepath, max_lines_changed=100): try: # check if black is installed subprocess.run(['black', '--version'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) From 73ef51821ac8c3ec5daafe4234a5bf1d518f30f2 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 19:39:15 +0300 Subject: [PATCH 08/16] refactoring --- code_to_optimize/few_formatting_errors.py | 2 +- code_to_optimize/many_formatting_errors.py | 4 +- codeflash/code_utils/formatter.py | 89 ++++++++++++++-------- tests/test_formatter.py | 7 ++ 4 files changed, 66 insertions(+), 36 deletions(-) diff --git a/code_to_optimize/few_formatting_errors.py b/code_to_optimize/few_formatting_errors.py index 905be2b39..27ed71b44 100644 --- a/code_to_optimize/few_formatting_errors.py +++ b/code_to_optimize/few_formatting_errors.py @@ -1,6 +1,6 @@ import os -class BadlyFormattedClass(object): +class UnformattedExampleClass(object): def __init__( self, name, diff --git a/code_to_optimize/many_formatting_errors.py b/code_to_optimize/many_formatting_errors.py index 702539f70..79cfc825d 100644 --- a/code_to_optimize/many_formatting_errors.py +++ b/code_to_optimize/many_formatting_errors.py @@ -3,7 +3,7 @@ # This is a poorly formatted Python file with many style violations -class BadlyFormattedClass( object ): +class UnformattedExampleClass( object ): def __init__(self,name,age=None,email=None,phone=None,address=None,city=None,state=None,zip_code=None): self.name=name;self.age=age;self.email=email;self.phone=phone self.address=address;self.city=city;self.state=state;self.zip_code=zip_code @@ -84,7 +84,7 @@ def generate_report(data,include_stats=True,include_charts=False,format_type='js return result -class DataProcessor ( BadlyFormattedClass ) : +class DataProcessor ( UnformattedExampleClass ) : def __init__(self,data_source,config=None,debug=False): super().__init__("DataProcessor") self.data_source=data_source;self.config=config or{};self.debug=debug diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 13b330746..94b5c7dc5 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -3,7 +3,7 @@ import os import shlex import subprocess -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import isort @@ -12,37 +12,60 @@ if TYPE_CHECKING: from pathlib import Path - -def should_format_file(filepath, max_lines_changed=100): - try: - # check if black is installed - subprocess.run(['black', '--version'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - - result = subprocess.run( - ['black', '--diff', filepath], - capture_output=True, - text=True - ) - - diff_lines = [line for line in result.stdout.split('\n') - if line.startswith(('+', '-')) and not line.startswith(('+++', '---'))] - - changes_count = len(diff_lines) - - if changes_count > max_lines_changed: - logger.debug(f"Skipping {filepath}: {changes_count} lines would change (max: {max_lines_changed})") - return False - - return True - - except subprocess.CalledProcessError: - logger.warning(f"black --diff command failed for {filepath}") - return False - except FileNotFoundError: - logger.warning("black formatter is not installed. Skipping formatting diff check.") - return False - - +def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: + try: + subprocess.run(['black', '--version'], check=True, + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + result = subprocess.run( + ['black', '--diff', filepath], + capture_output=True, + text=True + ) + return result.stdout.strip() if result.stdout else None + except (FileNotFoundError): + return None + + +def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]: + try: + subprocess.run(['ruff', '--version'], check=True, + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + result = subprocess.run( + ['ruff', "format", '--diff', filepath], + capture_output=True, + text=True + ) + return result.stdout.strip() if result.stdout else None + except (FileNotFoundError): + return None + + +def get_diff_lines_count(diff_output: str) -> int: + diff_lines = [line for line in diff_output.split('\n') + if line.startswith(('+', '-')) and not line.startswith(('+++', '---'))] + return len(diff_lines) + +def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: + diff_changes_stdout = None + + diff_changes_stdout = get_diff_lines_output_by_black(filepath) + + if diff_changes_stdout is None: + logger.warning(f"black formatter not found, trying ruff instead...") + diff_changes_stdout = get_diff_lines_output_by_ruff(filepath) + if diff_changes_stdout is None: + msg = f"Both ruff, black formatters not found, skipping formatting diff check." + logger.warning(msg) + raise FileNotFoundError(msg) + + diff_lines_count = get_diff_lines_count(diff_changes_stdout) + + if diff_lines_count > max_diff_lines: + logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") + return False + else: + return True + def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution @@ -50,7 +73,7 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True if not path.exists(): msg = f"File {path} does not exist. Cannot format the file." raise FileNotFoundError(msg) - if formatter_name == "disabled" or not should_format_file(path): + if formatter_name == "disabled" or not is_safe_to_format(path): # few -> False, large -> True return path.read_text(encoding="utf8") file_token = "$file" # noqa: S105 diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 3106ee330..ed2d7233a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -268,12 +268,19 @@ def _run_formatting_test(source_filename: str, should_content_change: bool): else: assert content == original, f"Expected content to remain unchanged for {source_filename}" +def _ruff_or_black_installed() -> bool: + return shutil.which("black") is not None or shutil.which("ruff") is not None + def test_formatting_file_with_many_diffs(): """Test that files with many formatting errors are skipped (content unchanged).""" + if not _ruff_or_black_installed(): + pytest.skip("Neither black nor ruff is installed, skipping formatting tests.") _run_formatting_test("many_formatting_errors.py", should_content_change=False) def test_formatting_file_with_few_diffs(): """Test that files with few formatting errors are formatted (content changed).""" + if not _ruff_or_black_installed(): + pytest.skip("Neither black nor ruff is installed, skipping formatting tests.") _run_formatting_test("few_formatting_errors.py", should_content_change=True) From a5343fd9454eebf471fc893ad587a33a2e75b705 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 23:37:18 +0300 Subject: [PATCH 09/16] refactoring and improvements --- codeflash/code_utils/formatter.py | 61 ++++++++++++++++++------------- tests/test_formatter.py | 3 +- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 94b5c7dc5..3d5b587c6 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -4,7 +4,6 @@ import shlex import subprocess from typing import TYPE_CHECKING, Optional - import isort from codeflash.cli_cmds.console import console, logger @@ -12,37 +11,48 @@ if TYPE_CHECKING: from pathlib import Path -def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: +def get_nth_line(text: str, n: int) -> str | None: + for i, line in enumerate(text.splitlines(), start=1): + if i == n: + return line + return None + +def get_diff_output(cmd: list[str]) -> Optional[str]: try: - subprocess.run(['black', '--version'], check=True, - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - result = subprocess.run( - ['black', '--diff', filepath], - capture_output=True, - text=True - ) - return result.stdout.strip() if result.stdout else None - except (FileNotFoundError): + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + return result.stdout.strip() or None + except (FileNotFoundError, subprocess.CalledProcessError) as e: + if isinstance(e, subprocess.CalledProcessError): + # ruff returns 1 when the file needs formatting, and 0 when it is already formatted + is_ruff = cmd[0] == "ruff" + if e.returncode == 0 and is_ruff: + return "" + elif e.returncode == 1 and is_ruff: + return e.stdout.strip() or None return None +def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: + try: + import black # type: ignore + return get_diff_output(['black', '--diff', filepath]) + except ImportError: + return None + def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]: try: - subprocess.run(['ruff', '--version'], check=True, - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - result = subprocess.run( - ['ruff', "format", '--diff', filepath], - capture_output=True, - text=True - ) - return result.stdout.strip() if result.stdout else None - except (FileNotFoundError): + import ruff # type: ignore + return get_diff_output(['ruff', 'format', '--diff', filepath]) + except ImportError: + print("can't import ruff") return None def get_diff_lines_count(diff_output: str) -> int: - diff_lines = [line for line in diff_output.split('\n') - if line.startswith(('+', '-')) and not line.startswith(('+++', '---'))] + lines = diff_output.split('\n') + def is_diff_line(line: str) -> bool: + return line.startswith(('+', '-')) and not line.startswith(('+++', '---')) + diff_lines = [line for line in lines if is_diff_line(line)] return len(diff_lines) def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: @@ -54,9 +64,8 @@ def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: logger.warning(f"black formatter not found, trying ruff instead...") diff_changes_stdout = get_diff_lines_output_by_ruff(filepath) if diff_changes_stdout is None: - msg = f"Both ruff, black formatters not found, skipping formatting diff check." - logger.warning(msg) - raise FileNotFoundError(msg) + logger.warning(f"Both ruff, black formatters not found, skipping formatting diff check.") + return False diff_lines_count = get_diff_lines_count(diff_changes_stdout) @@ -73,7 +82,7 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True if not path.exists(): msg = f"File {path} does not exist. Cannot format the file." raise FileNotFoundError(msg) - if formatter_name == "disabled" or not is_safe_to_format(path): # few -> False, large -> True + if formatter_name == "disabled" or not is_safe_to_format(str(path)): return path.read_text(encoding="utf8") file_token = "$file" # noqa: S105 diff --git a/tests/test_formatter.py b/tests/test_formatter.py index ed2d7233a..c2e7864e6 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -217,6 +217,8 @@ def foo(): def _run_formatting_test(source_filename: str, should_content_change: bool): + if shutil.which("ruff") is None: + pytest.skip("ruff is not installed, skipping.") with tempfile.TemporaryDirectory() as test_dir_str: test_dir = Path(test_dir_str) this_file = Path(__file__).resolve() @@ -262,7 +264,6 @@ def _run_formatting_test(source_filename: str, should_content_change: bool): ) content = target_path.read_text() - if should_content_change: assert content != original, f"Expected content to change for {source_filename}" else: From 395855d5c214c963d0c4784ccb3a42926074b6df Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 23:50:44 +0300 Subject: [PATCH 10/16] added black as dev dependency --- poetry.lock | 68 +++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 + tests/test_formatter.py | 13 +++----- 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/poetry.lock b/poetry.lock index 04cfeae09..b80c86387 100644 --- a/poetry.lock +++ b/poetry.lock @@ -73,6 +73,53 @@ files = [ {file = "backoff-1.11.1.tar.gz", hash = "sha256:ccb962a2378418c667b3c979b504fdeb7d9e0d29c0579e3b13b86467177728cb"}, ] +[[package]] +name = "black" +version = "25.1.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32"}, + {file = "black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da"}, + {file = "black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7"}, + {file = "black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9"}, + {file = "black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0"}, + {file = "black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299"}, + {file = "black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096"}, + {file = "black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2"}, + {file = "black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b"}, + {file = "black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc"}, + {file = "black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f"}, + {file = "black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba"}, + {file = "black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f"}, + {file = "black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3"}, + {file = "black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171"}, + {file = "black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18"}, + {file = "black-25.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1ee0a0c330f7b5130ce0caed9936a904793576ef4d2b98c40835d6a65afa6a0"}, + {file = "black-25.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3df5f1bf91d36002b0a75389ca8663510cf0531cca8aa5c1ef695b46d98655f"}, + {file = "black-25.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9e6827d563a2c820772b32ce8a42828dc6790f095f441beef18f96aa6f8294e"}, + {file = "black-25.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:bacabb307dca5ebaf9c118d2d2f6903da0d62c9faa82bd21a33eecc319559355"}, + {file = "black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717"}, + {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.10)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "blessed" version = "1.21.0" @@ -248,7 +295,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -264,11 +311,11 @@ description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" groups = ["main", "dev"] +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "sys_platform == \"win32\" or platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [[package]] name = "coverage" @@ -1025,8 +1072,11 @@ files = [ {file = "lxml-5.4.0-cp36-cp36m-win_amd64.whl", hash = "sha256:7ce1a171ec325192c6a636b64c94418e71a1964f56d002cc28122fceff0b6121"}, {file = "lxml-5.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:795f61bcaf8770e1b37eec24edf9771b307df3af74d1d6f27d812e15a9ff3872"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:29f451a4b614a7b5b6c2e043d7b64a15bd8304d7e767055e8ab68387a8cacf4e"}, + {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:891f7f991a68d20c75cb13c5c9142b2a3f9eb161f1f12a9489c82172d1f133c0"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4aa412a82e460571fad592d0f93ce9935a20090029ba08eca05c614f99b0cc92"}, + {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:ac7ba71f9561cd7d7b55e1ea5511543c0282e2b6450f122672a2694621d63b7e"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:c5d32f5284012deaccd37da1e2cd42f081feaa76981f0eaa474351b68df813c5"}, + {file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:ce31158630a6ac85bddd6b830cffd46085ff90498b397bd0a259f59d27a12188"}, {file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:31e63621e073e04697c1b2d23fcb89991790eef370ec37ce4d5d469f40924ed6"}, {file = "lxml-5.4.0-cp37-cp37m-win32.whl", hash = "sha256:be2ba4c3c5b7900246a8f866580700ef0d538f2ca32535e991027bdaba944063"}, {file = "lxml-5.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:09846782b1ef650b321484ad429217f5154da4d6e786636c38e434fa32e94e49"}, @@ -1344,6 +1394,18 @@ files = [ qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["docopt", "pytest"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2686,4 +2748,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.9" -content-hash = "1a73e9db33e3884cf1cc6e3371816aebd20831845ef9bf671be315e659480e86" +content-hash = "d0b959755aad4882df502f8ba219b865df472ba1830d5adf8e757aa6436bc3df" diff --git a/pyproject.toml b/pyproject.toml index c3e48f889..dd38137ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ types-pexpect = "^4.9.0.20241208" types-unidiff = "^0.7.0.20240505" uv = ">=0.6.2" pre-commit = "^4.2.0" +black = "^25.1.0" [tool.poetry.build] script = "codeflash/update_license_version.py" diff --git a/tests/test_formatter.py b/tests/test_formatter.py index c2e7864e6..b6c87b190 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -217,8 +217,10 @@ def foo(): def _run_formatting_test(source_filename: str, should_content_change: bool): - if shutil.which("ruff") is None: - pytest.skip("ruff is not installed, skipping.") + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") with tempfile.TemporaryDirectory() as test_dir_str: test_dir = Path(test_dir_str) this_file = Path(__file__).resolve() @@ -269,19 +271,12 @@ def _run_formatting_test(source_filename: str, should_content_change: bool): else: assert content == original, f"Expected content to remain unchanged for {source_filename}" -def _ruff_or_black_installed() -> bool: - return shutil.which("black") is not None or shutil.which("ruff") is not None - def test_formatting_file_with_many_diffs(): """Test that files with many formatting errors are skipped (content unchanged).""" - if not _ruff_or_black_installed(): - pytest.skip("Neither black nor ruff is installed, skipping formatting tests.") _run_formatting_test("many_formatting_errors.py", should_content_change=False) def test_formatting_file_with_few_diffs(): """Test that files with few formatting errors are formatted (content changed).""" - if not _ruff_or_black_installed(): - pytest.skip("Neither black nor ruff is installed, skipping formatting tests.") _run_formatting_test("few_formatting_errors.py", should_content_change=True) From 822d6cc015d1a5dc3e6c28bea4a1ef599cb19a05 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 3 Jun 2025 23:57:55 +0300 Subject: [PATCH 11/16] made some refactor changes that codeflash suggested --- codeflash/code_utils/formatter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 3d5b587c6..e1d269aa7 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -27,7 +27,7 @@ def get_diff_output(cmd: list[str]) -> Optional[str]: is_ruff = cmd[0] == "ruff" if e.returncode == 0 and is_ruff: return "" - elif e.returncode == 1 and is_ruff: + if e.returncode == 1 and is_ruff: return e.stdout.strip() or None return None @@ -61,10 +61,10 @@ def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: diff_changes_stdout = get_diff_lines_output_by_black(filepath) if diff_changes_stdout is None: - logger.warning(f"black formatter not found, trying ruff instead...") + logger.warning("black formatter not found, trying ruff instead...") diff_changes_stdout = get_diff_lines_output_by_ruff(filepath) if diff_changes_stdout is None: - logger.warning(f"Both ruff, black formatters not found, skipping formatting diff check.") + logger.warning("Both ruff, black formatters not found, skipping formatting diff check.") return False diff_lines_count = get_diff_lines_count(diff_changes_stdout) @@ -72,8 +72,8 @@ def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: if diff_lines_count > max_diff_lines: logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") return False - else: - return True + + return True def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa From ce1502284a07e2adcf0c5a0ec080ff360ab81eab Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 4 Jun 2025 00:42:40 +0300 Subject: [PATCH 12/16] remove unused function --- codeflash/code_utils/formatter.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index e1d269aa7..ec077f444 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -11,12 +11,6 @@ if TYPE_CHECKING: from pathlib import Path -def get_nth_line(text: str, n: int) -> str | None: - for i, line in enumerate(text.splitlines(), start=1): - if i == n: - return line - return None - def get_diff_output(cmd: list[str]) -> Optional[str]: try: result = subprocess.run(cmd, capture_output=True, text=True, check=True) From d2a87116ec4702fefdd240f25d88a0073ef7ea0d Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 4 Jun 2025 02:27:24 +0300 Subject: [PATCH 13/16] formatting & using internal black dep --- codeflash/code_utils/formatter.py | 65 +++++++++++-------------------- poetry.lock | 10 ++--- pyproject.toml | 2 +- tests/test_formatter.py | 3 +- 4 files changed, 30 insertions(+), 50 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index ec077f444..3144416e1 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -4,6 +4,7 @@ import shlex import subprocess from typing import TYPE_CHECKING, Optional + import isort from codeflash.cli_cmds.console import console, logger @@ -11,64 +12,43 @@ if TYPE_CHECKING: from pathlib import Path -def get_diff_output(cmd: list[str]) -> Optional[str]: - try: - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - return result.stdout.strip() or None - except (FileNotFoundError, subprocess.CalledProcessError) as e: - if isinstance(e, subprocess.CalledProcessError): - # ruff returns 1 when the file needs formatting, and 0 when it is already formatted - is_ruff = cmd[0] == "ruff" - if e.returncode == 0 and is_ruff: - return "" - if e.returncode == 1 and is_ruff: - return e.stdout.strip() or None - return None - -def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: +def get_diff_output_by_black(filepath: str, unformatted_content: str) -> Optional[str]: try: - import black # type: ignore - return get_diff_output(['black', '--diff', filepath]) - except ImportError: - return None + import black -def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]: - try: - import ruff # type: ignore - return get_diff_output(['ruff', 'format', '--diff', filepath]) + formatted_content = black.format_file_contents(src_contents=unformatted_content, fast=True, mode=black.Mode()) + return black.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath) except ImportError: - print("can't import ruff") return None def get_diff_lines_count(diff_output: str) -> int: - lines = diff_output.split('\n') + lines = diff_output.split("\n") + def is_diff_line(line: str) -> bool: - return line.startswith(('+', '-')) and not line.startswith(('+++', '---')) + return line.startswith(("+", "-")) and not line.startswith(("+++", "---")) + diff_lines = [line for line in lines if is_diff_line(line)] return len(diff_lines) -def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: - diff_changes_stdout = None - diff_changes_stdout = get_diff_lines_output_by_black(filepath) +def is_safe_to_format(filepath: str, content: str, max_diff_lines: int = 100) -> bool: + diff_changes_str = None + + diff_changes_str = get_diff_output_by_black(filepath, unformatted_content=content) - if diff_changes_stdout is None: - logger.warning("black formatter not found, trying ruff instead...") - diff_changes_stdout = get_diff_lines_output_by_ruff(filepath) - if diff_changes_stdout is None: - logger.warning("Both ruff, black formatters not found, skipping formatting diff check.") - return False - - diff_lines_count = get_diff_lines_count(diff_changes_stdout) - + if diff_changes_str is None: + logger.warning("Looks like black formatter not found, make sure it is installed.") + return False + + diff_lines_count = get_diff_lines_count(diff_changes_str) if diff_lines_count > max_diff_lines: - logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") + logger.debug(f"Skipping formatting {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") return False return True - + def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution @@ -76,8 +56,9 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True if not path.exists(): msg = f"File {path} does not exist. Cannot format the file." raise FileNotFoundError(msg) - if formatter_name == "disabled" or not is_safe_to_format(str(path)): - return path.read_text(encoding="utf8") + file_content = path.read_text(encoding="utf8") + if formatter_name == "disabled" or not is_safe_to_format(filepath=str(path), content=file_content): + return file_content file_token = "$file" # noqa: S105 for command in formatter_cmds: diff --git a/poetry.lock b/poetry.lock index b80c86387..ab3e6054b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -79,7 +79,7 @@ version = "25.1.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.9" -groups = ["dev"] +groups = ["main"] files = [ {file = "black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32"}, {file = "black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da"}, @@ -295,7 +295,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -311,11 +311,11 @@ description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" groups = ["main", "dev"] -markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "sys_platform == \"win32\" or platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [[package]] name = "coverage" @@ -1400,7 +1400,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -2748,4 +2748,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.9" -content-hash = "d0b959755aad4882df502f8ba219b865df472ba1830d5adf8e757aa6436bc3df" +content-hash = "1ba28119bcc2b572133da8f243eea42fc8f732b6255afac7c2c7e616e2c68677" diff --git a/pyproject.toml b/pyproject.toml index dd38137ee..6a5c4904a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ crosshair-tool = ">=0.0.78" coverage = ">=7.6.4" line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13 platformdirs = ">=4.3.7" +black = "^25.1.0" [tool.poetry.group.dev] optional = true @@ -123,7 +124,6 @@ types-pexpect = "^4.9.0.20241208" types-unidiff = "^0.7.0.20240505" uv = ">=0.6.2" pre-commit = "^4.2.0" -black = "^25.1.0" [tool.poetry.build] script = "codeflash/update_license_version.py" diff --git a/tests/test_formatter.py b/tests/test_formatter.py index b6c87b190..b500bbb4f 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -259,13 +259,12 @@ def _run_formatting_test(source_filename: str, should_content_change: bool): args=args, ) - optimizer.reformat_code_and_helpers( + content, _ = optimizer.reformat_code_and_helpers( helper_functions=[], path=target_path, original_code=optimizer.function_to_optimize_source_code, ) - content = target_path.read_text() if should_content_change: assert content != original, f"Expected content to change for {source_filename}" else: From f46b3683b1391517cd13d3b666fdcf10fb382861 Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 4 Jun 2025 03:01:51 +0300 Subject: [PATCH 14/16] fix black import issue --- codeflash/code_utils/formatter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 3144416e1..afbced761 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -15,10 +15,10 @@ def get_diff_output_by_black(filepath: str, unformatted_content: str) -> Optional[str]: try: - import black + from black import Mode, format_file_contents, output - formatted_content = black.format_file_contents(src_contents=unformatted_content, fast=True, mode=black.Mode()) - return black.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath) + formatted_content = format_file_contents(src_contents=unformatted_content, fast=True, mode=Mode()) + return output.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath) except ImportError: return None From 6504cc4cc92725ca1dace7a57759dfe0c124fb0d Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 4 Jun 2025 03:51:02 +0300 Subject: [PATCH 15/16] handle formatting files with no formatting issues --- code_to_optimize/no_formatting_errors.py | 71 ++++++++++++++++++++++++ codeflash/code_utils/formatter.py | 4 +- tests/test_formatter.py | 4 ++ 3 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 code_to_optimize/no_formatting_errors.py diff --git a/code_to_optimize/no_formatting_errors.py b/code_to_optimize/no_formatting_errors.py new file mode 100644 index 000000000..3d32bc94c --- /dev/null +++ b/code_to_optimize/no_formatting_errors.py @@ -0,0 +1,71 @@ +import os, sys, json, datetime, math, random +import requests +from collections import defaultdict, OrderedDict +from typing import List, Dict, Optional, Union, Tuple, Any +import numpy as np +import pandas as pd + +# This is a poorly formatted Python file with many style violations + + +class UnformattedExampleClass(object): + def __init__( + self, + name, + age=None, + email=None, + phone=None, + address=None, + city=None, + state=None, + zip_code=None, + ): + self.name = name + self.age = age + self.email = email + self.phone = phone + self.address = address + self.city = city + self.state = state + self.zip_code = zip_code + self.data = {"name": name, "age": age, "email": email} + + def get_info(self): + return f"Name: {self.name}, Age: {self.age}" + + def update_data(self, **kwargs): + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + self.data.update(kwargs) + + +def process_data( + data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False +): + if not data_list: + return [] + if filter_func: + data_list = [item for item in data_list if filter_func(item)] + if transform_func: + data_list = [transform_func(item) for item in data_list] + if sort_key: + data_list = sorted(data_list, key=sort_key, reverse=reverse) + return data_list + + +def calculate_statistics(numbers): + if not numbers: + return None + mean = sum(numbers) / len(numbers) + median = sorted(numbers)[len(numbers) // 2] + variance = sum((x - mean) ** 2 for x in numbers) / len(numbers) + std_dev = math.sqrt(variance) + return { + "mean": mean, + "median": median, + "variance": variance, + "std_dev": std_dev, + "min": min(numbers), + "max": max(numbers), + } diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index afbced761..6188e8649 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -15,11 +15,11 @@ def get_diff_output_by_black(filepath: str, unformatted_content: str) -> Optional[str]: try: - from black import Mode, format_file_contents, output + from black import Mode, format_file_contents, output, report formatted_content = format_file_contents(src_contents=unformatted_content, fast=True, mode=Mode()) return output.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath) - except ImportError: + except (ImportError, report.NothingChanged): return None diff --git a/tests/test_formatter.py b/tests/test_formatter.py index b500bbb4f..baf5b8079 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -279,3 +279,7 @@ def test_formatting_file_with_many_diffs(): def test_formatting_file_with_few_diffs(): """Test that files with few formatting errors are formatted (content changed).""" _run_formatting_test("few_formatting_errors.py", should_content_change=True) + +def test_formatting_file_with_no_diffs(): + """Test that files with no formatting errors are unchanged.""" + _run_formatting_test("no_formatting_errors.py", should_content_change=False) From d924b3140c86dbccac2259e013a4701db47ead95 Mon Sep 17 00:00:00 2001 From: Saga4 Date: Thu, 5 Jun 2025 01:31:53 +0530 Subject: [PATCH 16/16] fix_duplication_suggestion_issue --- codeflash/discovery/functions_to_optimize.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3f0d72bcd..41d99ec2c 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -476,12 +476,14 @@ def filter_functions( if blocklist_funcs: functions_tmp = [] for function in _functions: - if not ( + if ( function.file_path.name in blocklist_funcs and function.qualified_name in blocklist_funcs[function.file_path.name] ): + # This function is in blocklist, we can skip it blocklist_funcs_removed_count += 1 continue + # This function is NOT in blocklist. we can keep it functions_tmp.append(function) _functions = functions_tmp