Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cldk/analysis/java/codeanalyzer/codeanalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def __init__(
self.eager_analysis = eager_analysis
self.analysis_level = analysis_level
self.target_files = target_files
self.application = self._init_codeanalyzer(analysis_level=1 if analysis_level == AnalysisLevel.symbol_table else 2)
if self.source_code is None:
self.application = self._init_codeanalyzer(analysis_level=1 if analysis_level == AnalysisLevel.symbol_table else 2)
else:
self.application = self._codeanalyzer_single_file()
# Attributes related the Java code analysis...
if analysis_level == AnalysisLevel.call_graph:
self.call_graph: nx.DiGraph = self._generate_call_graph(using_symbol_table=False)
Expand Down
6 changes: 0 additions & 6 deletions cldk/analysis/python/python_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,11 @@ class PythonAnalysis:

def __init__(
self,
eager_analysis: bool,
project_dir: str | Path | None,
source_code: str | None,
analysis_backend_path: str | None,
analysis_json_path: str | Path | None,
) -> None:
self.project_dir = project_dir
self.source_code = source_code
self.analysis_json_path = analysis_json_path
self.analysis_backend_path = analysis_backend_path
self.eager_analysis = eager_analysis
self.analysis_backend: TreesitterPython = TreesitterPython()

def get_methods(self) -> List[PyMethod]:
Expand Down
6 changes: 6 additions & 0 deletions cldk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from cldk.analysis.c import CAnalysis
from cldk.analysis.java import JavaAnalysis
from cldk.analysis.commons.treesitter import TreesitterJava
from cldk.analysis.python.python_analysis import PythonAnalysis
from cldk.utils.exceptions import CldkInitializationException
from cldk.utils.sanitization.java import TreesitterSanitizer

Expand Down Expand Up @@ -118,6 +119,11 @@ def analysis(
target_files=target_files,
eager_analysis=eager,
)
elif self.language == "python":
return PythonAnalysis(
project_dir=project_path,
source_code=source_code,
)
elif self.language == "c":
return CAnalysis(project_dir=project_path)
else:
Expand Down
4 changes: 4 additions & 0 deletions cldk/models/java/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,15 @@ class JCompilationUnit(BaseModel):
"""Represents a compilation unit in Java.

Attributes:
file_path (str): The path to the source file.
package_name (str): The name of the package for the comppilation unit.
comments (List[JComment]): A list of comments in the compilation unit.
imports (List[str]): A list of import statements in the compilation unit.
type_declarations (Dict[str, JType]): A dictionary mapping type names to their corresponding JType representations.
"""

file_path: str
package_name: str
comments: List[JComment]
imports: List[str]
type_declarations: Dict[str, JType]
Expand Down
18 changes: 18 additions & 0 deletions tests/analysis/java/test_java_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def test_get_symbol_table_is_not_null(test_fixture, analysis_json):
)
assert analysis.get_symbol_table() is not None

def test_get_symbol_table_source_code(java_code):
"""Should return a symbol table for source analysis with expected class/method count"""

# Initialize the CLDK object with the project directory, language, and analysis_backend
cldk = CLDK(language="java")
analysis = cldk.analysis(
source_code=java_code,
analysis_backend_path=None,
eager=True,
analysis_level=AnalysisLevel.symbol_table,
)

# assert on expected class name and method count in the symbol table
expected_class_name = "com.acme.modres.WeatherServlet"
assert analysis.get_symbol_table() is not None
assert len(analysis.get_symbol_table().keys()) == 1
assert expected_class_name in analysis.get_methods().keys()
assert len(analysis.get_methods().get(expected_class_name).keys()) == 9

def test_get_imports(test_fixture, analysis_json):
"""Should return NotImplemented for get_imports()"""
Expand Down
30 changes: 15 additions & 15 deletions tests/analysis/python/test_python_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def divide(self, a, b):

def test_get_methods():
"""Should return all of the methods"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

all_methods = python_analysis.get_methods()
assert all_methods is not None
Expand All @@ -79,7 +79,7 @@ def test_get_methods():

def test_get_functions():
"""Should return all of the functions"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

all_functions = python_analysis.get_functions()
assert all_functions is not None
Expand All @@ -91,7 +91,7 @@ def test_get_functions():

