In [5]:
from tree_sitter import Parser, Language
import tree_sitter_java as tsj
from neo4j import GraphDatabase

class Neo4jConnection:
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

In [6]:
language = Language(tsj.language())
parser = Parser(language=language)
parser

<tree_sitter.Parser at 0x106463ba0>

In [7]:
with open("../data/java/employee-crud/src/main/java/com/java_ast_knowledge_graph_poc/employee_crud/controller/EmployeeController.java", "r", encoding="utf-8") as file:
    source_code = file.read()

source_code

'package com.java_ast_knowledge_graph_poc.employee_crud.controller;\n\nimport java.util.List;\n\nimport org.springframework.beans.factory.annotation.Autowired;\nimport org.springframework.http.HttpStatus;\nimport org.springframework.http.ResponseEntity;\nimport org.springframework.web.bind.annotation.DeleteMapping;\nimport org.springframework.web.bind.annotation.GetMapping;\nimport org.springframework.web.bind.annotation.PathVariable;\nimport org.springframework.web.bind.annotation.PostMapping;\nimport org.springframework.web.bind.annotation.PutMapping;\nimport org.springframework.web.bind.annotation.RequestBody;\nimport org.springframework.web.bind.annotation.RequestMapping;\nimport org.springframework.web.bind.annotation.RestController;\n\nimport com.java_ast_knowledge_graph_poc.employee_crud.dto.EmployeeDTO;\nimport com.java_ast_knowledge_graph_poc.employee_crud.service.EmployeeService;\n\n@RestController\n@RequestMapping("/api/employees")\npublic class EmployeeController {\n\n    p

In [8]:
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node

In [9]:
graphDBConn = Neo4jConnection(uri="neo4j://localhost:7687", user="neo4j", password="password")

In [10]:
children_nodes = root_node.children
class_declration_node = children_nodes[-1]
package_declaration_node = children_nodes[0]

create package node

In [59]:
print(f"name: {package_declaration_node.children[1].text.decode('utf-8')}")
print(f"content: {package_declaration_node.text.decode('utf-8')}")

package_node = graphDBConn.driver.execute_query(
    "MERGE (p:Package {name: $name, content: $content}) RETURN p",
    name=package_declaration_node.children[1].text.decode('utf-8'),
    content=package_declaration_node.text.decode('utf-8')
)

package_node

name: com.java_ast_knowledge_graph_poc.employee_crud.controller
content: package com.java_ast_knowledge_graph_poc.employee_crud.controller;


EagerResult(records=[<Record p=<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:9' labels=frozenset({'Package'}) properties={'name': 'com.java_ast_knowledge_graph_poc.employee_crud.controller', 'content': 'package com.java_ast_knowledge_graph_poc.employee_crud.controller;'}>>], summary=<neo4j._work.summary.ResultSummary object at 0x106ec0ef0>, keys=['p'])

In [14]:
package_node.records[0].get("p")

<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:0' labels=frozenset({'Package'}) properties={'name': 'com.java_ast_knowledge_graph_poc.employee_crud.controller', 'content': 'package com.java_ast_knowledge_graph_poc.employee_crud.controller;'}>

create class node

In [60]:
rest_path = class_declration_node.children[0].children[1].children[-1].children[1].children[1].text.decode('utf-8')
class_name =class_declration_node.children[-2].text.decode('utf-8')
content = " ".join(list(map(lambda x: x.text.decode('utf-8'), class_declration_node.children[:-1])))
# Get the package node ID from the previous query
package_id = package_node.records[0].get("p").element_id

# Create class and relationship in one query
class_node = graphDBConn.driver.execute_query(
    """
    MATCH (pkg:Package) WHERE elementId(pkg) = $package_id
    MERGE (c:Class {name: $name, rest_path: $rest_path, content: $content})
    MERGE (c)-[:BELONGS_TO]->(pkg)
    RETURN c
    """,
    package_id=package_id,
    name=class_name,
    rest_path=rest_path,
    content=content
)

class_node.records[0].get("c")

<Node element_id='4:006ce47f-db97-446a-8d48-00e2d825b218:10' labels=frozenset({'Class'}) properties={'name': 'EmployeeController', 'content': '@RestController\n@RequestMapping("/api/employees")\npublic class EmployeeController', 'rest_path': '/api/employees'}>

process import declarations

In [61]:
import_declaration_nodes = children_nodes[1:-1]

In [62]:
for import_declaration_node in import_declaration_nodes:
    import_class_name = import_declaration_node.children[1].children[-1].text.decode('utf-8')
    import_package_name = import_declaration_node.children[1].children[0].text.decode('utf-8')
    import_statement = import_declaration_node.text.decode('utf-8')
    print(f"import_class_name: {import_class_name}")
    print(f"import_package_name: {import_package_name}")
    print(f"import_statement: {import_statement}")
    import_node = graphDBConn.driver.execute_query(
        """
        MATCH (c1:Class {name: $class_name})
        MERGE (pkg:Package {name: $import_package_name})
        MERGE (c:Class {name: $import_class_name})
        MERGE (c)-[:BELONGS_TO]->(pkg)
        MERGE (c1)-[:IMPORTS {content: $import_statement}]->(c)
        RETURN c
        """,
        import_package_name=import_package_name,
        import_class_name=import_class_name,
        class_name=class_name,
        import_statement=import_statement
    )


import_class_name: List
import_package_name: java.util
import_statement: import java.util.List;
import_class_name: Autowired
import_package_name: org.springframework.beans.factory.annotation
import_statement: import org.springframework.beans.factory.annotation.Autowired;
import_class_name: HttpStatus
import_package_name: org.springframework.http
import_statement: import org.springframework.http.HttpStatus;
import_class_name: ResponseEntity
import_package_name: org.springframework.http
import_statement: import org.springframework.http.ResponseEntity;
import_class_name: DeleteMapping
import_package_name: org.springframework.web.bind.annotation
import_statement: import org.springframework.web.bind.annotation.DeleteMapping;
import_class_name: GetMapping
import_package_name: org.springframework.web.bind.annotation
import_statement: import org.springframework.web.bind.annotation.GetMapping;
import_class_name: PathVariable
import_package_name: org.springframework.web.bind.annotation
import_st