From b5a24b670eba2095ee9650cb5883ddb14ee338a7 Mon Sep 17 00:00:00 2001 From: Rahul Krishna Date: Wed, 2 Oct 2024 22:32:36 -0400 Subject: [PATCH] Now reads slim JSON Signed-off-by: Rahul Krishna --- .../java/codeanalyzer/codeanalyzer.py | 277 +++++++----------- cldk/models/java/models.py | 111 ++++--- tests/conftest.py | 11 + tests/models/java/test_java_models.py | 11 + 4 files changed, 190 insertions(+), 220 deletions(-) diff --git a/cldk/analysis/java/codeanalyzer/codeanalyzer.py b/cldk/analysis/java/codeanalyzer/codeanalyzer.py index fd4ff9d5..7e96fa56 100644 --- a/cldk/analysis/java/codeanalyzer/codeanalyzer.py +++ b/cldk/analysis/java/codeanalyzer/codeanalyzer.py @@ -15,8 +15,7 @@ from cldk.analysis import AnalysisLevel from cldk.analysis.java.treesitter import JavaSitter from cldk.models.java import JGraphEdges -from cldk.models.java.models import JApplication, JCallable, JField, JMethodDetail, JType, JCompilationUnit, \ - JGraphEdgesST +from cldk.models.java.models import JApplication, JCallable, JField, JMethodDetail, JType, JCompilationUnit, JGraphEdgesST from typing import Dict, List, Tuple from typing import Union @@ -28,13 +27,13 @@ class JCodeanalyzer: - """ A class for building the application view of a Java application using Codeanalyzer. + """A class for building the application view of a Java application using Codeanalyzer. Args: project_dir (str or Path): The path to the root of the Java project. source_code (str, optional): The source code of a single Java file to analyze. Defaults to None. analysis_backend_path (str or Path, optional): The path to the analysis backend. Defaults to None. - analysis_json_path (str or Path, optional): The path to save the intermediate code analysis outputs. + analysis_json_path (str or Path, optional): The path to save the intermediate code analysis outputs. If None, the analysis will be read from the pipe. analysis_level (str): The level of analysis ('symbol_table' or 'call_graph'). use_graalvm_binary (bool): If True, the GraalVM binary will be used instead of the codeanalyzer jar. @@ -76,15 +75,15 @@ class JCodeanalyzer: """ def __init__( - self, - project_dir: Union[str, Path], - source_code: str | None, - analysis_backend_path: Union[str, Path, None], - analysis_json_path: Union[str, Path, None], - analysis_level: str, - use_graalvm_binary: bool, - eager_analysis: bool, - target_files: List[str] | None + self, + project_dir: Union[str, Path], + source_code: str | None, + analysis_backend_path: Union[str, Path, None], + analysis_json_path: Union[str, Path, None], + analysis_level: str, + use_graalvm_binary: bool, + eager_analysis: bool, + target_files: List[str] | None, ) -> None: self.project_dir = project_dir self.source_code = source_code @@ -94,8 +93,7 @@ def __init__( self.eager_analysis = eager_analysis self.analysis_level = analysis_level self.target_files = target_files - self.application = self._init_codeanalyzer( - analysis_level=1 if analysis_level == AnalysisLevel.symbol_table else 2) + self.application = self._init_codeanalyzer(analysis_level=1 if analysis_level == AnalysisLevel.symbol_table else 2) # Attributes related the Java code analysis... if analysis_level == AnalysisLevel.call_graph: self.call_graph: DiGraph = self._generate_call_graph(using_symbol_table=False) @@ -104,7 +102,7 @@ def __init__( @staticmethod def _download_or_update_code_analyzer(filepath: Path) -> str: - """ Downloads the codeanalyzer jar from the latest release on GitHub. + """Downloads the codeanalyzer jar from the latest release on GitHub. Args: filepath (Path): The path to save the codeanalyzer jar. @@ -139,8 +137,7 @@ def _download_or_update_code_analyzer(filepath: Path) -> str: if match: current_datetime_str = match.group(0) - if datetime.strptime(datetime_str, date_format) > datetime.strptime(current_datetime_str, - date_format): + if datetime.strptime(datetime_str, date_format) > datetime.strptime(current_datetime_str, date_format): logger.info(f"Codeanalzyer jar is outdated. Downloading the latest version.") # Remove the older codeanalyzer jar for jarfile in current_codeanalyzer_jars: @@ -159,7 +156,7 @@ def _download_or_update_code_analyzer(filepath: Path) -> str: raise Exception(f"Failed to fetch release warn: {response.status_code} {response.text}") def _get_application(self) -> JApplication: - """ Returns the application view of the Java code. + """Returns the application view of the Java code. Returns: JApplication: The application view of the Java code. @@ -169,7 +166,7 @@ def _get_application(self) -> JApplication: return self.application def _get_codeanalyzer_exec(self) -> List[str]: - """ Returns the executable command for codeanalyzer. + """Returns the executable command for codeanalyzer. Returns: List[str]: The executable command for codeanalyzer. @@ -181,8 +178,7 @@ def _get_codeanalyzer_exec(self) -> List[str]: """ if self.use_graalvm_binary: - with resources.as_file( - resources.files("cldk.analysis.java.codeanalyzer.bin") / "codeanalyzer") as codeanalyzer_bin_path: + with resources.as_file(resources.files("cldk.analysis.java.codeanalyzer.bin") / "codeanalyzer") as codeanalyzer_bin_path: codeanalyzer_exec = shlex.split(codeanalyzer_bin_path.__str__()) else: @@ -198,7 +194,7 @@ def _get_codeanalyzer_exec(self) -> List[str]: codeanalyzer_jar_file = self._download_or_update_code_analyzer(codeanalyzer_jar_path) codeanalyzer_exec = shlex.split(f"java -jar {codeanalyzer_jar_file}") return codeanalyzer_exec - + def init_japplication(self, data: str) -> JApplication: """Return JApplication giving the stringified JSON as input. Returns @@ -207,9 +203,9 @@ def init_japplication(self, data: str) -> JApplication: The application view of the Java code with the analysis results. """ return JApplication(**json.loads(data)) - + def _init_codeanalyzer(self, analysis_level=1) -> JApplication: - """ Initializes the Codeanalyzer. + """Initializes the Codeanalyzer. Args: analysis_level (int): The level of analysis to be performed (1 for symbol table, 2 for call graph). @@ -221,19 +217,15 @@ def _init_codeanalyzer(self, analysis_level=1) -> JApplication: CodeanalyzerExecutionException: If there is an error running Codeanalyzer. """ codeanalyzer_exec = self._get_codeanalyzer_exec() - codeanalyzer_args = '' + codeanalyzer_args = "" if self.analysis_json_path is None: logger.info("Reading analysis from the pipe.") # If target file is provided, the input is merged into a single string and passed to codeanalyzer if self.target_files: - target_file_options = ' -t '.join([s.strip() for s in self.target_files]) - codeanalyzer_args = codeanalyzer_exec + shlex.split( - f"-i {Path(self.project_dir)} --analysis-level={analysis_level} -t {target_file_options}" - ) + target_file_options = " -t ".join([s.strip() for s in self.target_files]) + codeanalyzer_args = codeanalyzer_exec + shlex.split(f"-i {Path(self.project_dir)} --analysis-level={analysis_level} -t {target_file_options}") else: - codeanalyzer_args = codeanalyzer_exec + shlex.split( - f"-i {Path(self.project_dir)} --analysis-level={analysis_level}" - ) + codeanalyzer_args = codeanalyzer_exec + shlex.split(f"-i {Path(self.project_dir)} --analysis-level={analysis_level}") try: logger.info(f"Running codeanalyzer: {' '.join(codeanalyzer_args)}") console_out: CompletedProcess[str] = subprocess.run( @@ -245,17 +237,15 @@ def _init_codeanalyzer(self, analysis_level=1) -> JApplication: return JApplication(**json.loads(console_out.stdout)) except Exception as e: raise CodeanalyzerExecutionException(str(e)) from e - else: # Check if the code analyzer needs to be run is_run_code_analyzer = False analysis_json_path_file = Path(self.analysis_json_path).joinpath("analysis.json") # If target file is provided, the input is merged into a single string and passed to codeanalyzer if self.target_files: - target_file_options = ' -t '.join([s.strip() for s in self.target_files]) + target_file_options = " -t ".join([s.strip() for s in self.target_files]) codeanalyzer_args = codeanalyzer_exec + shlex.split( - f"-i {Path(self.project_dir)} --analysis-level={analysis_level}" - f" -o {self.analysis_json_path} -t {target_file_options}" + f"-i {Path(self.project_dir)} --analysis-level={analysis_level}" f" -o {self.analysis_json_path} -t {target_file_options}" ) is_run_code_analyzer = True else: @@ -264,9 +254,7 @@ def _init_codeanalyzer(self, analysis_level=1) -> JApplication: # flag is set, we'll run the analysis every time the object is created. This will happen regradless # of the existence of the analysis file. # Create the executable command for codeanalyzer. - codeanalyzer_args = codeanalyzer_exec + shlex.split( - f"-i {Path(self.project_dir)} --analysis-level={analysis_level} -o {self.analysis_json_path}" - ) + codeanalyzer_args = codeanalyzer_exec + shlex.split(f"-i {Path(self.project_dir)} --analysis-level={analysis_level} -o {self.analysis_json_path}") is_run_code_analyzer = True if is_run_code_analyzer: @@ -288,7 +276,7 @@ def _init_codeanalyzer(self, analysis_level=1) -> JApplication: return JApplication(**data) def _codeanalyzer_single_file(self): - """ Invokes codeanalyzer in a single file mode. + """Invokes codeanalyzer in a single file mode. Returns: JApplication: The application view of the Java code with the analysis results. @@ -299,8 +287,7 @@ def _codeanalyzer_single_file(self): try: print(f"Running {' '.join(codeanalyzer_cmd)}") logger.info(f"Running {' '.join(codeanalyzer_cmd)}") - console_out: CompletedProcess[str] = subprocess.run(codeanalyzer_cmd, capture_output=True, text=True, - check=True) + console_out: CompletedProcess[str] = subprocess.run(codeanalyzer_cmd, capture_output=True, text=True, check=True) if console_out.returncode != 0: raise CodeanalyzerExecutionException(console_out.stderr) return JApplication(**json.loads(console_out.stdout)) @@ -308,7 +295,7 @@ def _codeanalyzer_single_file(self): raise CodeanalyzerExecutionException(str(e)) from e def get_symbol_table(self) -> Dict[str, JCompilationUnit]: - """ Returns the symbol table of the Java code. + """Returns the symbol table of the Java code. Returns: Dict[str, JCompilationUnit]: The symbol table of the Java code. @@ -318,7 +305,7 @@ def get_symbol_table(self) -> Dict[str, JCompilationUnit]: return self.application.symbol_table def get_application_view(self) -> JApplication: - """ Returns the application view of the Java code. + """Returns the application view of the Java code. Returns: JApplication: The application view of the Java code. @@ -333,7 +320,7 @@ def get_application_view(self) -> JApplication: return self.application def get_system_dependency_graph(self) -> list[JGraphEdges]: - """ Runs the codeanalyzer to get the system dependency graph. + """Runs the codeanalyzer to get the system dependency graph. Returns: list[JGraphEdges]: The system dependency graph. @@ -344,7 +331,7 @@ def get_system_dependency_graph(self) -> list[JGraphEdges]: return self.application.system_dependency_graph def _generate_call_graph(self, using_symbol_table) -> DiGraph: - """ Generates the call graph of the Java code. + """Generates the call graph of the Java code. Args: using_symbol_table (bool): Whether to use the symbol table for generating the call graph. @@ -384,14 +371,14 @@ def _generate_call_graph(self, using_symbol_table) -> DiGraph: return cg def get_class_hierarchy(self) -> DiGraph: - """ Returns the class hierarchy of the Java code. + """Returns the class hierarchy of the Java code. Returns: DiGraph: The class hierarchy of the Java code. """ def get_call_graph(self) -> DiGraph: - """ Returns the call graph of the Java code. + """Returns the call graph of the Java code. Returns: DiGraph: The call graph of the Java code. @@ -403,11 +390,11 @@ def get_call_graph(self) -> DiGraph: return self.call_graph def get_call_graph_json(self) -> str: - """ Get call graph in serialized json format. + """Get call graph in serialized json format. Returns: str: Call graph in json. - """ + """ callgraph_list = [] edges = list(self.call_graph.edges.data("calling_lines")) for edge in edges: @@ -423,7 +410,7 @@ def get_call_graph_json(self) -> str: return json.dumps(callgraph_list) def get_all_callers(self, target_class_name: str, target_method_signature: str, using_symbol_table: bool) -> Dict: - """ Get all the caller details for a given Java method. + """Get all the caller details for a given Java method. Args: target_class_name (str): The qualified class name of the target method. @@ -437,9 +424,7 @@ def get_all_callers(self, target_class_name: str, target_method_signature: str, caller_detail_dict = {} call_graph = None if using_symbol_table: - call_graph = self.__call_graph_using_symbol_table(qualified_class_name=target_class_name, - method_signature=target_method_signature, - is_target_method=True) + call_graph = self.__call_graph_using_symbol_table(qualified_class_name=target_class_name, method_signature=target_method_signature, is_target_method=True) else: call_graph = self.call_graph if (target_method_signature, target_class_name) not in call_graph.nodes(): @@ -453,17 +438,15 @@ def get_all_callers(self, target_class_name: str, target_method_signature: str, data=True, ) caller_detail_dict["caller_details"] = [] - caller_detail_dict["target_method"] = call_graph.nodes[(target_method_signature, target_class_name)][ - "method_detail"] + caller_detail_dict["target_method"] = call_graph.nodes[(target_method_signature, target_class_name)]["method_detail"] for source, target, data in in_edge_view: - cm = {"caller_method": call_graph.nodes[source]["method_detail"], - "calling_lines": data["calling_lines"]} + cm = {"caller_method": call_graph.nodes[source]["method_detail"], "calling_lines": data["calling_lines"]} caller_detail_dict["caller_details"].append(cm) return caller_detail_dict def get_all_callees(self, source_class_name: str, source_method_signature: str, using_symbol_table: bool) -> Dict: - """ Get all the callee details for a given Java method. + """Get all the callee details for a given Java method. Args: source_class_name (str): The qualified class name of the source method. @@ -476,8 +459,7 @@ def get_all_callees(self, source_class_name: str, source_method_signature: str, callee_detail_dict = {} call_graph = None if using_symbol_table: - call_graph = self.__call_graph_using_symbol_table(qualified_class_name=source_class_name, - method_signature=source_method_signature) + call_graph = self.__call_graph_using_symbol_table(qualified_class_name=source_class_name, method_signature=source_method_signature) else: call_graph = self.call_graph if (source_method_signature, source_class_name) not in call_graph.nodes(): @@ -486,8 +468,7 @@ def get_all_callees(self, source_class_name: str, source_method_signature: str, out_edge_view = call_graph.out_edges(nbunch=(source_method_signature, source_class_name), data=True) callee_detail_dict["callee_details"] = [] - callee_detail_dict["source_method"] = call_graph.nodes[(source_method_signature, source_class_name)][ - "method_detail"] + callee_detail_dict["source_method"] = call_graph.nodes[(source_method_signature, source_class_name)]["method_detail"] for source, target, data in out_edge_view: cm = {"callee_method": call_graph.nodes[target]["method_detail"]} cm["calling_lines"] = data["calling_lines"] @@ -495,7 +476,7 @@ def get_all_callees(self, source_class_name: str, source_method_signature: str, return callee_detail_dict def get_all_methods_in_application(self) -> Dict[str, Dict[str, JCallable]]: - """ Returns a dictionary of all methods in the Java code with qualified class name as the key + """Returns a dictionary of all methods in the Java code with qualified class name as the key and a dictionary of methods in that class as the value. Returns: @@ -509,7 +490,7 @@ def get_all_methods_in_application(self) -> Dict[str, Dict[str, JCallable]]: return class_method_dict def get_all_classes(self) -> Dict[str, JType]: - """ Returns a dictionary of all classes in the Java code. + """Returns a dictionary of all classes in the Java code. Returns: Dict[str, JType]: A dictionary of all classes in the Java code, with qualified class names as keys. @@ -522,21 +503,21 @@ def get_all_classes(self) -> Dict[str, JType]: return class_dict def get_class(self, qualified_class_name) -> JType: - """ Returns a class given the qualified class name. + """Returns a class given the qualified class name. Args: qualified_class_name (str): The qualified name of the class. Returns: JType: A class for the given qualified class name. - """ + """ symtab = self.get_symbol_table() for _, v in symtab.items(): if qualified_class_name in v.type_declarations.keys(): return v.type_declarations.get(qualified_class_name) def get_method(self, qualified_class_name, method_signature) -> JCallable: - """ Returns a method given the qualified method name. + """Returns a method given the qualified method name. Args: qualified_class_name (str): The qualified name of the class. @@ -554,7 +535,7 @@ def get_method(self, qualified_class_name, method_signature) -> JCallable: return ci.callable_declarations[cd] def get_java_file(self, qualified_class_name) -> str: - """ Returns java file name given the qualified class name. + """Returns java file name given the qualified class name. Args: qualified_class_name (str): The qualified name of the class. @@ -568,7 +549,7 @@ def get_java_file(self, qualified_class_name) -> str: return k def get_java_compilation_unit(self, file_path: str) -> JCompilationUnit: - """ Given the path of a Java source file, returns the compilation unit object from the symbol table. + """Given the path of a Java source file, returns the compilation unit object from the symbol table. Args: file_path (str): Absolute path to the Java source file. @@ -582,7 +563,7 @@ def get_java_compilation_unit(self, file_path: str) -> JCompilationUnit: return self.application.symbol_table[file_path] def get_all_methods_in_class(self, qualified_class_name) -> Dict[str, JCallable]: - """ Returns a dictionary of all methods in the given class. + """Returns a dictionary of all methods in the given class. Args: qualified_class_name (str): The qualified name of the class. @@ -597,7 +578,7 @@ def get_all_methods_in_class(self, qualified_class_name) -> Dict[str, JCallable] return methods def get_all_constructors(self, qualified_class_name) -> Dict[str, JCallable]: - """ Returns a dictionary of all constructors of the given class. + """Returns a dictionary of all constructors of the given class. Args: qualified_class_name (str): The qualified name of the class. @@ -612,7 +593,7 @@ def get_all_constructors(self, qualified_class_name) -> Dict[str, JCallable]: return constructors def get_all_sub_classes(self, qualified_class_name) -> Dict[str, JType]: - """ Returns a dictionary of all sub-classes of the given class. + """Returns a dictionary of all sub-classes of the given class. Args: qualified_class_name (str): The qualified name of the class. @@ -624,13 +605,12 @@ def get_all_sub_classes(self, qualified_class_name) -> Dict[str, JType]: all_classes = self.get_all_classes() sub_classes = {} for cls in all_classes: - if qualified_class_name in all_classes[cls].implements_list or qualified_class_name in all_classes[ - cls].extends_list: + if qualified_class_name in all_classes[cls].implements_list or qualified_class_name in all_classes[cls].extends_list: sub_classes[cls] = all_classes[cls] return sub_classes def get_all_fields(self, qualified_class_name) -> List[JField]: - """ Returns a list of all fields of the given class. + """Returns a list of all fields of the given class. Args: qualified_class_name (str): The qualified name of the class. @@ -645,7 +625,7 @@ def get_all_fields(self, qualified_class_name) -> List[JField]: return ci.field_declarations def get_all_nested_classes(self, qualified_class_name) -> List[JType]: - """ Returns a list of all nested classes for the given class. + """Returns a list of all nested classes for the given class. Args: qualified_class_name (str): The qualified name of the class. @@ -661,7 +641,7 @@ def get_all_nested_classes(self, qualified_class_name) -> List[JType]: return [self.get_class(c) for c in nested_classes] # Assuming qualified nested class names def get_extended_classes(self, qualified_class_name) -> List[str]: - """ Returns a list of all extended classes for the given class. + """Returns a list of all extended classes for the given class. Args: qualified_class_name (str): The qualified name of the class. @@ -676,7 +656,7 @@ def get_extended_classes(self, qualified_class_name) -> List[str]: return ci.extends_list def get_implemented_interfaces(self, qualified_class_name) -> List[str]: - """ Returns a list of all implemented interfaces for the given class. + """Returns a list of all implemented interfaces for the given class. Args: qualified_class_name (str): The qualified name of the class. @@ -690,10 +670,8 @@ def get_implemented_interfaces(self, qualified_class_name) -> List[str]: return list() return ci.implements_list - def get_class_call_graph_using_symbol_table(self, qualified_class_name: str, - method_signature: str | None = None) -> ( - List)[Tuple[JMethodDetail, JMethodDetail]]: - """ Returns call graph using symbol table. The analysis will not be + def get_class_call_graph_using_symbol_table(self, qualified_class_name: str, method_signature: str | None = None) -> (List)[Tuple[JMethodDetail, JMethodDetail]]: + """Returns call graph using symbol table. The analysis will not be complete as symbol table has known limitation of resolving types Args: qualified_class_name: qualified name of the class @@ -706,8 +684,7 @@ def get_class_call_graph_using_symbol_table(self, qualified_class_name: str, if method_signature is None: filter_criteria = {node for node in call_graph.nodes if node[1] == qualified_class_name} else: - filter_criteria = {node for node in call_graph.nodes if - tuple(node) == (method_signature, qualified_class_name)} + filter_criteria = {node for node in call_graph.nodes if tuple(node) == (method_signature, qualified_class_name)} graph_edges: List[Tuple[JMethodDetail, JMethodDetail]] = list() for edge in call_graph.edges(nbunch=filter_criteria): @@ -716,10 +693,8 @@ def get_class_call_graph_using_symbol_table(self, qualified_class_name: str, graph_edges.append((source, target)) return graph_edges - def __call_graph_using_symbol_table(self, - qualified_class_name: str, - method_signature: str, is_target_method: bool = False)-> DiGraph: - """ Generate call graph using symbol table + def __call_graph_using_symbol_table(self, qualified_class_name: str, method_signature: str, is_target_method: bool = False) -> DiGraph: + """Generate call graph using symbol table Args: qualified_class_name: qualified class name method_signature: method signature @@ -731,11 +706,9 @@ def __call_graph_using_symbol_table(self, cg = nx.DiGraph() sdg = None if is_target_method: - sdg = self.__raw_call_graph_using_symbol_table_target_method(target_class_name=qualified_class_name, - target_method_signature=method_signature) + sdg = self.__raw_call_graph_using_symbol_table_target_method(target_class_name=qualified_class_name, target_method_signature=method_signature) else: - sdg = self.__raw_call_graph_using_symbol_table(qualified_class_name=qualified_class_name, - method_signature=method_signature) + sdg = self.__raw_call_graph_using_symbol_table(qualified_class_name=qualified_class_name, method_signature=method_signature) tsu = JavaSitter() edge_list = [ ( @@ -761,11 +734,8 @@ def __call_graph_using_symbol_table(self, cg.add_edges_from(edge_list) return cg - def __raw_call_graph_using_symbol_table_target_method(self, - target_class_name: str, - target_method_signature: str, - cg=None) -> list[JGraphEdgesST]: - """ Generates call graph using symbol table information given the target method and target class + def __raw_call_graph_using_symbol_table_target_method(self, target_class_name: str, target_method_signature: str, cg=None) -> list[JGraphEdgesST]: + """Generates call graph using symbol table information given the target method and target class Args: qualified_class_name: qualified class name method_signature: source method signature @@ -776,28 +746,26 @@ def __raw_call_graph_using_symbol_table_target_method(self, """ if cg is None: cg = [] - target_method_details = self.get_method(qualified_class_name=target_class_name, - method_signature=target_method_signature) + target_method_details = self.get_method(qualified_class_name=target_class_name, method_signature=target_method_signature) for class_name in self.get_all_classes(): for method in self.get_all_methods_in_class(qualified_class_name=class_name): - method_details = self.get_method(qualified_class_name=class_name, - method_signature=method) + method_details = self.get_method(qualified_class_name=class_name, method_signature=method) for call_site in method_details.call_sites: source_method_details = None - source_class = '' - callee_signature = '' - if call_site.callee_signature != '': - pattern = r'\b(?:[a-zA-Z_][\w\.]*\.)+([a-zA-Z_][\w]*)\b|<[^>]*>' + source_class = "" + callee_signature = "" + if call_site.callee_signature != "": + pattern = r"\b(?:[a-zA-Z_][\w\.]*\.)+([a-zA-Z_][\w]*)\b|<[^>]*>" # Find the part within the parentheses - start = call_site.callee_signature.find('(') + 1 - end = call_site.callee_signature.rfind(')') + start = call_site.callee_signature.find("(") + 1 + end = call_site.callee_signature.rfind(")") # Extract the elements inside the parentheses - elements = call_site.callee_signature[start:end].split(',') + elements = call_site.callee_signature[start:end].split(",") # Apply the regex to each element - simplified_elements = [re.sub(pattern, r'\1', element.strip()) for element in elements] + simplified_elements = [re.sub(pattern, r"\1", element.strip()) for element in elements] # Reconstruct the string with simplified elements callee_signature = f"{call_site.callee_signature[:start]}{', '.join(simplified_elements)}{call_site.callee_signature[end:]}" @@ -806,39 +774,31 @@ def __raw_call_graph_using_symbol_table_target_method(self, # call to any class if self.get_class(qualified_class_name=call_site.receiver_type): if callee_signature == target_method_signature and call_site.receiver_type == target_class_name: - source_method_details = self.get_method(method_signature=method, - qualified_class_name=class_name) + source_method_details = self.get_method(method_signature=method, qualified_class_name=class_name) source_class = class_name else: # check if any method exists with the signature in the class even if the receiver type is blank if callee_signature == target_method_signature and class_name == target_class_name: - source_method_details = self.get_method(method_signature=method, - qualified_class_name=class_name) + source_method_details = self.get_method(method_signature=method, qualified_class_name=class_name) source_class = class_name - if source_class != '' and source_method_details is not None: + if source_class != "" and source_method_details is not None: source: JMethodDetail target: JMethodDetail type: str weight: str call_edge = JGraphEdgesST( - source=JMethodDetail(method_declaration=source_method_details.declaration, - klass=source_class, - method=source_method_details), - target=JMethodDetail(method_declaration=target_method_details.declaration, - klass=target_class_name, - method=target_method_details), - type='CALL_DEP', - weight='1') + source=JMethodDetail(method_declaration=source_method_details.declaration, klass=source_class, method=source_method_details), + target=JMethodDetail(method_declaration=target_method_details.declaration, klass=target_class_name, method=target_method_details), + type="CALL_DEP", + weight="1", + ) if call_edge not in cg: cg.append(call_edge) return cg - def __raw_call_graph_using_symbol_table(self, - qualified_class_name: str, - method_signature: str, - cg=None) -> list[JGraphEdgesST]: - """ Generates a call graph using symbol table information. + def __raw_call_graph_using_symbol_table(self, qualified_class_name: str, method_signature: str, cg=None) -> list[JGraphEdgesST]: + """Generates a call graph using symbol table information. Args: qualified_class_name (str): The qualified class name. @@ -850,30 +810,29 @@ def __raw_call_graph_using_symbol_table(self, """ if cg is None: cg = [] - source_method_details = self.get_method(qualified_class_name=qualified_class_name, - method_signature=method_signature) + source_method_details = self.get_method(qualified_class_name=qualified_class_name, method_signature=method_signature) # If the provided classname and method signature combination do not exist if source_method_details is None: return cg for call_site in source_method_details.call_sites: target_method_details = None - target_class = '' - callee_signature = '' - if call_site.callee_signature != '': + target_class = "" + callee_signature = "" + if call_site.callee_signature != "": # Currently the callee signature returns the fully qualified type, whereas # the key for JCallable does not. The below logic converts the fully qualified signature # to the desider format. Only limitation is the nested generic type. - pattern = r'\b(?:[a-zA-Z_][\w\.]*\.)+([a-zA-Z_][\w]*)\b|<[^>]*>' + pattern = r"\b(?:[a-zA-Z_][\w\.]*\.)+([a-zA-Z_][\w]*)\b|<[^>]*>" # Find the part within the parentheses - start = call_site.callee_signature.find('(') + 1 - end = call_site.callee_signature.rfind(')') + start = call_site.callee_signature.find("(") + 1 + end = call_site.callee_signature.rfind(")") # Extract the elements inside the parentheses - elements = call_site.callee_signature[start:end].split(',') + elements = call_site.callee_signature[start:end].split(",") # Apply the regex to each element - simplified_elements = [re.sub(pattern, r'\1', element.strip()) for element in elements] + simplified_elements = [re.sub(pattern, r"\1", element.strip()) for element in elements] # Reconstruct the string with simplified elements callee_signature = f"{call_site.callee_signature[:start]}{', '.join(simplified_elements)}{call_site.callee_signature[end:]}" @@ -881,43 +840,35 @@ def __raw_call_graph_using_symbol_table(self, if call_site.receiver_type != "": # call to any class if self.get_class(qualified_class_name=call_site.receiver_type): - tmd = self.get_method(method_signature=callee_signature, - qualified_class_name=call_site.receiver_type) + tmd = self.get_method(method_signature=callee_signature, qualified_class_name=call_site.receiver_type) if tmd is not None: target_method_details = tmd target_class = call_site.receiver_type else: # check if any method exists with the signature in the class even if the receiver type is blank - tmd = self.get_method(method_signature=callee_signature, - qualified_class_name=qualified_class_name) + tmd = self.get_method(method_signature=callee_signature, qualified_class_name=qualified_class_name) if tmd is not None: target_method_details = tmd target_class = qualified_class_name - if target_class != '' and target_method_details is not None: + if target_class != "" and target_method_details is not None: source: JMethodDetail target: JMethodDetail type: str weight: str call_edge = JGraphEdgesST( - source=JMethodDetail(method_declaration=source_method_details.declaration, - klass=qualified_class_name, - method=source_method_details), - target=JMethodDetail(method_declaration=target_method_details.declaration, - klass=target_class, - method=target_method_details), - type='CALL_DEP', - weight='1') + source=JMethodDetail(method_declaration=source_method_details.declaration, klass=qualified_class_name, method=source_method_details), + target=JMethodDetail(method_declaration=target_method_details.declaration, klass=target_class, method=target_method_details), + type="CALL_DEP", + weight="1", + ) if call_edge not in cg: cg.append(call_edge) - cg = self.__raw_call_graph_using_symbol_table(qualified_class_name=target_class, - method_signature=target_method_details.signature, - cg=cg) + cg = self.__raw_call_graph_using_symbol_table(qualified_class_name=target_class, method_signature=target_method_details.signature, cg=cg) return cg - def get_class_call_graph(self, qualified_class_name: str, method_name: str | None = None) -> List[ - Tuple[JMethodDetail, JMethodDetail]]: - """ Generates a call graph for a given class and (optionally) filters by a given method. + def get_class_call_graph(self, qualified_class_name: str, method_name: str | None = None) -> List[Tuple[JMethodDetail, JMethodDetail]]: + """Generates a call graph for a given class and (optionally) filters by a given method. Args: qualified_class_name (str): The qualified name of the class. @@ -935,8 +886,7 @@ def get_class_call_graph(self, qualified_class_name: str, method_name: str | Non if method_name is None: filter_criteria = {node for node in self.call_graph.nodes if node[1] == qualified_class_name} else: - filter_criteria = {node for node in self.call_graph.nodes if - tuple(node) == (method_name, qualified_class_name)} + filter_criteria = {node for node in self.call_graph.nodes if tuple(node) == (method_name, qualified_class_name)} graph_edges: List[Tuple[JMethodDetail, JMethodDetail]] = list() for edge in self.call_graph.edges(nbunch=filter_criteria): @@ -947,7 +897,7 @@ def get_class_call_graph(self, qualified_class_name: str, method_name: str | Non return graph_edges def get_all_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: - """ Returns a dictionary of all entry point methods in the Java code with + """Returns a dictionary of all entry point methods in the Java code with qualified class name as the key and a dictionary of methods in that class as the value. Returns: @@ -958,13 +908,12 @@ def get_all_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: class_method_dict = {} class_dict = self.get_all_classes() for k, v in class_dict.items(): - entry_point_methods = {method_name: callable_decl for (method_name, callable_decl) in - v.callable_declarations.items() if callable_decl.is_entry_point is True} + entry_point_methods = {method_name: callable_decl for (method_name, callable_decl) in v.callable_declarations.items() if callable_decl.is_entry_point is True} class_method_dict[k] = entry_point_methods return class_method_dict def get_all_entry_point_classes(self) -> Dict[str, JType]: - """ Returns a dictionary of all entry point classes in the Java code. + """Returns a dictionary of all entry point classes in the Java code. Returns: Dict[str, JType]: A dictionary of all entry point classes in the Java code, diff --git a/cldk/models/java/models.py b/cldk/models/java/models.py index a4cd7e7c..7bc51274 100644 --- a/cldk/models/java/models.py +++ b/cldk/models/java/models.py @@ -7,10 +7,11 @@ constants = ConstantsNamespace() context_concrete_class = ContextVar("context_concrete_class") # context var to store class concreteness +_CALLABLES_LOOKUP_TABLE = dict() class JField(BaseModel): - """ Represents a field in a Java class or interface. + """Represents a field in a Java class or interface. Attributes: comment (str): The comment associated with the field. @@ -33,7 +34,7 @@ class JField(BaseModel): class JCallableParameter(BaseModel): - """ Represents a parameter of a Java callable. + """Represents a parameter of a Java callable. Attributes: name (str): The name of the parameter. @@ -42,7 +43,6 @@ class JCallableParameter(BaseModel): modifiers (List[str]): The modifiers applied to the parameter. """ - name: str type: str annotations: List[str] @@ -50,18 +50,19 @@ class JCallableParameter(BaseModel): class JEnumConstant(BaseModel): - """ Represents a constant in an enumeration. + """Represents a constant in an enumeration. Attributes: name (str): The name of the enum constant. arguments (List[str]): The arguments associated with the enum constant. """ + name: str arguments: List[str] class JCallSite(BaseModel): - """ Represents a call site. + """Represents a call site. Attributes: method_name (str): The name of the method called at the call site. @@ -77,7 +78,6 @@ class JCallSite(BaseModel): end_column (int): The ending column of the call site. """ - method_name: str receiver_expr: str = "" receiver_type: str @@ -97,7 +97,7 @@ class JCallSite(BaseModel): class JVariableDeclaration(BaseModel): - """ Represents a variable declaration. + """Represents a variable declaration. Attributes: name (str): The name of the variable. @@ -119,7 +119,7 @@ class JVariableDeclaration(BaseModel): class JCallable(BaseModel): - """ Represents a callable entity such as a method or constructor in Java. + """Represents a callable entity such as a method or constructor in Java. Attributes: signature (str): The signature of the callable. @@ -170,35 +170,30 @@ def detect_entrypoint_method(self): # check first if the class in which this method exists is concrete or not, by looking at the context var if context_concrete_class.get(): # convert annotations to the form GET, POST even if they are @GET or @GET('/ID') etc. - annotations_cleaned = [match for annotation in self.annotations for match in - re.findall(r'@(.*?)(?:\(|$)', annotation)] + annotations_cleaned = [match for annotation in self.annotations for match in re.findall(r"@(.*?)(?:\(|$)", annotation)] param_type_list = [val.type for val in self.parameters] # check the param types against known servlet param types - if any(substring in string for substring in param_type_list for string in - constants.ENTRY_POINT_METHOD_SERVLET_PARAM_TYPES): + if any(substring in string for substring in param_type_list for string in constants.ENTRY_POINT_METHOD_SERVLET_PARAM_TYPES): # check if this method is over-riding (only methods that override doGet / doPost etc. will be flagged as first level entry points) - if 'Override' in annotations_cleaned: + if "Override" in annotations_cleaned: self.is_entry_point = True return self # now check the cleaned annotations against known javax ws annotations - if any(substring in string for substring in annotations_cleaned for string in - constants.ENTRY_POINT_METHOD_JAVAX_WS_ANNOTATIONS): + if any(substring in string for substring in annotations_cleaned for string in constants.ENTRY_POINT_METHOD_JAVAX_WS_ANNOTATIONS): self.is_entry_point = True return self # check the cleaned annotations against known spring rest method annotations - if any(substring in string for substring in annotations_cleaned for string in - constants.ENTRY_POINT_METHOD_SPRING_ANNOTATIONS): + if any(substring in string for substring in annotations_cleaned for string in constants.ENTRY_POINT_METHOD_SPRING_ANNOTATIONS): self.is_entry_point = True return self return self class JType(BaseModel): - - """ Represents a Java class or interface. + """Represents a Java class or interface. Attributes: is_interface (bool): A flag indicating whether the object is an interface. @@ -265,16 +260,13 @@ def check_concrete_class(cls, values): def check_concrete_entry_point(self): """Detects if the class is entry point based on its properties.""" if self.is_concrete_class: - if any(substring in string for substring in (self.extends_list + self.implements_list) - for string in constants.ENTRY_POINT_SERVLET_CLASSES): + if any(substring in string for substring in (self.extends_list + self.implements_list) for string in constants.ENTRY_POINT_SERVLET_CLASSES): self.is_entry_point = True return self # Handle spring classes # clean annotations - take out @ and any paranehesis along with info in them. - annotations_cleaned = [match for annotation in self.annotations for match in - re.findall(r'@(.*?)(?:\(|$)', annotation)] - if any(substring in string for substring in annotations_cleaned - for string in constants.ENTRY_POINT_CLASS_SPRING_ANNOTATIONS): + annotations_cleaned = [match for annotation in self.annotations for match in re.findall(r"@(.*?)(?:\(|$)", annotation)] + if any(substring in string for substring in annotations_cleaned for string in constants.ENTRY_POINT_CLASS_SPRING_ANNOTATIONS): self.is_entry_point = True return self # context_concrete.reset() @@ -282,7 +274,7 @@ def check_concrete_entry_point(self): class JCompilationUnit(BaseModel): - """ Represents a compilation unit in Java. + """Represents a compilation unit in Java. Attributes: comment (str): A comment associated with the compilation unit. @@ -297,13 +289,14 @@ class JCompilationUnit(BaseModel): class JMethodDetail(BaseModel): - """ Represents details about a method in a Java class. + """Represents details about a method in a Java class. Attributes: method_declaration (str): The declaration string of the method. klass (str): The name of the class containing the method. 'class' is a reserved keyword in Python. method (JCallable): An instance of JCallable representing the callable details of the method. """ + method_declaration: str # class is a reserved keyword in python. we'll use klass. klass: str @@ -317,7 +310,7 @@ def __hash__(self): class JGraphEdgesST(BaseModel): - """ Represents an edge in a graph structure for method dependencies. + """Represents an edge in a graph structure for method dependencies. Attributes: source (JMethodDetail): The source method of the edge. @@ -327,6 +320,7 @@ class JGraphEdgesST(BaseModel): source_kind (Optional[str]): The kind of the source method. Default is None. destination_kind (Optional[str]): The kind of the target method. Default is None. """ + source: JMethodDetail target: JMethodDetail type: str @@ -336,16 +330,6 @@ class JGraphEdgesST(BaseModel): class JGraphEdges(BaseModel): - """ Represents an edge in a graph structure for method dependencies. - - Attributes: - source (JMethodDetail): The source method of the edge. - target (JMethodDetail): The target method of the edge. - type (str): The type of the edge. - weight (str): The weight of the edge, indicating the strength or significance of the connection. - source_kind (Optional[str]): The kind of the source method. Default is None. - destination_kind (Optional[str]): The kind of the target method. Default is None. - """ source: JMethodDetail target: JMethodDetail type: str @@ -356,19 +340,18 @@ class JGraphEdges(BaseModel): @field_validator("source", "target", mode="before") @classmethod def validate_source(cls, value) -> JMethodDetail: - """ Validates the source and target methods by parsing the input JSON string. - - Args: - value (str): A JSON string containing details about the method. - - Returns: - JMethodDetail: An instance of JMethodDetail representing the method details. - """ - - callable_dict = json.loads(value) - j_callable = JCallable(**json.loads(callable_dict["callable"])) # parse the value which is a quoted string - class_name = callable_dict["class_interface_declarations"] - method_decl = j_callable.declaration + if isinstance(value, str): + callable_dict = json.loads(value) + j_callable = JCallable(**json.loads(callable_dict["callable"])) # parse the value which is a quoted string + class_name = callable_dict["class_interface_declarations"] + method_decl = j_callable.declaration + elif isinstance(value, dict): + file_path, type_declaration, callable_declaration = value["file_path"], value["type_declaration"], value["callable_declaration"] + j_callable = _CALLABLES_LOOKUP_TABLE.get((file_path, type_declaration, callable_declaration), None) + if j_callable is None: + raise ValueError(f"Callable not found in lookup table: {file_path}, {type_declaration}, {callable_declaration}") + class_name = type_declaration + method_decl = j_callable.declaration mc = JMethodDetail(method_declaration=method_decl, klass=class_name, method=j_callable) return mc @@ -377,11 +360,27 @@ def __hash__(self): class JApplication(BaseModel): - """Represents a Java application. - - Attributes: - symbol_table (Dict[str, JCompilationUnit]): The symbol table representation containing compilation units. - system_dependency_graph (List[JGraphEdges], optional): The edges of the system dependency graph. Defaults to None. """ + Represents a Java application. + + Parameters + ---------- + symbol_table : List[JCompilationUnit] + The symbol table representation + system_dependency : List[JGraphEdges] + The edges of the system dependency graph. Default None. + """ + symbol_table: Dict[str, JCompilationUnit] system_dependency_graph: List[JGraphEdges] = None + + @field_validator("symbol_table", mode="after") + @classmethod + def validate_source(cls, symbol_table): + from ipdb import set_trace + + # Populate the lookup table for callables + for file_path, j_compulation_unit in symbol_table.items(): + for type_declaration, jtype in j_compulation_unit.type_declarations.items(): + for callable_declaration, j_callable in jtype.callable_declarations.items(): + _CALLABLES_LOOKUP_TABLE[(file_path, type_declaration, callable_declaration)] = j_callable diff --git a/tests/conftest.py b/tests/conftest.py index bd69ef18..391a140c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,17 @@ from urllib.request import urlretrieve +@pytest.fixture(scope="session", autouse=True) +def analysis_json_fixture(): + # Path to your pyproject.toml + pyproject_path = Path(__file__).parent.parent / "pyproject.toml" + + # Load the configuration + config = toml.load(pyproject_path) + + return config["tool"]["cldk"]["testing"]["sample-application-analysis-json"] + + @pytest.fixture(scope="session", autouse=True) def test_fixture(): """ diff --git a/tests/models/java/test_java_models.py b/tests/models/java/test_java_models.py index e69de29b..cfabe05b 100644 --- a/tests/models/java/test_java_models.py +++ b/tests/models/java/test_java_models.py @@ -0,0 +1,11 @@ +from typing import List, Tuple +from cldk import CLDK + + +def test_get_class_call_graph(analysis_json_fixture): + # Initialize the CLDK object with the project directory, language, and analysis_backend. + cldk = CLDK(language="java") + + analysis = cldk.analysis( + project_path=analysis_json_fixture, analysis_backend="codeanalyzer", analysis_json_path=analysis_json_fixture, eager=False, analysis_level="call-graph" + )