## Preparation

Install the required packages with `pip install "neo4j-graphrag[ollama]"`.

In [5]:
from neo4j import GraphDatabase
from neo4j_graphrag.indexes import create_vector_index, upsert_vectors
from neo4j_graphrag.llm import OllamaLLM
from neo4j_graphrag.embeddings import OllamaEmbeddings
from neo4j_graphrag.types import EntityType
from neo4j_graphrag.retrievers import Text2CypherRetriever, VectorRetriever
import json
import ast
from typing import Any

# Insert path to your STIX ATT&CK file
with open("attack-stix-data/enterprise-attack-17.1.json", "r") as file:
    data = json.load(file)

# Insert your Neo4j instance URL and credentials
URI = "neo4j+s://6224f1f3.databases.neo4j.io"
AUTH = ("neo4j", "DBy7vuJuvsbib8F3FRhIXzIFu5vsgPxs31gJoANwMlo")

driver = GraphDatabase.driver(URI, auth=AUTH)

# ollama run ollama run gemma3:27b-it-q8_0
llm = OllamaLLM(model_name="gemma3:27b-it-qat")

# ollama run rjmalagon/gte-qwen2-7b-instruct:f16
embedder = OllamaEmbeddings(model="rjmalagon/gte-qwen2-7b-instruct:f16")

## KG Construction

In [None]:
def flatten_dict(d) -> dict[Any, str]:
    return {k: str(v) for k, v in d.items()}


def create_sdo_node(tx, label, props) -> None:
    query = f"CREATE (n:SDO:{{label} $props)"
    tx.run(query, props=props)


def create_relation(tx, src, tgt, rel_type, props) -> None:
    query = f"""
    MATCH (a {{{id: $source_id}})
    MATCH (b {{id: $target_id}})
    MERGE (a)-[r:{rel_type}]->(b)
    SET r += $props
    """
    tx.run(query, source_id=src, target_id=tgt, props=props)


with driver.session() as session:
    for obj in data["objects"]:
        flat_obj = flatten_dict(obj)
        if flat_obj["type"] == "relationship":
            rel_type = flat_obj["relationship_type"].replace("-", "_")
            props = {
                k: v
                for k, v in flat_obj.items()
                if k not in ("source_ref", "target_ref", "relationship_type")
            }
            session.execute_write(
                create_relation,
                flat_obj["source_ref"],
                flat_obj["target_ref"],
                rel_type,
                props,
            )
        else:
            sdo_type = flat_obj["type"].replace("-", "_")
            props = {k: v for k, v in flat_obj.items() if k != "type"}
            session.execute_write(create_sdo_node, sdo_type, props)

In [5]:
def create_contains_technique_relation(tx, tech_id, tac_name) -> None:
    query = f"""
    MATCH (a:attack_pattern {{id: "{tech_id}"}}), (b:x_mitre_tactic {{x_mitre_shortname: "{tac_name}"}})
    MERGE (b)-[r:contains_technique]->(a)
    """
    tx.run(query)


def create_component_of_relation(tx, dc_id, ds_id) -> None:
    query = f"""
    MATCH (a:x_mitre_data_component {{id: "{dc_id}"}}), (b:x_mitre_data_source {{id: "{ds_id}"}})
    MERGE (a)-[r:component_of]->(b)
    """
    tx.run(query)


attack_patterns = [obj for obj in data["objects"] if obj["type"] == "attack-pattern"]
data_components = [
    obj for obj in data["objects"] if obj["type"] == "x-mitre-data-component"
]

with driver.session() as session:
    for ap in attack_patterns:
        ap_id = ap["id"]

        for phase in ap["kill_chain_phases"]:
            phase_name = phase["phase_name"]
            session.execute_write(create_contains_technique_relation, ap_id, phase_name)

    for dc in data_components:
        dc_id = dc["id"]
        ds_id = dc["x_mitre_data_source_ref"]
        session.execute_write(create_component_of_relation, dc_id, ds_id)

## Create and populate Vector Index

In [6]:
INDEX_NAME = "SDOs"

create_vector_index(
    driver,
    INDEX_NAME,
    label="SDO",
    embedding_property="embedding",
    dimensions=3584,
    similarity_fn="cosine",
)

In [4]:
with driver.session() as session:
    result = session.run("MATCH (n) RETURN n")
    for record in result:
        node = record["n"]

        if node["description"]:
            vector = embedder.embed_query(f"{node["name"]}\n\n{node["description"]}")
            upsert_vectors(
                driver,
                ids=[node.element_id],
                embedding_property="embedding",
                embeddings=[vector],
                entity_type=EntityType.NODE,
            )
        elif node["name"]:
            vector = embedder.embed_query(node["name"])
            upsert_vectors(
                driver,
                ids=[node.element_id],
                embedding_property="embedding",
                embeddings=[vector],
                entity_type=EntityType.NODE,
            )

## Similarity Search

Vector Similarity + Text2Cypher

In [15]:
QUERY = "What tools did the INC Ransum group use?"
VECTOR_RETRIEVAL_TOP_K = 10

vector_retriever = VectorRetriever(driver, "SDOs", embedder)
result = vector_retriever.search(query_text=QUERY, top_k=VECTOR_RETRIEVAL_TOP_K)

nodes_str = ""
for item in result.items:
    if not item.metadata:
        continue

    content = ast.literal_eval(item.content)

    nodes_str += f"{item.metadata["nodeLabels"][1]} {{name: '{content["name"]}'}}\n"