def test_get_all_modules(tmp_path):
"""Should return all of the modules"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=tmp_path, source_code=None, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=tmp_path, source_code=None)

# set up some temporary modules
temp_file_path = os.path.join(tmp_path, "hello.py")
Expand All @@ -111,7 +111,7 @@ def test_get_all_modules(tmp_path):

def test_get_method_details():
"""Should return the method details"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

method_details = python_analysis.get_method_details("add(self, a, b)")
assert method_details is not None
Expand All @@ -121,7 +121,7 @@ def test_get_method_details():

def test_is_parsable():
"""Should be able to parse the code"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

code = "def is_parsable(self, code: str) -> bool: return True"
is_parsable = python_analysis.is_parsable(code)
Expand All @@ -134,7 +134,7 @@ def test_is_parsable():

def test_get_raw_ast():
"""Should return the raw AST"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

raw_ast = python_analysis.get_raw_ast(PYTHON_CODE)
assert raw_ast is not None
Expand All @@ -144,7 +144,7 @@ def test_get_raw_ast():

def test_get_imports():
"""Should return all of the imports"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

all_imports = python_analysis.get_imports()
assert all_imports is not None
Expand All @@ -156,7 +156,7 @@ def test_get_imports():

def test_get_variables():
"""Should return all of the variables"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

with pytest.raises(NotImplementedError) as except_info:
python_analysis.get_variables()
Expand All @@ -165,7 +165,7 @@ def test_get_variables():

def test_get_classes():
"""Should return all of the classes"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

all_classes = python_analysis.get_classes()
assert all_classes is not None
Expand All @@ -178,7 +178,7 @@ def test_get_classes():

def test_get_classes_by_criteria():
"""Should return all of the classes that match the criteria"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

with pytest.raises(NotImplementedError) as except_info:
python_analysis.get_classes_by_criteria()
Expand All @@ -187,7 +187,7 @@ def test_get_classes_by_criteria():

def test_get_sub_classes():
"""Should return all of the subclasses"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

with pytest.raises(NotImplementedError) as except_info:
python_analysis.get_sub_classes()
Expand All @@ -196,7 +196,7 @@ def test_get_sub_classes():

def test_get_nested_classes():
"""Should return all of the nested classes"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

with pytest.raises(NotImplementedError) as except_info:
python_analysis.get_nested_classes()
Expand All @@ -205,7 +205,7 @@ def test_get_nested_classes():

def test_get_constructors():
"""Should return all of the constructors"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

with pytest.raises(NotImplementedError) as except_info:
python_analysis.get_constructors()
Expand All @@ -214,7 +214,7 @@ def test_get_constructors():

def test_get_methods_in_class():
"""Should return all of the methods in the class"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

with pytest.raises(NotImplementedError) as except_info:
python_analysis.get_methods_in_class()
Expand All @@ -223,7 +223,7 @@ def test_get_methods_in_class():

def test_get_fields():
"""Should return all of the fields in the class"""
python_analysis = PythonAnalysis(eager_analysis=True, project_dir=None, source_code=PYTHON_CODE, analysis_backend_path=None, analysis_json_path=None)
python_analysis = PythonAnalysis(project_dir=None, source_code=PYTHON_CODE)

with pytest.raises(NotImplementedError) as except_info:
python_analysis.get_fields()
Expand Down
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,24 @@ def test_fixture_binutils():
for directory in Path(test_data_path).iterdir():
if directory.exists() and directory.is_dir():
shutil.rmtree(directory)

@pytest.fixture(scope="session", autouse=True)
def java_code() -> str:
"""
Returns sample Java source code for analysis.

Yields:
str : Java code to be analyzed.
"""
# ----------------------------------[ SETUP ]----------------------------------
# Path to your pyproject.toml
pyproject_path = Path(__file__).parent.parent / "pyproject.toml"

# Load the configuration
config = toml.load(pyproject_path)

# Access the test data path
test_data_path = config["tool"]["cldk"]["testing"]["sample-application"]
javafile = Path(test_data_path).absolute() / ("WeatherServlet.java")
with open(javafile) as f:
return f.read()
Loading