Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 134 additions & 165 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional

import jedi
import pytest
from pydantic.dataclasses import dataclass
from rich.panel import Panel
Expand Down Expand Up @@ -288,192 +289,160 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
return False, function_name, None


def process_test_files(
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
) -> dict[str, list[FunctionCalledInTest]]:
import jedi

def _process_single_test_file(
test_file: Path,
functions: list[TestsInFile],
cfg: TestConfig,
jedi_project: jedi.Project,
function_to_test_map: defaultdict,
) -> None:
project_root_path = cfg.project_root_path
test_framework = cfg.test_framework

function_to_test_map = defaultdict(set)
jedi_project = jedi.Project(path=project_root_path)
goto_cache = {}
tests_cache = TestsCache()
try:
script = jedi.Script(path=test_file, project=jedi_project)
test_functions = set()

with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
progress,
task_id,
):
for test_file, functions in file_to_test_map.items():
file_hash = TestsCache.compute_file_hash(test_file)
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
if cached_tests:
self_cur = tests_cache.cur
self_cur.execute(
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
(str(test_file), file_hash),
)
qualified_names = [row[0] for row in self_cur.fetchall()]
for cached, qualified_name in zip(cached_tests, qualified_names):
function_to_test_map[qualified_name].add(cached)
progress.advance(task_id)
continue
all_names = script.get_names(all_scopes=True, references=True)
all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)

try:
script = jedi.Script(path=test_file, project=jedi_project)
test_functions = set()
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
except Exception as e:
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
return

if test_framework == "pytest":
for function in functions:
if "[" in function.test_function:
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
if function_name in top_level_functions:
test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type))
elif function.test_function in top_level_functions:
test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type))
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
if base_name in top_level_functions:
test_functions.add(
TestFunction(
function_name=base_name,
test_class=function.test_class,
parameters=function.test_function,
test_type=function.test_type,
)
)

all_names = script.get_names(all_scopes=True, references=True)
all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)
elif test_framework == "unittest":
functions_to_search = [elem.test_function for elem in functions]
test_suites = {elem.test_class for elem in functions}

top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
except Exception as e:
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
progress.advance(task_id)
continue
matching_names = test_suites & top_level_classes.keys()
for matched_name in matching_names:
for def_name in all_defs:
if (
def_name.type == "function"
and def_name.full_name is not None
and f".{matched_name}." in def_name.full_name
):
for function in functions_to_search:
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)

if test_framework == "pytest":
for function in functions:
if "[" in function.test_function:
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
if function_name in top_level_functions:
if is_parameterized and new_function == def_name.name:
test_functions.add(
TestFunction(function_name, function.test_class, parameters, function.test_type)
TestFunction(
function_name=def_name.name,
test_class=matched_name,
parameters=parameters,
test_type=functions[0].test_type,
)
)
elif function.test_function in top_level_functions:
test_functions.add(
TestFunction(function.test_function, function.test_class, None, function.test_type)
)
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
if base_name in top_level_functions:
elif function == def_name.name:
test_functions.add(
TestFunction(
function_name=base_name,
test_class=function.test_class,
parameters=function.test_function,
test_type=function.test_type,
function_name=def_name.name,
test_class=matched_name,
parameters=None,
test_type=functions[0].test_type,
)
)

elif test_framework == "unittest":
functions_to_search = [elem.test_function for elem in functions]
test_suites = {elem.test_class for elem in functions}

matching_names = test_suites & top_level_classes.keys()
for matched_name in matching_names:
for def_name in all_defs:
if (
def_name.type == "function"
and def_name.full_name is not None
and f".{matched_name}." in def_name.full_name
):
for function in functions_to_search:
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)

if is_parameterized and new_function == def_name.name:
test_functions.add(
TestFunction(
function_name=def_name.name,
test_class=matched_name,
parameters=parameters,
test_type=functions[0].test_type,
)
)
elif function == def_name.name:
test_functions.add(
TestFunction(
function_name=def_name.name,
test_class=matched_name,
parameters=None,
test_type=functions[0].test_type,
)
)

test_functions_list = list(test_functions)
test_functions_raw = [elem.function_name for elem in test_functions_list]

test_functions_by_name = defaultdict(list)
for i, func_name in enumerate(test_functions_raw):
test_functions_by_name[func_name].append(i)

for name in all_names:
if name.full_name is None:
continue
m = FUNCTION_NAME_REGEX.search(name.full_name)
if not m:
continue

scope = m.group(1)
if scope not in test_functions_by_name:
continue

cache_key = (name.full_name, name.module_name)
try:
if cache_key in goto_cache:
definition = goto_cache[cache_key]
else:
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
goto_cache[cache_key] = definition
except Exception as e:
logger.debug(str(e))
continue

if not definition or definition[0].type != "function":
continue

definition_path = str(definition[0].module_path)
if (
definition_path.startswith(str(project_root_path) + os.sep)
and definition[0].module_name != name.module_name
and definition[0].full_name is not None
):
for index in test_functions_by_name[scope]:
scope_test_function = test_functions_list[index].function_name
scope_test_class = test_functions_list[index].test_class
scope_parameters = test_functions_list[index].parameters
test_type = test_functions_list[index].test_type

if scope_parameters is not None:
if test_framework == "pytest":
scope_test_function += "[" + scope_parameters + "]"
if test_framework == "unittest":
scope_test_function += "_" + scope_parameters

full_name_without_module_prefix = definition[0].full_name.replace(
definition[0].module_name + ".", "", 1
)
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
test_functions_list = list(test_functions)
test_functions_raw = [elem.function_name for elem in test_functions_list]

test_functions_by_name = defaultdict(list)
for i, func_name in enumerate(test_functions_raw):
test_functions_by_name[func_name].append(i)

tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
function_name=scope,
for name in all_names:
if name.full_name is None:
continue
m = FUNCTION_NAME_REGEX.search(name.full_name)
if not m:
continue

scope = m.group(1)
if scope not in test_functions_by_name:
continue

try:
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
logger.debug(str(e))
continue

if not definition or definition[0].type != "function":
continue

definition_path = str(definition[0].module_path)
if (
definition_path.startswith(str(project_root_path) + os.sep)
and definition[0].module_name != name.module_name
and definition[0].full_name is not None
):
for index in test_functions_by_name[scope]:
scope_test_function = test_functions_list[index].function_name
scope_test_class = test_functions_list[index].test_class
scope_parameters = test_functions_list[index].parameters
test_type = test_functions_list[index].test_type

if scope_parameters is not None:
if test_framework == "pytest":
scope_test_function += "[" + scope_parameters + "]"
if test_framework == "unittest":
scope_test_function += "_" + scope_parameters

full_name_without_module_prefix = definition[0].full_name.replace(
definition[0].module_name + ".", "", 1
)
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"

function_to_test_map[qualified_name_with_modules_from_root].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=scope_test_class,
test_function=scope_test_function,
test_type=test_type,
line_number=name.line,
col_number=name.column,
)
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
)

function_to_test_map[qualified_name_with_modules_from_root].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=scope_test_class,
test_function=scope_test_function,
test_type=test_type,
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
)

def process_test_files(
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
) -> dict[str, list[FunctionCalledInTest]]:
function_to_test_map = defaultdict(set)
jedi_project = jedi.Project(path=cfg.project_root_path)

with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
progress,
task_id,
):
for test_file, functions in file_to_test_map.items():
_process_single_test_file(test_file, functions, cfg, jedi_project, function_to_test_map)
progress.advance(task_id)

tests_cache.close()
return {function: list(tests) for function, tests in function_to_test_map.items()}
Loading