diff --git a/code_to_optimize/few_formatting_errors.py b/code_to_optimize/few_formatting_errors.py new file mode 100644 index 000000000..27ed71b44 --- /dev/null +++ b/code_to_optimize/few_formatting_errors.py @@ -0,0 +1,47 @@ +import os + +class UnformattedExampleClass(object): + def __init__( + self, + name, + age= None, + email= None, + phone=None, + address=None, + city=None, + state=None, + zip_code=None, + ): + self.name = name + self.age = age + self.email = email + self.phone = phone + self. address = address + self.city = city + self.state = state + self.zip_code = zip_code + self.data = {"name": name, "age": age, "email": email} + + def get_info(self): + return f"Name: {self.name}, Age: {self.age}" + + def update_data(self, **kwargs): + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + self.data.update(kwargs) + + +def process_data( + data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False +): + if not data_list: + return [] + if filter_func: + data_list = [ item for item in data_list if filter_func(item)] + if transform_func: + data_list = [transform_func(item) for item in data_list] + if sort_key: + data_list = sorted(data_list, key=sort_key, reverse=reverse) + return data_list + diff --git a/code_to_optimize/many_formatting_errors.py b/code_to_optimize/many_formatting_errors.py new file mode 100644 index 000000000..79cfc825d --- /dev/null +++ b/code_to_optimize/many_formatting_errors.py @@ -0,0 +1,147 @@ +import os,sys,json,datetime,math,random;import requests;from collections import defaultdict,OrderedDict +from typing import List,Dict,Optional,Union,Tuple,Any;import numpy as np;import pandas as pd + +# This is a poorly formatted Python file with many style violations + +class UnformattedExampleClass( object ): + def __init__(self,name,age=None,email=None,phone=None,address=None,city=None,state=None,zip_code=None): + self.name=name;self.age=age;self.email=email;self.phone=phone + self.address=address;self.city=city;self.state=state;self.zip_code=zip_code + self.data={"name":name,"age":age,"email":email} + + def get_info(self ): + return f"Name: {self.name}, Age: {self.age}" + + def update_data(self,**kwargs): + for key,value in kwargs.items(): + if hasattr(self,key):setattr(self,key,value) + self.data.update(kwargs) + +def process_data(data_list,filter_func=None,transform_func=None,sort_key=None,reverse=False): + if not data_list:return[] + if filter_func:data_list=[item for item in data_list if filter_func(item)] + if transform_func:data_list=[transform_func(item)for item in data_list] + if sort_key:data_list=sorted(data_list,key=sort_key,reverse=reverse) + return data_list + +def calculate_statistics(numbers): + if not numbers:return None + mean=sum(numbers)/len(numbers); median=sorted(numbers)[len(numbers)//2] + variance=sum((x-mean)**2 for x in numbers)/len(numbers);std_dev=math.sqrt(variance) + return {"mean":mean,"median":median,"variance":variance,"std_dev":std_dev,"min":min(numbers),"max":max(numbers)} + +def complex_nested_function(x,y,z): + def inner_function_1(a,b): + def deeply_nested(c,d): + return c*d+a*b + return deeply_nested(a+1,b-1)+deeply_nested(a-1,b+1) + def inner_function_2 (a,b,c): + result=[] + for i in range(a): + for j in range(b): + for k in range(c): + if i*j*k>0:result.append(i*j*k) + elif i+j+k==0:result.append(-1) + else :result.append(0) + return result + return inner_function_1(x,y)+sum(inner_function_2(x,y,z)) + +# Long lines and poor dictionary formatting +user_data={"users":[{"id":1,"name":"John Doe","email":"john@example.com","preferences":{"theme":"dark","notifications":True,"language":"en"},"metadata":{"created_at":"2023-01-01","last_login":"2024-01-01","login_count":150}},{"id":2,"name":"Jane Smith","email":"jane@example.com","preferences":{"theme":"light","notifications":False,"language":"es"},"metadata":{"created_at":"2023-02-15","last_login":"2024-01-15","login_count":89}}]} + +# Poor list formatting and string concatenation +long_list_of_items=['item_1','item_2','item_3','item_4','item_5','item_6','item_7','item_8','item_9','item_10','item_11','item_12','item_13','item_14','item_15','item_16','item_17','item_18','item_19','item_20'] + +def generate_report(data,include_stats=True,include_charts=False,format_type='json',output_file=None): + if not data:raise ValueError("Data cannot be empty") + report={'timestamp':datetime.datetime.now().isoformat(),'data_count':len(data),'summary':{}} + + # Bad formatting in loops and conditionals + for i,item in enumerate(data): + if isinstance(item,dict): + for key,value in item.items(): + if key not in report['summary']:report['summary'][key]=[] + report['summary'][key].append(value) + elif isinstance(item,(int,float)): + if 'numbers' not in report['summary']:report['summary']['numbers']=[] + report['summary']['numbers'].append(item) + else: + if 'other' not in report['summary']:report['summary']['other']=[] + report['summary']['other'].append(str(item)) + + if include_stats and 'numbers' in report['summary']: + numbers=report['summary']['numbers'] + report['statistics']=calculate_statistics(numbers) + + # Long conditional chain with poor formatting + if format_type=='json':result=json.dumps(report,indent=None,separators=(',',':')) + elif format_type=='pretty_json':result=json.dumps(report,indent=2) + elif format_type=='string':result=str(report) + else:result=report + + if output_file: + with open(output_file,'w')as f:f.write(result if isinstance(result,str)else json.dumps(result)) + + return result + +class DataProcessor ( UnformattedExampleClass ) : + def __init__(self,data_source,config=None,debug=False): + super().__init__("DataProcessor") + self.data_source=data_source;self.config=config or{};self.debug=debug + self.processed_data=[];self.errors=[];self.warnings=[] + + def load_data ( self ) : + try: + if isinstance(self.data_source,str): + if self.data_source.endswith('.json'): + with open(self.data_source,'r')as f:data=json.load(f) + elif self.data_source.endswith('.csv'):data=pd.read_csv(self.data_source).to_dict('records') + else:raise ValueError(f"Unsupported file type: {self.data_source}") + elif isinstance(self.data_source,list):data=self.data_source + else:data=[self.data_source] + return data + except Exception as e: + self.errors.append(str(e));return[] + + def validate_data(self,data): + valid_items=[];invalid_items=[] + for item in data: + if isinstance(item,dict)and'id'in item and'name'in item:valid_items.append(item) + else:invalid_items.append(item) + if invalid_items:self.warnings.append(f"Found {len(invalid_items)} invalid items") + return valid_items + + def process(self): + data=self.load_data() + if not data:return{"success":False,"error":"No data loaded"} + + validated_data=self.validate_data(data) + processed_result=process_data(validated_data, + filter_func=lambda x:x.get('active',True), + transform_func=lambda x:{**x,'processed_at':datetime.datetime.now().isoformat()}, + sort_key=lambda x:x.get('name','')) + + self.processed_data=processed_result + return{"success":True,"count":len(processed_result),"data":processed_result} +if __name__=="__main__": + sample_data=[{"id":1,"name":"Alice","active":True},{"id":2,"name":"Bob","active":False},{"id":3,"name":"Charlie","active":True}] + + processor=DataProcessor(sample_data,config={"debug":True}) + result=processor.process() + + if result["success"]: + print(f"Successfully processed {result['count']} items") + for item in result["data"][:3]:print(f"- {item['name']} (ID: {item['id']})") + else:print(f"Processing failed: {result.get('error','Unknown error')}") + + # Generate report with poor formatting + report=generate_report(sample_data,include_stats=True,format_type='pretty_json') + print("Generated report:",report[:100]+"..."if len(report)>100 else report) + + # Complex calculation with poor spacing + numbers=[random.randint(1,100)for _ in range(50)] + stats=calculate_statistics(numbers) + complex_result=complex_nested_function(5,3,2) + + print(f"Statistics: mean={stats['mean']:.2f}, std_dev={stats['std_dev']:.2f}") + print(f"Complex calculation result: {complex_result}") diff --git a/code_to_optimize/no_formatting_errors.py b/code_to_optimize/no_formatting_errors.py new file mode 100644 index 000000000..3d32bc94c --- /dev/null +++ b/code_to_optimize/no_formatting_errors.py @@ -0,0 +1,71 @@ +import os, sys, json, datetime, math, random +import requests +from collections import defaultdict, OrderedDict +from typing import List, Dict, Optional, Union, Tuple, Any +import numpy as np +import pandas as pd + +# This is a poorly formatted Python file with many style violations + + +class UnformattedExampleClass(object): + def __init__( + self, + name, + age=None, + email=None, + phone=None, + address=None, + city=None, + state=None, + zip_code=None, + ): + self.name = name + self.age = age + self.email = email + self.phone = phone + self.address = address + self.city = city + self.state = state + self.zip_code = zip_code + self.data = {"name": name, "age": age, "email": email} + + def get_info(self): + return f"Name: {self.name}, Age: {self.age}" + + def update_data(self, **kwargs): + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + self.data.update(kwargs) + + +def process_data( + data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False +): + if not data_list: + return [] + if filter_func: + data_list = [item for item in data_list if filter_func(item)] + if transform_func: + data_list = [transform_func(item) for item in data_list] + if sort_key: + data_list = sorted(data_list, key=sort_key, reverse=reverse) + return data_list + + +def calculate_statistics(numbers): + if not numbers: + return None + mean = sum(numbers) / len(numbers) + median = sorted(numbers)[len(numbers) // 2] + variance = sum((x - mean) ** 2 for x in numbers) / len(numbers) + std_dev = math.sqrt(variance) + return { + "mean": mean, + "median": median, + "variance": variance, + "std_dev": std_dev, + "min": min(numbers), + "max": max(numbers), + } diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 927a4d4cb..6188e8649 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -3,7 +3,7 @@ import os import shlex import subprocess -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import isort @@ -13,14 +13,53 @@ from pathlib import Path +def get_diff_output_by_black(filepath: str, unformatted_content: str) -> Optional[str]: + try: + from black import Mode, format_file_contents, output, report + + formatted_content = format_file_contents(src_contents=unformatted_content, fast=True, mode=Mode()) + return output.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath) + except (ImportError, report.NothingChanged): + return None + + +def get_diff_lines_count(diff_output: str) -> int: + lines = diff_output.split("\n") + + def is_diff_line(line: str) -> bool: + return line.startswith(("+", "-")) and not line.startswith(("+++", "---")) + + diff_lines = [line for line in lines if is_diff_line(line)] + return len(diff_lines) + + +def is_safe_to_format(filepath: str, content: str, max_diff_lines: int = 100) -> bool: + diff_changes_str = None + + diff_changes_str = get_diff_output_by_black(filepath, unformatted_content=content) + + if diff_changes_str is None: + logger.warning("Looks like black formatter not found, make sure it is installed.") + return False + + diff_lines_count = get_diff_lines_count(diff_changes_str) + if diff_lines_count > max_diff_lines: + logger.debug(f"Skipping formatting {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") + return False + + return True + + def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution formatter_name = formatter_cmds[0].lower() if not path.exists(): msg = f"File {path} does not exist. Cannot format the file." raise FileNotFoundError(msg) - if formatter_name == "disabled": - return path.read_text(encoding="utf8") + file_content = path.read_text(encoding="utf8") + if formatter_name == "disabled" or not is_safe_to_format(filepath=str(path), content=file_content): + return file_content + file_token = "$file" # noqa: S105 for command in formatter_cmds: formatter_cmd_list = shlex.split(command, posix=os.name != "nt") @@ -29,7 +68,7 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True result = subprocess.run(formatter_cmd_list, capture_output=True, check=False) if result.returncode == 0: if print_status: - console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}") + console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}") else: logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}") except FileNotFoundError as e: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3f0d72bcd..41d99ec2c 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -476,12 +476,14 @@ def filter_functions( if blocklist_funcs: functions_tmp = [] for function in _functions: - if not ( + if ( function.file_path.name in blocklist_funcs and function.qualified_name in blocklist_funcs[function.file_path.name] ): + # This function is in blocklist, we can skip it blocklist_funcs_removed_count += 1 continue + # This function is NOT in blocklist. we can keep it functions_tmp.append(function) _functions = functions_tmp diff --git a/poetry.lock b/poetry.lock index 04cfeae09..ab3e6054b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -73,6 +73,53 @@ files = [ {file = "backoff-1.11.1.tar.gz", hash = "sha256:ccb962a2378418c667b3c979b504fdeb7d9e0d29c0579e3b13b86467177728cb"}, ] +[[package]] +name = "black" +version = "25.1.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32"}, + {file = "black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da"}, + {file = "black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7"}, + {file = "black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9"}, + {file = "black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0"}, + {file = "black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299"}, + {file = "black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096"}, + {file = "black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2"}, + {file = "black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b"}, + {file = "black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc"}, + {file = "black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f"}, + {file = "black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba"}, + {file = "black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f"}, + {file = "black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3"}, + {file = "black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171"}, + {file = "black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18"}, + {file = "black-25.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1ee0a0c330f7b5130ce0caed9936a904793576ef4d2b98c40835d6a65afa6a0"}, + {file = "black-25.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3df5f1bf91d36002b0a75389ca8663510cf0531cca8aa5c1ef695b46d98655f"}, + {file = "black-25.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9e6827d563a2c820772b32ce8a42828dc6790f095f441beef18f96aa6f8294e"}, + {file = "black-25.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:bacabb307dca5ebaf9c118d2d2f6903da0d62c9faa82bd21a33eecc319559355"}, + {file = "black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717"}, + {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.10)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "blessed" version = "1.21.0" @@ -1025,8 +1072,11 @@ files = [ {file = "lxml-5.4.0-cp36-cp36m-win_amd64.whl", hash = "sha256:7ce1a171ec325192c6a636b64c94418e71a1964f56d002cc28122fceff0b6121"}, {file = "lxml-5.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:795f61bcaf8770e1b37eec24edf9771b307df3af74d1d6f27d812e15a9ff3872"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:29f451a4b614a7b5b6c2e043d7b64a15bd8304d7e767055e8ab68387a8cacf4e"}, + {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:891f7f991a68d20c75cb13c5c9142b2a3f9eb161f1f12a9489c82172d1f133c0"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4aa412a82e460571fad592d0f93ce9935a20090029ba08eca05c614f99b0cc92"}, + {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:ac7ba71f9561cd7d7b55e1ea5511543c0282e2b6450f122672a2694621d63b7e"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:c5d32f5284012deaccd37da1e2cd42f081feaa76981f0eaa474351b68df813c5"}, + {file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:ce31158630a6ac85bddd6b830cffd46085ff90498b397bd0a259f59d27a12188"}, {file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:31e63621e073e04697c1b2d23fcb89991790eef370ec37ce4d5d469f40924ed6"}, {file = "lxml-5.4.0-cp37-cp37m-win32.whl", hash = "sha256:be2ba4c3c5b7900246a8f866580700ef0d538f2ca32535e991027bdaba944063"}, {file = "lxml-5.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:09846782b1ef650b321484ad429217f5154da4d6e786636c38e434fa32e94e49"}, @@ -1344,6 +1394,18 @@ files = [ qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["docopt", "pytest"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2686,4 +2748,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.9" -content-hash = "1a73e9db33e3884cf1cc6e3371816aebd20831845ef9bf671be315e659480e86" +content-hash = "1ba28119bcc2b572133da8f243eea42fc8f732b6255afac7c2c7e616e2c68677" diff --git a/pyproject.toml b/pyproject.toml index c3e48f889..6a5c4904a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ crosshair-tool = ">=0.0.78" coverage = ">=7.6.4" line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13 platformdirs = ">=4.3.7" +black = "^25.1.0" [tool.poetry.group.dev] optional = true diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..baf5b8079 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,12 +1,17 @@ +import argparse import os import tempfile from pathlib import Path import pytest +import shutil from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig def test_remove_duplicate_imports(): """Test that duplicate imports are removed when should_sort_imports is True.""" @@ -209,3 +214,72 @@ def foo(): tmp_path = tmp.name with pytest.raises(FileNotFoundError): format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + + +def _run_formatting_test(source_filename: str, should_content_change: bool): + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") + with tempfile.TemporaryDirectory() as test_dir_str: + test_dir = Path(test_dir_str) + this_file = Path(__file__).resolve() + repo_root_dir = this_file.parent.parent + source_file = repo_root_dir / "code_to_optimize" / source_filename + + original = source_file.read_text() + target_path = test_dir / "target.py" + + shutil.copy2(source_file, target_path) + + function_to_optimize = FunctionToOptimize( + function_name="process_data", + parents=[], + file_path=target_path + ) + + test_cfg = TestConfig( + tests_root=test_dir, + project_root_path=test_dir, + test_framework="pytest", + tests_project_rootdir=test_dir, + ) + + args = argparse.Namespace( + disable_imports_sorting=False, + formatter_cmds=[ + "ruff check --exit-zero --fix $file", + "ruff format $file" + ], + ) + + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + args=args, + ) + + content, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + ) + + if should_content_change: + assert content != original, f"Expected content to change for {source_filename}" + else: + assert content == original, f"Expected content to remain unchanged for {source_filename}" + + +def test_formatting_file_with_many_diffs(): + """Test that files with many formatting errors are skipped (content unchanged).""" + _run_formatting_test("many_formatting_errors.py", should_content_change=False) + + +def test_formatting_file_with_few_diffs(): + """Test that files with few formatting errors are formatted (content changed).""" + _run_formatting_test("few_formatting_errors.py", should_content_change=True) + +def test_formatting_file_with_no_diffs(): + """Test that files with no formatting errors are unchanged.""" + _run_formatting_test("no_formatting_errors.py", should_content_change=False)