diff --git a/cldk/analysis/java/codeanalyzer/codeanalyzer.py b/cldk/analysis/java/codeanalyzer/codeanalyzer.py index 207a3481..355c7ec9 100644 --- a/cldk/analysis/java/codeanalyzer/codeanalyzer.py +++ b/cldk/analysis/java/codeanalyzer/codeanalyzer.py @@ -12,16 +12,11 @@ from networkx import DiGraph +from cldk.analysis import AnalysisLevel from cldk.analysis.java.treesitter import JavaSitter from cldk.models.java import JGraphEdges -from cldk.models.java.models import ( - JApplication, - JCallable, - JField, - JMethodDetail, - JType, - JCompilationUnit, -) +from cldk.models.java.models import JApplication, JCallable, JField, JMethodDetail, JType, JCompilationUnit, \ + JGraphEdgesST from typing import Dict, List, Tuple from typing import Union @@ -75,15 +70,12 @@ def __init__( self.analysis_level = analysis_level self.target_files = target_files self.application = self._init_codeanalyzer( - analysis_level=1 if analysis_level == "symbol_table" else 2 - ) + analysis_level=1 if analysis_level == AnalysisLevel.symbol_table else 2) # Attributes related the Java code analysis... - if analysis_level == "symbol_table": - self.call_graph: DiGraph | None = None + if analysis_level == AnalysisLevel.call_graph: + self.call_graph: DiGraph = self._generate_call_graph(using_symbol_table=False) else: - self.call_graph: DiGraph = self._generate_call_graph( - using_symbol_table=False - ) + self.call_graph: DiGraph | None = None @staticmethod def _download_or_update_code_analyzer(filepath: Path) -> str: @@ -112,18 +104,12 @@ def _download_or_update_code_analyzer(filepath: Path) -> str: if match: datetime_str = match.group(0) else: - raise Exception( - f"Release URL {download_url} does not contain a datetime pattern." - ) + raise Exception(f"Release URL {download_url} does not contain a datetime pattern.") # Look for codeanalyzer.YYYYMMDDTHHMMSS.jar in the filepath - current_codeanalyzer_jars = [ - jarfile for jarfile in filepath.glob("*.jar") - ] + current_codeanalyzer_jars = [jarfile for jarfile in filepath.glob("*.jar")] if not any(current_codeanalyzer_jars): - logger.info( - f"Codeanalzyer jar is not found. Downloading the latest version." - ) + logger.info(f"Codeanalzyer jar is not found. Downloading the latest version.") filename = filepath / f"codeanalyzer.{datetime_str}.jar" urlretrieve(download_url, filename) return filename.__str__() @@ -133,12 +119,9 @@ def _download_or_update_code_analyzer(filepath: Path) -> str: if match: current_datetime_str = match.group(0) - if datetime.strptime( - datetime_str, date_format - ) > datetime.strptime(current_datetime_str, date_format): - logger.info( - f"Codeanalzyer jar is outdated. Downloading the latest version." - ) + if datetime.strptime(datetime_str, date_format) > datetime.strptime(current_datetime_str, + date_format): + logger.info(f"Codeanalzyer jar is outdated. Downloading the latest version.") # Remove the older codeanalyzer jar for jarfile in current_codeanalyzer_jars: jarfile.unlink() @@ -147,17 +130,13 @@ def _download_or_update_code_analyzer(filepath: Path) -> str: urlretrieve(download_url, filename) else: filename = current_codeanalyzer_jar_name - logger.info( - f"Codeanalzyer jar is already at the latest version." - ) + logger.info(f"Codeanalzyer jar is already at the latest version.") else: filename = current_codeanalyzer_jar_name return filename.__str__() else: - raise Exception( - f"Failed to fetch release warn: {response.status_code} {response.text}" - ) + raise Exception(f"Failed to fetch release warn: {response.status_code} {response.text}") def _get_application(self) -> JApplication: """ @@ -191,29 +170,21 @@ def _get_codeanalyzer_exec(self) -> List[str]: if self.use_graalvm_binary: with resources.as_file( - resources.files("cldk.analysis.java.codeanalyzer.bin") / "codeanalyzer" - ) as codeanalyzer_bin_path: + resources.files("cldk.analysis.java.codeanalyzer.bin") / "codeanalyzer") as codeanalyzer_bin_path: codeanalyzer_exec = shlex.split(codeanalyzer_bin_path.__str__()) else: + if self.analysis_backend_path: analysis_backend_path = Path(self.analysis_backend_path) logger.info(f"Using codeanalyzer.jar from {analysis_backend_path}") - codeanalyzer_exec = shlex.split( - f"java -jar {analysis_backend_path / 'codeanalyzer.jar'}" - ) + codeanalyzer_exec = shlex.split(f"java -jar {analysis_backend_path / 'codeanalyzer.jar'}") else: # Since the path to codeanalyzer.jar was not provided, we'll download the latest version from GitHub. - with resources.as_file( - resources.files("cldk.analysis.java.codeanalyzer.jar") - ) as codeanalyzer_jar_path: + with resources.as_file(resources.files("cldk.analysis.java.codeanalyzer.jar")) as codeanalyzer_jar_path: # Download the codeanalyzer jar if it doesn't exist, update if it's outdated, # do nothing if it's up-to-date. - codeanalyzer_jar_file = self._download_or_update_code_analyzer( - codeanalyzer_jar_path - ) - codeanalyzer_exec = shlex.split( - f"java -jar {codeanalyzer_jar_file}" - ) + codeanalyzer_jar_file = self._download_or_update_code_analyzer(codeanalyzer_jar_path) + codeanalyzer_exec = shlex.split(f"java -jar {codeanalyzer_jar_file}") return codeanalyzer_exec def _init_codeanalyzer(self, analysis_level=1) -> JApplication: @@ -256,9 +227,7 @@ def _init_codeanalyzer(self, analysis_level=1) -> JApplication: raise CodeanalyzerExecutionException(str(e)) from e else: - analysis_json_path_file = Path(self.analysis_json_path).joinpath( - "analysis.json" - ) + analysis_json_path_file = Path(self.analysis_json_path).joinpath("analysis.json") if not analysis_json_path_file.exists() or self.eager_analysis: # If the analysis file does not exist, we'll run the analysis. Alternately, if the eager_analysis # flag is set, we'll run the analysis every time the object is created. This will happen regradless @@ -277,9 +246,7 @@ def _init_codeanalyzer(self, analysis_level=1) -> JApplication: f"-i {Path(self.project_dir)} --analysis-level={analysis_level} -o {self.analysis_json_path}" ) try: - logger.info( - f"Running codeanalyzer subprocess with args {codeanalyzer_args}" - ) + logger.info(f"Running codeanalyzer subprocess with args {codeanalyzer_args}") subprocess.run( codeanalyzer_args, capture_output=True, @@ -287,9 +254,7 @@ def _init_codeanalyzer(self, analysis_level=1) -> JApplication: check=True, ) if not analysis_json_path_file.exists(): - raise CodeanalyzerExecutionException( - "Codeanalyzer did not generate the analysis file." - ) + raise CodeanalyzerExecutionException("Codeanalyzer did not generate the analysis file.") except Exception as e: raise CodeanalyzerExecutionException(str(e)) from e @@ -313,9 +278,8 @@ def _codeanalyzer_single_file(self): try: print(f"Running {' '.join(codeanalyzer_cmd)}") logger.info(f"Running {' '.join(codeanalyzer_cmd)}") - console_out: CompletedProcess[str] = subprocess.run( - codeanalyzer_cmd, capture_output=True, text=True, check=True - ) + console_out: CompletedProcess[str] = subprocess.run(codeanalyzer_cmd, capture_output=True, text=True, + check=True) if console_out.returncode != 0: raise CodeanalyzerExecutionException(console_out.stderr) return JApplication(**json.loads(console_out.stdout)) @@ -378,9 +342,7 @@ def _generate_call_graph(self, using_symbol_table) -> DiGraph: """ cg = nx.DiGraph() if using_symbol_table: - NotImplementedError( - "Call graph generation using symbol table is not implemented yet." - ) + NotImplementedError("Call graph generation using symbol table is not implemented yet.") else: sdg = self.get_system_dependency_graph() tsu = JavaSitter() @@ -391,9 +353,7 @@ def _generate_call_graph(self, using_symbol_table) -> DiGraph: { "type": jge.type, "weight": jge.weight, - "calling_lines": tsu.get_calling_lines( - jge.source.method.code, jge.target.method.signature - ), + "calling_lines": tsu.get_calling_lines(jge.source.method.code, jge.target.method.signature), }, ) for jge in sdg @@ -445,22 +405,16 @@ def get_call_graph_json(self) -> str: for edge in edges: callgraph_dict = {} callgraph_dict["source_method_signature"] = edge[0][0] - callgraph_dict["source_method_body"] = self.call_graph.nodes[edge[0]][ - "method_detail" - ].method.code + callgraph_dict["source_method_body"] = self.call_graph.nodes[edge[0]]["method_detail"].method.code callgraph_dict["source_class"] = edge[0][1] callgraph_dict["target_method_signature"] = edge[1][0] - callgraph_dict["target_method_body"] = self.call_graph.nodes[edge[1]][ - "method_detail" - ].method.code + callgraph_dict["target_method_body"] = self.call_graph.nodes[edge[1]]["method_detail"].method.code callgraph_dict["target_class"] = edge[1][1] callgraph_dict["calling_lines"] = edge[2] callgraph_list.append(callgraph_dict) return json.dumps(callgraph_list) - def get_all_callers( - self, target_class_name: str, target_method_signature: str - ) -> Dict: + def get_all_callers(self, target_class_name: str, target_method_signature: str, using_symbol_table: bool) -> Dict: """ Get all the caller details for a given java method. @@ -471,10 +425,16 @@ def get_all_callers( """ caller_detail_dict = {} - if (target_method_signature, target_class_name) not in self.call_graph.nodes(): + call_graph = None + if using_symbol_table: + call_graph = self.__raw_call_graph_using_symbol_table_target_method(target_class_name=target_class_name, + target_method_signature=target_method_signature) + else: + call_graph = self.call_graph + if (target_method_signature, target_class_name) not in call_graph.nodes(): return caller_detail_dict - in_edge_view = self.call_graph.in_edges( + in_edge_view = call_graph.in_edges( nbunch=( target_method_signature, target_class_name, @@ -482,21 +442,16 @@ def get_all_callers( data=True, ) caller_detail_dict["caller_details"] = [] - caller_detail_dict["target_method"] = self.call_graph.nodes[ - (target_method_signature, target_class_name) - ]["method_detail"] + caller_detail_dict["target_method"] = call_graph.nodes[(target_method_signature, target_class_name)][ + "method_detail"] for source, target, data in in_edge_view: - cm = { - "caller_method": self.call_graph.nodes[source]["method_detail"], - "calling_lines": data["calling_lines"], - } + cm = {"caller_method": call_graph.nodes[source]["method_detail"], + "calling_lines": data["calling_lines"]} caller_detail_dict["caller_details"].append(cm) return caller_detail_dict - def get_all_callees( - self, source_class_name: str, source_method_signature: str - ) -> Dict: + def get_all_callees(self, source_class_name: str, source_method_signature: str, using_symbol_table: bool) -> Dict: """ Get all the callee details for a given java method. @@ -506,19 +461,22 @@ def get_all_callees( Callee details in a dictionary. """ callee_detail_dict = {} - if (source_method_signature, source_class_name) not in self.call_graph.nodes(): + call_graph = None + if using_symbol_table: + call_graph = self.__call_graph_using_symbol_table(qualified_class_name=source_class_name, + method_signature=source_method_signature) + else: + call_graph = self.call_graph + if (source_method_signature, source_class_name) not in call_graph.nodes(): return callee_detail_dict - out_edge_view = self.call_graph.out_edges( - nbunch=(source_method_signature, source_class_name), data=True - ) + out_edge_view = call_graph.out_edges(nbunch=(source_method_signature, source_class_name), data=True) callee_detail_dict["callee_details"] = [] - callee_detail_dict["source_method"] = self.call_graph.nodes[ - (source_method_signature, source_class_name) - ]["method_detail"] + callee_detail_dict["source_method"] = call_graph.nodes[(source_method_signature, source_class_name)][ + "method_detail"] for source, target, data in out_edge_view: - cm = {"callee_method": self.call_graph.nodes[target]["method_detail"]} + cm = {"callee_method": call_graph.nodes[target]["method_detail"]} cm["calling_lines"] = data["calling_lines"] callee_detail_dict["callee_details"].append(cm) return callee_detail_dict @@ -655,11 +613,7 @@ def get_all_methods_in_class(self, qualified_class_name) -> Dict[str, JCallable] ci = self.get_class(qualified_class_name) if ci is None: return {} - methods = { - k: v - for (k, v) in ci.callable_declarations.items() - if v.is_constructor is False - } + methods = {k: v for (k, v) in ci.callable_declarations.items() if v.is_constructor is False} return methods def get_all_constructors(self, qualified_class_name) -> Dict[str, JCallable]: @@ -679,11 +633,7 @@ def get_all_constructors(self, qualified_class_name) -> Dict[str, JCallable]: ci = self.get_class(qualified_class_name) if ci is None: return {} - constructors = { - k: v - for (k, v) in ci.callable_declarations.items() - if v.is_constructor is True - } + constructors = {k: v for (k, v) in ci.callable_declarations.items() if v.is_constructor is True} return constructors def get_all_sub_classes(self, qualified_class_name) -> Dict[str, JType]: @@ -700,10 +650,8 @@ def get_all_sub_classes(self, qualified_class_name) -> Dict[str, JType]: all_classes = self.get_all_classes() sub_classes = {} for cls in all_classes: - if ( - qualified_class_name in all_classes[cls].implements_list - or qualified_class_name in all_classes[cls].extends_list - ): + if qualified_class_name in all_classes[cls].implements_list or qualified_class_name in all_classes[ + cls].extends_list: sub_classes[cls] = all_classes[cls] return sub_classes @@ -723,9 +671,7 @@ def get_all_fields(self, qualified_class_name) -> List[JField]: """ ci = self.get_class(qualified_class_name) if ci is None: - logging.warning( - f"Class {qualified_class_name} not found in the application view." - ) + logging.warning(f"Class {qualified_class_name} not found in the application view.") return list() return ci.field_declarations @@ -745,14 +691,10 @@ def get_all_nested_classes(self, qualified_class_name) -> List[JType]: """ ci = self.get_class(qualified_class_name) if ci is None: - logging.warning( - f"Class {qualified_class_name} not found in the application view." - ) + logging.warning(f"Class {qualified_class_name} not found in the application view.") return list() nested_classes = ci.nested_type_declerations - return [ - self.get_class(c) for c in nested_classes - ] # Assuming qualified nested class names + return [self.get_class(c) for c in nested_classes] # Assuming qualified nested class names def get_extended_classes(self, qualified_class_name) -> List[str]: """ @@ -770,9 +712,7 @@ def get_extended_classes(self, qualified_class_name) -> List[str]: """ ci = self.get_class(qualified_class_name) if ci is None: - logging.warning( - f"Class {qualified_class_name} not found in the application view." - ) + logging.warning(f"Class {qualified_class_name} not found in the application view.") return list() return ci.extends_list @@ -792,15 +732,239 @@ def get_implemented_interfaces(self, qualified_class_name) -> List[str]: """ ci = self.get_class(qualified_class_name) if ci is None: - logging.warning( - f"Class {qualified_class_name} not found in the application view." - ) + logging.warning(f"Class {qualified_class_name} not found in the application view.") return list() return ci.implements_list - def get_class_call_graph( - self, qualified_class_name: str, method_name: str | None = None - ) -> List[Tuple[JMethodDetail, JMethodDetail]]: + def get_class_call_graph_using_symbol_table(self, qualified_class_name: str, + method_signature: str | None = None) -> ( + List)[Tuple[JMethodDetail, JMethodDetail]]: + """ + Returns call graph using symbol table. The analysis will not be + complete as symbol table has known limitation of resolving types + Args: + qualified_class_name: qualified name of the class + method_signature: method signature of the starting point of the call graph + + Returns: List[Tuple[JMethodDetail, JMethodDetail]] + List of edges + """ + call_graph = self.__call_graph_using_symbol_table(qualified_class_name, method_signature) + if method_signature is None: + filter_criteria = {node for node in call_graph.nodes if node[1] == qualified_class_name} + else: + filter_criteria = {node for node in call_graph.nodes if + tuple(node) == (method_signature, qualified_class_name)} + + graph_edges: List[Tuple[JMethodDetail, JMethodDetail]] = list() + for edge in call_graph.edges(nbunch=filter_criteria): + source: JMethodDetail = call_graph.nodes[edge[0]]["method_detail"] + target: JMethodDetail = call_graph.nodes[edge[1]]["method_detail"] + graph_edges.append((source, target)) + return graph_edges + + def __call_graph_using_symbol_table(self, + qualified_class_name: str, + method_signature: str, is_target_method: bool = False)-> DiGraph: + """ + Generate call graph using symbol table + Args: + qualified_class_name: qualified class name + method_signature: method signature + is_target_method: is the input method is a target method. By default, it is the source method + + Returns: + DiGraph: call graph + """ + cg = nx.DiGraph() + sdg = None + if is_target_method: + sdg = None + else: + sdg = self.__raw_call_graph_using_symbol_table(qualified_class_name=qualified_class_name, + method_signature=method_signature) + tsu = JavaSitter() + edge_list = [ + ( + (jge.source.method.signature, jge.source.klass), + (jge.target.method.signature, jge.target.klass), + { + "type": jge.type, + "weight": jge.weight, + "calling_lines": tsu.get_calling_lines(jge.source.method.code, jge.target.method.signature), + }, + ) + for jge in sdg + ] + for jge in sdg: + cg.add_node( + (jge.source.method.signature, jge.source.klass), + method_detail=jge.source, + ) + cg.add_node( + (jge.target.method.signature, jge.target.klass), + method_detail=jge.target, + ) + cg.add_edges_from(edge_list) + return cg + + def __raw_call_graph_using_symbol_table_target_method(self, + target_class_name: str, + target_method_signature: str, + cg=None) -> list[JGraphEdgesST]: + """ + Generates call graph using symbol table information given the target method and target class + Args: + qualified_class_name: qualified class name + method_signature: source method signature + cg: call graph + + Returns: + list[JGraphEdgesST]: list of call edges + """ + if cg is None: + cg = [] + target_method_details = self.get_method(qualified_class_name=target_class_name, + method_signature=target_method_signature) + for class_name in self.get_all_classes(): + for method in self.get_all_methods_in_class(qualified_class_name=class_name): + method_details = self.get_method(qualified_class_name=class_name, + method_signature=method) + for call_site in method_details.call_sites: + source_method_details = None + source_class = '' + callee_signature = '' + if call_site.callee_signature != '': + pattern = r'\b(?:[a-zA-Z_][\w\.]*\.)+([a-zA-Z_][\w]*)\b|<[^>]*>' + + # Find the part within the parentheses + start = call_site.callee_signature.find('(') + 1 + end = call_site.callee_signature.rfind(')') + + # Extract the elements inside the parentheses + elements = call_site.callee_signature[start:end].split(',') + + # Apply the regex to each element + simplified_elements = [re.sub(pattern, r'\1', element.strip()) for element in elements] + + # Reconstruct the string with simplified elements + callee_signature = f"{call_site.callee_signature[:start]}{', '.join(simplified_elements)}{call_site.callee_signature[end:]}" + + if call_site.receiver_type != "": + # call to any class + if self.get_class(qualified_class_name=call_site.receiver_type): + if callee_signature==target_method_signature and call_site.receiver_type == target_class_name: + source_method_details = self.get_method(method_signature=method, + qualified_class_name=class_name) + source_class = class_name + else: + # check if any method exists with the signature in the class even if the receiver type is blank + if callee_signature == target_method_signature and class_name == target_class_name: + source_method_details = self.get_method(method_signature=method, + qualified_class_name=class_name) + source_class = class_name + + if source_class != '' and source_method_details is not None: + source: JMethodDetail + target: JMethodDetail + type: str + weight: str + call_edge = JGraphEdgesST( + source=JMethodDetail(method_declaration=source_method_details.declaration, + klass=source_class, + method=source_method_details), + target=JMethodDetail(method_declaration=target_method_details.declaration, + klass=target_class_name, + method=target_method_details), + type='CALL_DEP', + weight='1') + if call_edge not in cg: + cg.append(call_edge) + return cg + + def __raw_call_graph_using_symbol_table(self, + qualified_class_name: str, + method_signature: str, + cg=None) -> list[JGraphEdgesST]: + """ + Generates call graph using symbol table information + Args: + qualified_class_name: qualified class name + method_signature: source method signature + cg: call graph + + Returns: + list[JGraphEdgesST]: list of call edges + """ + if cg is None: + cg = [] + source_method_details = self.get_method(qualified_class_name=qualified_class_name, + method_signature=method_signature) + # If the provided classname and method signature combination do not exist + if source_method_details is None: + return cg + for call_site in source_method_details.call_sites: + target_method_details = None + target_class = '' + callee_signature = '' + if call_site.callee_signature != '': + # Currently the callee signature returns the fully qualified type, whereas + # the key for JCallable does not. The below logic converts the fully qualified signature + # to the desider format. Only limitation is the nested generic type. + pattern = r'\b(?:[a-zA-Z_][\w\.]*\.)+([a-zA-Z_][\w]*)\b|<[^>]*>' + + # Find the part within the parentheses + start = call_site.callee_signature.find('(') + 1 + end = call_site.callee_signature.rfind(')') + + # Extract the elements inside the parentheses + elements = call_site.callee_signature[start:end].split(',') + + # Apply the regex to each element + simplified_elements = [re.sub(pattern, r'\1', element.strip()) for element in elements] + + # Reconstruct the string with simplified elements + callee_signature = f"{call_site.callee_signature[:start]}{', '.join(simplified_elements)}{call_site.callee_signature[end:]}" + + if call_site.receiver_type != "": + # call to any class + if self.get_class(qualified_class_name=call_site.receiver_type): + tmd = self.get_method(method_signature=callee_signature, + qualified_class_name=call_site.receiver_type) + if tmd is not None: + target_method_details = tmd + target_class = call_site.receiver_type + else: + # check if any method exists with the signature in the class even if the receiver type is blank + tmd = self.get_method(method_signature=callee_signature, + qualified_class_name=qualified_class_name) + if tmd is not None: + target_method_details = tmd + target_class = qualified_class_name + + if target_class != '' and target_method_details is not None: + source: JMethodDetail + target: JMethodDetail + type: str + weight: str + call_edge = JGraphEdgesST( + source=JMethodDetail(method_declaration=source_method_details.declaration, + klass=qualified_class_name, + method=source_method_details), + target=JMethodDetail(method_declaration=target_method_details.declaration, + klass=target_class, + method=target_method_details), + type='CALL_DEP', + weight='1') + if call_edge not in cg: + cg.append(call_edge) + cg = self.__raw_call_graph_using_symbol_table(qualified_class_name=target_class, + method_signature=target_method_details.signature, + cg=cg) + return cg + + def get_class_call_graph(self, qualified_class_name: str, method_name: str | None = None) -> List[ + Tuple[JMethodDetail, JMethodDetail]]: """ A call graph for a given class and (optionally) filtered by a given method. @@ -827,17 +991,10 @@ def get_class_call_graph( # If the method name is not provided, we'll get the call graph for the entire class. if method_name is None: - filter_criteria = { - node - for node in self.call_graph.nodes - if node[1] == qualified_class_name - } + filter_criteria = {node for node in self.call_graph.nodes if node[1] == qualified_class_name} else: - filter_criteria = { - node - for node in self.call_graph.nodes - if tuple(node) == (method_name, qualified_class_name) - } + filter_criteria = {node for node in self.call_graph.nodes if + tuple(node) == (method_name, qualified_class_name)} graph_edges: List[Tuple[JMethodDetail, JMethodDetail]] = list() for edge in self.call_graph.edges(nbunch=filter_criteria): @@ -862,11 +1019,8 @@ def get_all_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: class_method_dict = {} class_dict = self.get_all_classes() for k, v in class_dict.items(): - entry_point_methods = { - method_name: callable_decl - for (method_name, callable_decl) in v.callable_declarations.items() - if callable_decl.is_entry_point is True - } + entry_point_methods = {method_name: callable_decl for (method_name, callable_decl) in + v.callable_declarations.items() if callable_decl.is_entry_point is True} class_method_dict[k] = entry_point_methods return class_method_dict @@ -883,9 +1037,5 @@ def get_all_entry_point_classes(self) -> Dict[str, JType]: class_dict = {} symtab = self.get_symbol_table() for val in symtab.values(): - class_dict.update( - (k, v) - for k, v in val.type_declarations.items() - if v.is_entry_point is True - ) - return class_dict + class_dict.update((k, v) for k, v in val.type_declarations.items() if v.is_entry_point is True) + return class_dict \ No newline at end of file diff --git a/cldk/analysis/java/java.py b/cldk/analysis/java/java.py index 5d295643..789224bd 100644 --- a/cldk/analysis/java/java.py +++ b/cldk/analysis/java/java.py @@ -88,10 +88,14 @@ def get_variables(self, **kwargs): raise NotImplementedError(f"Support for this functionality has not been implemented yet.") def get_service_entry_point_classes(self, **kwargs): - raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.backend.get_all_entry_point_classes() def get_service_entry_point_methods(self, **kwargs): - raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.backend.get_all_entry_point_methods() def get_application_view(self) -> JApplication: """ @@ -490,33 +494,33 @@ def get_class_call_graph(self, qualified_class_name: str, method_signature: str raise NotImplementedError(f"Support for this functionality has not been implemented yet.") return self.backend.get_class_call_graph(qualified_class_name, method_signature) - def get_entry_point_classes(self) -> Dict[str, JType]: - """ - Returns a dictionary of all entry point classes in the Java code. - - Returns: - -------- - Dict[str, JType] - A dict of all entry point classes in the Java code, with qualified class names as keys - """ - if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: - raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.backend.get_all_entry_point_classes() - - def get_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: - """ - Returns a dictionary of all entry point methods in the Java code with - qualified class name as key and dictionary of methods in that class - as value - - Returns: - -------- - Dict[str, Dict[str, JCallable]]: - A dictionary of dictionaries of entry point methods in the Java code. - """ - if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: - raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.backend.get_all_entry_point_methods() + # def get_entry_point_classes(self) -> Dict[str, JType]: + # """ + # Returns a dictionary of all entry point classes in the Java code. + # + # Returns: + # -------- + # Dict[str, JType] + # A dict of all entry point classes in the Java code, with qualified class names as keys + # """ + # if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + # raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + # return self.backend.get_all_entry_point_classes() + # + # def get_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: + # """ + # Returns a dictionary of all entry point methods in the Java code with + # qualified class name as key and dictionary of methods in that class + # as value + # + # Returns: + # -------- + # Dict[str, Dict[str, JCallable]]: + # A dictionary of dictionaries of entry point methods in the Java code. + # """ + # if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + # raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + # return self.backend.get_all_entry_point_methods() def remove_all_comments(self) -> str: """ diff --git a/cldk/analysis/python/treesitter/python_sitter.py b/cldk/analysis/python/treesitter/python_sitter.py index ec3ba656..df5dcc3f 100644 --- a/cldk/analysis/python/treesitter/python_sitter.py +++ b/cldk/analysis/python/treesitter/python_sitter.py @@ -197,6 +197,7 @@ def get_all_classes(self, module: str) -> List[PyClass]: full_signature=class_full_signature, methods=methods, super_classes=super_classes, + class_name=klass_name, is_test_class=is_test_class)) return classes diff --git a/cldk/models/python/models.py b/cldk/models/python/models.py index ab08b94c..1f398f8a 100644 --- a/cldk/models/python/models.py +++ b/cldk/models/python/models.py @@ -45,6 +45,7 @@ class PyClass(BaseModel): full_signature: str super_classes: List[str] is_test_class: bool + class_name: str = None methods: List[PyMethod] diff --git a/tests/example.py b/tests/example.py deleted file mode 100644 index ad612e0a..00000000 --- a/tests/example.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Example: Use CLDK to build a code summarization model -""" - -from cldk import CLDK -from cldk.analysis.java import JavaAnalysis - -# Initialize the Codellm-DevKit object with the project directory, language, and analysis_backend. -ns = CLDK( - project_dir="/Users/rajupavuluri/development/sample.daytrader8/", - language="java", - analysis_json_path="/Users/rkrsn/Downloads/sample.daytrader8/", -) - - -# Get the java application view for the project. -java_analysis: JavaAnalysis = ns.get_analysis() - -classes_dict = ns.preprocessing.get_classes() -# print(classes_dict) -entry_point_classes_dict = ns.preprocessing.get_entry_point_classes() -print(entry_point_classes_dict) - -entry_point_methods_dict = ns.preprocessing.get_entry_point_methods() -print(entry_point_methods_dict) - - -# ##get the first class in this dictionary for testing purposes -test_class_name = next(iter(classes_dict)) -print(test_class_name) -test_class = classes_dict[test_class_name] -# print(test_class) -# print(test_class.is_entry_point) - -# constructors = ns.preprocessing.get_all_constructors(test_class_name) -# print(constructors) - -# fields = ns.preprocessing.get_all_fields(test_class_name) -# print("fields :", fields) - -# methods = ns.preprocessing.get_all_methods_in_class(test_class_name) -# # print("number of methods in class ",test_class_name, ": ",len(methods)) -# nested_classes = ns.preprocessing.get_all_nested_classes(test_class_name) -# # print("nested_classes: ",nested_classes) -# extended_classes = ns.preprocessing.get_extended_classes(test_class_name) -# # print("extended_classes: ",extended_classes) -# implemented_interfaces = ns.preprocessing.get_implemented_interfaces( -# test_class_name -# ) -# # print("implemented_interfaces: ",implemented_interfaces) -# class_result = ns.preprocessing.get_class(test_class_name) -# print("class_result: ", class_result) -# java_file_name = ns.preprocessing.get_java_file(test_class_name) -# # print("java_file_name ",java_file_name) -# all_methods = ns.preprocessing.get_all_methods_in_application() -# # print(all_methods) -# method = ns.preprocessing.get_method( -# "com.ibm.websphere.samples.daytrader.util.Log", -# "public static void trace(String message)", -# ) -# print(method) -# # Get the call graph. - -# cg = ns.preprocessing.get_call_graph() -# print(cg) -# # print(ns.preprocessing.get_call_graph_json()) - -# # print(cg.edges) -# # d = ns.preprocessing.get_all_callers("com.ibm.websphere.samples.daytrader.util.Log","public static void trace(String message)") -# # print("caller details::") -# # print(d) -# # v = ns.preprocessing.get_all_callees("com.ibm.websphere.samples.daytrader.impl.ejb3.MarketSummarySingleton","private void updateMarketSummary()") -# # print("callee details::") -# # print(v) - -# """ -# # Get the user specified method. -# method: JCallable = app.get_method("com.example.foo.Bar.baz") # <- User specified method. - -# # Get the slices that contain the method. -# slices: nx.Generator = ns.preprocessing.get_slices_containing_method(method, sdg=app.sdg) - -# # Optional: Get samples for RAG from (say) elasticsearch -# few_shot_samples: List[str] = ns.prompting.rag( -# database={"hostname": "https://localhost:9200", "index": "summarization"} -# ).retrive_few_shot_samples(method=method, slices=slices) - -# # Natively we'll support PDL as the prompting engine to get summaries from the LLM. - -# summaries: List[str] = ns.prompting(engine="pdl").summarize(method, context=slices, few_shot_samples=few_shot_samples) - -# # Optionally, we will also support other open-source engines such as LMQL, Guidance, user defined Jinja, etc. -# summaries: List[str] = ns.prompting(engine="lmql").summarize(slices=slices, few_shot_samples=few_shot_samples) -# summaries: List[str] = ns.prompting(engine="guidance").summarize(slices=slices, few_shot_samples=few_shot_samples) -# summaries: List[str] = ns.prompting(engine="jinja", template="