In [114]:
from tree_sitter import Parser, Language
import tree_sitter_java as tsj
from neo4j import GraphDatabase

class Neo4jConnection:
    def __init__(self, uri, user, password, database=None):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        self.database = database

    def execute_query(self, query, **parameters):
        return self.driver.execute_query(query, database_=self.database, **parameters)

    def close(self):
        self.driver.close()

In [115]:
language = Language(tsj.language())
parser = Parser(language=language)
parser

<tree_sitter.Parser at 0x106d9fd80>

In [151]:
with open("../data/java/employee-crud/src/main/java/com/java_ast_knowledge_graph_poc/employee_crud/controller/EmployeeController.java", "r", encoding="utf-8") as file:
    source_code = file.read()

source_code

'package com.java_ast_knowledge_graph_poc.employee_crud.controller;\n\nimport java.util.List;\n\nimport org.springframework.beans.factory.annotation.Autowired;\nimport org.springframework.http.HttpStatus;\nimport org.springframework.http.ResponseEntity;\nimport org.springframework.web.bind.annotation.DeleteMapping;\nimport org.springframework.web.bind.annotation.GetMapping;\nimport org.springframework.web.bind.annotation.PathVariable;\nimport org.springframework.web.bind.annotation.PostMapping;\nimport org.springframework.web.bind.annotation.PutMapping;\nimport org.springframework.web.bind.annotation.RequestBody;\nimport org.springframework.web.bind.annotation.RequestMapping;\nimport org.springframework.web.bind.annotation.RestController;\n\nimport com.java_ast_knowledge_graph_poc.employee_crud.dto.EmployeeDTO;\nimport com.java_ast_knowledge_graph_poc.employee_crud.service.EmployeeService;\n\n@RestController\n@RequestMapping("/api/employees")\npublic class EmployeeController {\n\n    p

In [152]:
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node

In [118]:
graphDBConn = Neo4jConnection(uri="neo4j://localhost:7687", user="neo4j", password="password")

In [153]:
children_nodes = root_node.children
class_declration_node = children_nodes[-1]
package_declaration_node = children_nodes[0]

create package node

In [212]:
print(f"name: {package_declaration_node.children[1].text.decode('utf-8')}")
print(f"content: {package_declaration_node.text.decode('utf-8')}")

package_node = graphDBConn.execute_query(
    "MERGE (p:Package {name: $name, content: $content}) RETURN p",
    name=package_declaration_node.children[1].text.decode('utf-8'),
    content=package_declaration_node.text.decode('utf-8')
)

package_node

name: com.java_ast_knowledge_graph_poc.employee_crud.controller
content: package com.java_ast_knowledge_graph_poc.employee_crud.controller;