SCHEMA = """
Node properties:
attack_pattern {name: STRING, description: STRING}
campaign {name: STRING, description: STRING}
course_of_action {name: STRING, description: STRING}
identity {name: STRING, description: STRING}
intrusion_set {name: STRING, description: STRING}
malware {name: STRING, description: STRING}
tool {name: STRING, description: STRING}
x_mitre_data_source {name: STRING, description: STRING}
x_mitre_data_component {name: STRING, description: STRING}
x_mitre_tactic {name: STRING, description: STRING}
Relationship properties:
attributed_to {}
component_of {}
contains_technique {}
detects {description: STRING}
mitigates {description: STRING}
subtechnique_of {}
uses {description: STRING}
The relationships:
(:course_of_action)-[:mitigates]->(:attack_pattern)
(:malware)-[:uses]->(:attack_pattern)
(:tool)-[:uses]->(:attack_pattern)
(:x_mitre_tactic)-[:contains_technique]->(:attack_pattern)
(:attack_pattern)-[:subtechnique_of]->(:attack_pattern)
(:x_mitre_data_component)-[:component_of]->(:x_mitre_data_source)
(:x_mitre_data_component)-[:detects]->(:attack_pattern)
(:intrusion_set)-[:uses]->(:malware)
(:intrusion_set)-[:uses]->(:tool)
(:intrusion_set)-[:uses]->(:attack_pattern)
(:campaign)-[:attributed_to]->(:intrusion_set)
(:campaign)-[:uses]->(:tool)
(:campaign)-[:uses]->(:attack_pattern)
(:campaign)-[:uses]->(:malware)
"""

PROMPT = """
Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.

Schema:
{schema}

Retrieved Nodes:
{nodes}
Input:
{query_text}

Do not use any properties or relationships not included in the schema.
Retrieved Nodes have been retrieved from the graph in advance by using Vector Similarity. You may use them, but you do not have to. 
Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response.

Cypher query:
"""

print(nodes_str)

# later: LLM evaluates each retrieved nodes' relevancy? 

txt2cypher_retriever = Text2CypherRetriever(
    driver=driver,
    llm=llm,
    neo4j_schema=SCHEMA,
    custom_prompt=PROMPT
)

result = txt2cypher_retriever.search(
    query_text=QUERY,
    prompt_params={
        "schema": SCHEMA,
        "nodes": nodes_str
    },
)

if result.metadata:
    cypher = result.metadata["cypher"]
    print(cypher)

for item in result.items:
    print(item.content)


intrusion_set {name: 'INC Ransom'}
intrusion_set {name: 'Rancor'}
campaign {name: 'C0015'}
malware {name: 'RansomHub'}
intrusion_set {name: 'Cinnamon Tempest'}
malware {name: 'Regin'}
intrusion_set {name: 'Group5'}
malware {name: 'InnaputRAT'}
malware {name: 'Action RAT'}
attack_pattern {name: 'Tool'}

MATCH (i:intrusion_set {name: 'INC Ransom'})-[:uses]->(t:tool)
RETURN t

<Record t=<Node element_id='4:501f467e-76d1-4670-90ee-68e680883505:937' labels=frozenset({'tool', 'SDO'}) properties={'object_marking_refs': "['marking-definition--fa42a846-8d90-4e51-bc29-71d5b4802168']", 'x_mitre_contributors': "['David Ferguson, CyberSponse']", 'spec_version': '2.1', 'created': '2017-05-31T21:32:31.601Z', 'x_mitre_deprecated': 'False', 'description': 'The [Net](https://attack.mitre.org/software/S0039) utility is a component of the Windows operating system. It is used in command-line operations for control of users, groups, services, and network connections. (Citation: Microsoft Net Utility)\n\n[Ne

Naives Text2Cypher

In [16]:
QUERY = "What tools did the INC Ransum group use?"

SCHEMA = """
Node properties:
attack_pattern {name: STRING, description: STRING}
campaign {name: STRING, description: STRING}
course_of_action {name: STRING, description: STRING}
identity {name: STRING, description: STRING}
intrusion_set {name: STRING, description: STRING}
malware {name: STRING, description: STRING}
tool {name: STRING, description: STRING}
x_mitre_data_source {name: STRING, description: STRING}
x_mitre_data_component {name: STRING, description: STRING}
x_mitre_tactic {name: STRING, description: STRING}
Relationship properties:
attributed_to {}
component_of {}
contains_technique {}
detects {description: STRING}
mitigates {description: STRING}
subtechnique_of {}
uses {description: STRING}
The relationships:
(:course_of_action)-[:mitigates]->(:attack_pattern)
(:malware)-[:uses]->(:attack_pattern)
(:tool)-[:uses]->(:attack_pattern)
(:x_mitre_tactic)-[:contains_technique]->(:attack_pattern)
(:attack_pattern)-[:subtechnique_of]->(:attack_pattern)
(:x_mitre_data_component)-[:component_of]->(:x_mitre_data_source)
(:x_mitre_data_component)-[:detects]->(:attack_pattern)
(:intrusion_set)-[:uses]->(:malware)
(:intrusion_set)-[:uses]->(:tool)
(:intrusion_set)-[:uses]->(:attack_pattern)
(:campaign)-[:attributed_to]->(:intrusion_set)
(:campaign)-[:uses]->(:tool)
(:campaign)-[:uses]->(:attack_pattern)
(:campaign)-[:uses]->(:malware)
"""

retriever = Text2CypherRetriever(
    driver=driver,
    llm=llm,
    neo4j_schema=SCHEMA,
)

result = retriever.search(query_text=QUERY)

if result.metadata:
    cypher = result.metadata["cypher"]
    print(cypher)

print(len(result.items))

for item in result.items:
    print(item.content)

MATCH (c:campaign {name: 'INC Ransum'})-[:uses]->(t:tool)
RETURN t.name

0