EagerResult(records=[<Record p=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:0' labels=frozenset({'Package'}) properties={'name': 'com.java_ast_knowledge_graph_poc.employee_crud.controller', 'content': 'package com.java_ast_knowledge_graph_poc.employee_crud.controller;'}>>], summary=<neo4j._work.summary.ResultSummary object at 0x1073c7930>, keys=['p'])

In [121]:
package_node.records[0].get("p")

<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:0' labels=frozenset({'Package'}) properties={'name': 'com.java_ast_knowledge_graph_poc.employee_crud.controller', 'content': 'package com.java_ast_knowledge_graph_poc.employee_crud.controller;'}>

create class node

In [213]:
rest_path = class_declration_node.children[0].children[1].children[-1].children[1].children[1].text.decode('utf-8')
class_name =class_declration_node.children[-2].text.decode('utf-8')
content = " ".join(list(map(lambda x: x.text.decode('utf-8'), class_declration_node.children[:-1])))
# Get the package node ID from the previous query
package_id = package_node.records[0].get("p").element_id

# Create class and relationship in one query
class_node = graphDBConn.execute_query(
    """
    MATCH (pkg:Package) WHERE elementId(pkg) = $package_id
    MERGE (c:Class {name: $name, rest_path: $rest_path, content: $content})
    MERGE (c)-[:BELONGS_TO]->(pkg)
    RETURN c
    """,
    package_id=package_id,
    name=class_name,
    rest_path=rest_path,
    content=content
)

class_node.records[0].get("c")

<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:1' labels=frozenset({'Class'}) properties={'name': 'EmployeeController', 'content': '@RestController\n@RequestMapping("/api/employees")\npublic class EmployeeController', 'rest_path': '/api/employees'}>

###process import declarations

In [214]:
import_declaration_nodes = children_nodes[1:-1]

In [215]:
import_nodes = []

In [216]:
for import_declaration_node in import_declaration_nodes:
    if next((node for node in import_declaration_node.children if node.type == "asterisk"), None):
        import_class_name = "*"
        import_package_name = import_declaration_node.children[1].text.decode('utf-8')
    else:
        import_class_name = import_declaration_node.children[1].children[-1].text.decode('utf-8')
        import_package_name = import_declaration_node.children[1].children[0].text.decode('utf-8')
    import_statement = import_declaration_node.text.decode('utf-8')
    import_node = graphDBConn.driver.execute_query(
        """
        MATCH (c1:Class {name: $class_name})
        MERGE (pkg:Package {name: $import_package_name})
        MERGE (c:Class {name: $import_class_name})
        MERGE (c)-[:BELONGS_TO]->(pkg)
        MERGE (c1)-[:IMPORTS {content: $import_statement}]->(c)
        RETURN c
        """,
        import_package_name=import_package_name,
        import_class_name=import_class_name,
        class_name=class_name,
        import_statement=import_statement
    )
    import_nodes.append(import_node.records[0])


process class body

In [217]:
class_body_nodes = class_declration_node.children[-1].children[1:-1]

process fields

In [218]:
field_declaration_nodes = [node for node in class_body_nodes if node.type == 'field_declaration']
field_declaration_nodes

[<Node type=field_declaration, start_point=(23, 4), end_point=(23, 50)>,
 <Node type=field_declaration, start_point=(24, 4), end_point=(24, 26)>]

In [219]:
field_nodes = []
for node in field_declaration_nodes:
    field_type = node.children[1].text.decode('utf-8')
    field_name = node.children[2].children[0].text.decode('utf-8')
    field_content = node.text.decode('utf-8')
    import_node = graphDBConn.execute_query(
        """
        MATCH (c:Class{name: $field_type})<-[:IMPORTS]-(:Class{name: $class_name}) RETURN c
        """,
        field_type=field_type,
        class_name=class_name
    )
    if len(import_node.records)>0:
        class_node_id = class_node.records[0].get("c").element_id
        type_class_id = import_node.records[0].get("c").element_id
        field_node = graphDBConn.execute_query(
            """
            MATCH(parentClass:Class) WHERE elementId(parentClass) = $class_node_id
            MATCH(typeClass:Class) WHERE elementId(typeClass) = $type_class_id
            MERGE(f:Field {name: $field_name, content: $field_content})
            MERGE(f)-[:MEMBER_OF]->(parentClass)
            MERGE(f)-[:TYPE]->(typeClass)
            RETURN f
            """,
            class_node_id=class_node_id,
            type_class_id=type_class_id,
            field_name=field_name,
            field_content=field_content
        )
        field_nodes.append(field_node.records[0])
    elif any(
    node.get("c").get("name") == "*" for node in import_nodes if node.get("c") is not None
    ):
        class_node_id = class_node.records[0].get("c").element_id
        import_node = graphDBConn.execute_query(
            """
            MATCH (wildcard:Class{name:"*"})-[:BELONGS_TO]->(p:Package)
            MATCH (parentClass:Class) WHERE elementId(parentClass) = $class_node_id
            MATCH (parentClass)-[:IMPORTS]->(wildcard)
            MATCH (type:Class{name:$field_type})-[:BELONGS_TO]->(p)
            RETURN type
            """,
            field_type=field_type,
            class_node_id=class_node_id
        )
        if len(import_node.records)>0:
            type_class_id = import_node.records[0].get("c").element_id
            field_node = graphDBConn.execute_query(
                """
                MATCH(parentClass:Class) WHERE elementId(parentClass) = $class_node_id
                MATCH(typeClass:Class) WHERE elementId(typeClass) = $type_class_id
                MERGE(f:Field {name: $field_name, content: $field_content})
                MERGE(f)-[:MEMBER_OF]->(parentClass)
                MERGE(f)-[:TYPE]->(typeClass)
                RETURN f
                """,
                class_node_id=class_node_id,
                type_class_id=type_class_id,
                field_name=field_name,
                field_content=field_content
            )
            field_nodes.append(field_node.records[0])
        else:
            class_node_id = class_node.records[0].get("c").element_id
            field_node = graphDBConn.execute_query(
                """
                MATCH(parentClass:Class) WHERE elementId(parentClass) = $class_node_id
                MERGE(typeClass:Class{name: $type_class_name})
                MERGE(f:Field {name: $field_name, content: $field_content})
                MERGE(f)-[:MEMBER_OF]->(parentClass)
                MERGE(f)-[:TYPE]->(typeClass)
                RETURN f
                """,
                class_node_id=class_node_id,
                type_class_name=field_type,
                field_name=field_name,
                field_content=field_content
            )
            field_nodes.append(field_node.records[0])
    else:
        class_node_id = class_node.records[0].get("c").element_id
        field_node = graphDBConn.execute_query(
            """
            MATCH(parentClass:Class) WHERE elementId(parentClass) = $class_node_id
            MERGE(typeClass:Class{name: $type_class_name})
            MERGE(f:Field {name: $field_name, content: $field_content})
            MERGE(f)-[:MEMBER_OF]->(parentClass)
            MERGE(f)-[:TYPE]->(typeClass)
            RETURN f
            """,
            class_node_id=class_node_id,
            type_class_name=field_type,
            field_name=field_name,
            field_content=field_content
        )
        field_nodes.append(field_node.records[0])
field_nodes

[<Record f=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:22' labels=frozenset({'Field'}) properties={'name': 'employeeService', 'content': 'private final EmployeeService employeeService;'}>>,
 <Record f=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:24' labels=frozenset({'Field'}) properties={'name': 'object', 'content': 'private Object object;'}>>]

In [220]:
class_node_id = class_node.records[0].get("c").element_id
class_node_id

'4:006ce47f-db97-446a-8d48-00e2d825b218:1'

process methods

In [221]:
method_declaration_nodes =  [node for node in class_body_nodes if node.type == 'method_declaration']
method_declaration_nodes

[<Node type=method_declaration, start_point=(31, 4), end_point=(35, 5)>,
 <Node type=method_declaration, start_point=(37, 4), end_point=(41, 5)>,
 <Node type=method_declaration, start_point=(43, 4), end_point=(46, 5)>,
 <Node type=method_declaration, start_point=(48, 4), end_point=(53, 5)>,
 <Node type=method_declaration, start_point=(55, 4), end_point=(59, 5)>,
 <Node type=method_declaration, start_point=(61, 4), end_point=(63, 5)>]

In [230]:
method_nodes = []
for method_declaration_node in method_declaration_nodes:
    method_name_node = identifier_node = next(
        (child for child in method_declaration_node.children if child.type == "identifier"),
        None
    )
    method_parameter_node = next(child for child in method_declaration_node.children if child.type == "formal_parameters")
    parameter_nodes = [child for child in method_parameter_node.children if child.type == "formal_parameter"]
    method_return_type_node = next(child for child in method_declaration_node.children if child.type == "generic_type" or child.type == "void_type" or child.type == "type_identifier" or child.type == "array_type" or child.type == "primitive_type")
    method_body = next(child for child in method_declaration_node.children if child.type == "block")
    existing_method_node = graphDBConn.execute_query(
        """
        MATCH (c:Class) WHERE elementId(c) = $class_node_id
        MATCH (c)<-[:MEMBER_OF]-(m:Method{name:$method_name}) RETURN m
        """,
        class_node_id=class_node_id,
        method_name=method_name_node.text.decode('utf-8')
    )
    if len(existing_method_node.records) == 0:
        method_node = graphDBConn.execute_query(
            """
            MATCH (c:Class) WHERE elementId(c) = $class_node_id
            MERGE (m:Method {name: $method_name, return_type: $return_type, content: $content})
            MERGE (m)-[:MEMBER_OF]->(c)
            RETURN m
            """,
            class_node_id=class_node_id,
            method_name=method_name_node.text.decode('utf-8'),
            return_type=method_return_type_node.text.decode('utf-8'),
            content=method_declaration_node.text.decode('utf-8')
        )
        method_node = method_node.records[0]
    else:
        method_node = existing_method_node.records[0]
        method_node_id = method_node.get("m").element_id
        method_node = graphDBConn.execute_query(
            """
            MATCH (m:Method) WHERE elementId(m) = $method_node_id
            SET m.content = $content,
                m.return_type = $return_type
            RETURN m
            """,
            method_node_id=method_node_id,
            content=method_declaration_node.text.decode('utf-8'),
            return_type=method_return_type_node.text.decode('utf-8')
        )
        method_node = method_node.records[0]
    method_nodes.append(method_node)
    method_node_id = method_node.get("m").element_id
    traverse_method_body(method_body, method_node_id)
method_nodes

Node type: block
Node type: {
Node type: local_variable_declaration
Node type: type_identifier
Node type: variable_declarator
Node type: identifier
Node type: =
Node type: method_invocation
Method name: b'employeeService.saveEmployee(employeeDTO)'
processing external method
Node type: identifier
Node type: .
Node type: identifier
Node type: argument_list
Node type: (
Node type: identifier
Node type: )
Node type: ;
Node type: return_statement
Node type: return
Node type: object_creation_expression
Node type: new
Node type: generic_type
Node type: type_identifier
Node type: type_arguments
Node type: <
Node type: >
Node type: argument_list
Node type: (
Node type: identifier
Node type: ,
Node type: field_access
Node type: identifier
Node type: .
Node type: identifier
Node type: )
Node type: ;
Node type: }
Node type: block
Node type: {
Node type: local_variable_declaration
Node type: type_identifier
Node type: variable_declarator
Node type: identifier
Node type: =
Node type: method_invocati

[<Record m=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:25' labels=frozenset({'Method'}) properties={'return_type': 'ResponseEntity<EmployeeDTO>', 'name': 'createEmployee', 'content': '@PostMapping\n    public ResponseEntity<EmployeeDTO> createEmployee(@RequestBody EmployeeDTO employeeDTO) {\n        EmployeeDTO savedEmployee = employeeService.saveEmployee(employeeDTO);\n        return new ResponseEntity<>(savedEmployee, HttpStatus.CREATED);\n    }'}>>,
 <Record m=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:26' labels=frozenset({'Method'}) properties={'return_type': 'ResponseEntity<EmployeeDTO>', 'name': 'getEmployeeById', 'content': '@GetMapping("/{id}")\n    public ResponseEntity<EmployeeDTO> getEmployeeById(@PathVariable Long id) {\n        EmployeeDTO employee = employeeService.getEmployeeById(id);\n        return sendOk(employee);\n    }'}>>,
 <Record m=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:27' labels=frozenset({'Method'}) properties={'r

In [180]:
method_declaration_nodes[1].children[-1].children[1].children[1].children[-1].children[-1].children[1].children

[]

In [208]:
class ASTProcessingError(Exception):
    """Custom exception for AST processing errors"""
    pass

class NodeNotFoundError(ASTProcessingError):
    """Raised when a required AST node is not found"""
    pass

In [223]:
def traverse_method_body(node, method_node_id):
    if not node:
        return
    print("Node type:", node.type)
    if node.type == "method_invocation":
        print("Method name:", node.text)
        process_method_invocation(node, method_node_id)
    for child in node.children:
        traverse_method_body(child, method_node_id)

In [228]:
def process_method_invocation(node, method_node_id):
    # consider method chains
    if any(child.type == "." for child in node.children):
        print("processing external method")
        first_identifier_node = next(child for child in node.children if child.type == "identifier")
        field_node = graphDBConn.execute_query(
            """
            MATCH (c:Class) WHERE elementId(c) = $class_node_id
            MATCH (c)<-[:MEMBER_OF]-(f:Field{name:$field_name}) RETURN f
            """,
            class_node_id=class_node_id,
            field_name=first_identifier_node.text.decode('utf-8')
        )
        if len(field_node.records)>0:
            field_node_id = field_node.records[0].get("f").element_id
            type_node =graphDBConn.execute_query(
                """
                MATCH (f:Field) WHERE elementId(f) = $field_node_id
                MATCH (f) -[:TYPE]-> (c:Class) RETURN c
                """,
                field_node_id=field_node_id
            )
            if len(type_node.records)>0:
                type_node_id = type_node.records[0].get("c").element_id
                identifier_nodes = [child for child in node.children if child.type == "identifier"]
                second_identifier_node = identifier_nodes[1] if len(identifier_nodes) > 1 else None
                calling_method_node = graphDBConn.execute_query(
                    """
                    MATCH (c:Class) WHERE elementId(c) = $type_node_id
                    MATCH (c)<-[:MEMBER_OF]-(m:Method{name:$method_name}) RETURN m
                    """,
                    type_node_id=type_node_id,
                    method_name=second_identifier_node.text.decode('utf-8')
                )
                if len(calling_method_node.records)>0:
                    calling_method_node_id = calling_method_node.records[0].get("m").element_id
                    graphDBConn.execute_query("""
                        MATCH (m:Method) WHERE elementId(m) = $method_node_id
                        MATCH (cm:Method) WHERE elementId(cm) = $calling_method_node_id
                        MERGE (m)-[c:CALLS]->(cm) RETURN m,c,cm
                        """,
                        method_node_id=method_node_id,
                        calling_method_node_id=calling_method_node_id
                    )
                else:
                    graphDBConn.execute_query("""
                        MATCH (m:Method) WHERE elementId(m) = $method_node_id
                        MATCH (t:Class) WHERE elementId(t) = $type_node_id
                        CREATE (cm:Method{name:$calling_method_name})
                        MERGE (cm) -[:MEMBER_OF]->(t)
                        MERGE (m)-[c:CALLS]->(cm) RETURN m,c,cm
                        """,
                        method_node_id=method_node_id,
                        calling_method_name=second_identifier_node.text.decode('utf-8'),
                        type_node_id=type_node_id
                    )
            else:
                raise NodeNotFoundError("Field type node not found")
        else:
            print("path not implemented")
            # Handle other cases
        # this
        # super
        # import
        # import with wildcard
        # fully qualified name
        # nested class
    else:
        first_identifier_node = next(child for child in node.children if child.type == "identifier")
        calling_method_node = graphDBConn.execute_query(
            """
            MATCH (c:Class) WHERE elementId(c) = $class_node_id
            MATCH (m:Method{name:$method_name}) -[:MEMBER_OF]->(c)
            RETURN m
            """,
            method_name=first_identifier_node.text.decode('utf-8'),
            class_node_id=class_node_id
        )
        if len(calling_method_node.records)>0:
            calling_method_node_id = calling_method_node.records[0].get("m").element_id
            graphDBConn.execute_query("""
                MATCH (m:Method) WHERE elementId(m) = $method_node_id
                MATCH (cm:Method) WHERE elementId(cm) = $calling_method_node_id
                MERGE (m)-[c:CALLS]->(cm) RETURN m,c,cm
                """,
                method_node_id=method_node_id,
                calling_method_node_id=calling_method_node_id
            )
        else:
            print("processing static import")
        #need to check if it is a static import or not

In [183]:
method_nodes

[<Record m=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:25' labels=frozenset({'Method'}) properties={'return_type': 'ResponseEntity<EmployeeDTO>', 'name': 'createEmployee', 'content': '@PostMapping\n    public ResponseEntity<EmployeeDTO> createEmployee(@RequestBody EmployeeDTO employeeDTO) {\n        EmployeeDTO savedEmployee = employeeService.saveEmployee(employeeDTO);\n        return new ResponseEntity<>(savedEmployee, HttpStatus.CREATED);\n    }'}>>,
 <Record m=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:26' labels=frozenset({'Method'}) properties={'return_type': 'ResponseEntity<EmployeeDTO>', 'name': 'getEmployeeById', 'content': '@GetMapping("/{id}")\n    public ResponseEntity<EmployeeDTO> getEmployeeById(@PathVariable Long id) {\n        EmployeeDTO employee = employeeService.getEmployeeById(id);\n        return sendOk(employee);\n    }'}>>,
 <Record m=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:27' labels=frozenset({'Method'}) properties={'r