## Preparation

Install the required packages with `pip install "neo4j-graphrag[ollama]"`.

In [None]:
from neo4j import GraphDatabase
from neo4j_graphrag.indexes import create_vector_index, upsert_vectors
from neo4j_graphrag.llm import OllamaLLM
from neo4j_graphrag.embeddings import OllamaEmbeddings
from neo4j_graphrag.types import EntityType
from neo4j_graphrag.retrievers import Text2CypherRetriever, VectorRetriever
import json
import ast
from typing import Any

# Insert path to your STIX ATT&CK file
with open("attack-stix-data/enterprise-attack-17.1.json", "r") as file:
    data = json.load(file)

# Insert your Neo4j instance URL and credentials
URI = "neo4j+s://6224f1f3.databases.neo4j.io"
AUTH = ("neo4j", "XXXXX")

driver = GraphDatabase.driver(URI, auth=AUTH)

# ollama run ollama run gemma3:27b-it-q8_0
llm = OllamaLLM(model_name="gemma3:27b-it-qat")

# ollama run rjmalagon/gte-qwen2-7b-instruct:f16
embedder = OllamaEmbeddings(model="rjmalagon/gte-qwen2-7b-instruct:f16")

## KG Construction

In [None]:
def flatten_dict(d) -> dict[Any, str]:
    return {k: str(v) for k, v in d.items()}


def create_sdo_node(tx, label, props) -> None:
    query = f"CREATE (n:SDO:{{label} $props)"
    tx.run(query, props=props)


def create_relation(tx, src, tgt, rel_type, props) -> None:
    query = f"""
    MATCH (a {{{id: $source_id}})
    MATCH (b {{id: $target_id}})
    MERGE (a)-[r:{rel_type}]->(b)
    SET r += $props
    """
    tx.run(query, source_id=src, target_id=tgt, props=props)


with driver.session() as session:
    for obj in data["objects"]:
        flat_obj = flatten_dict(obj)
        if flat_obj["type"] == "relationship":
            rel_type = flat_obj["relationship_type"].replace("-", "_")
            props = {
                k: v
                for k, v in flat_obj.items()
                if k not in ("source_ref", "target_ref", "relationship_type")
            }
            session.execute_write(
                create_relation,
                flat_obj["source_ref"],
                flat_obj["target_ref"],
                rel_type,
                props,
            )
        else:
            sdo_type = flat_obj["type"].replace("-", "_")
            props = {k: v for k, v in flat_obj.items() if k != "type"}
            session.execute_write(create_sdo_node, sdo_type, props)

In [5]:
def create_contains_technique_relation(tx, tech_id, tac_name) -> None:
    query = f"""
    MATCH (a:attack_pattern {{id: "{tech_id}"}}), (b:x_mitre_tactic {{x_mitre_shortname: "{tac_name}"}})
    MERGE (b)-[r:contains_technique]->(a)
    """
    tx.run(query)


def create_component_of_relation(tx, dc_id, ds_id) -> None:
    query = f"""
    MATCH (a:x_mitre_data_component {{id: "{dc_id}"}}), (b:x_mitre_data_source {{id: "{ds_id}"}})
    MERGE (a)-[r:component_of]->(b)
    """
    tx.run(query)


attack_patterns = [obj for obj in data["objects"] if obj["type"] == "attack-pattern"]
data_components = [
    obj for obj in data["objects"] if obj["type"] == "x-mitre-data-component"
]

with driver.session() as session:
    for ap in attack_patterns:
        ap_id = ap["id"]

        for phase in ap["kill_chain_phases"]:
            phase_name = phase["phase_name"]
            session.execute_write(create_contains_technique_relation, ap_id, phase_name)

    for dc in data_components:
        dc_id = dc["id"]
        ds_id = dc["x_mitre_data_source_ref"]
        session.execute_write(create_component_of_relation, dc_id, ds_id)

## Create and populate Vector Index

In [6]:
INDEX_NAME = "SDOs"

create_vector_index(
    driver,
    INDEX_NAME,
    label="SDO",
    embedding_property="embedding",
    dimensions=3584,
    similarity_fn="cosine",
)

In [4]:
with driver.session() as session:
    result = session.run("MATCH (n) RETURN n")
    for record in result:
        node = record["n"]

        if node["description"]:
            vector = embedder.embed_query(f"{node["name"]}\n\n{node["description"]}")
            upsert_vectors(
                driver,
                ids=[node.element_id],
                embedding_property="embedding",
                embeddings=[vector],
                entity_type=EntityType.NODE,
            )
        elif node["name"]:
            vector = embedder.embed_query(node["name"])
            upsert_vectors(
                driver,
                ids=[node.element_id],
                embedding_property="embedding",
                embeddings=[vector],
                entity_type=EntityType.NODE,
            )

## Similarity Search

Vector Similarity + Text2Cypher

In [None]:
QUERY = "Which tools did the INC Ransum group use?"
VECTOR_RETRIEVAL_TOP_K = 10

print(f"Query: {QUERY}\n")
print("Step 1: Retrieve nodes from the graph via Vector Similarity.\n")

vector_retriever = VectorRetriever(driver, "SDOs", embedder)
result = vector_retriever.search(query_text=QUERY, top_k=VECTOR_RETRIEVAL_TOP_K)

nodes_str = ""
for item in result.items:
    if not item.metadata:
        continue

    content = ast.literal_eval(item.content)

    nodes_str += f"{item.metadata["nodeLabels"][1]} {{name: '{content["name"]}'}}\n"

print(nodes_str)

SCHEMA = """Node properties:
attack_pattern {name: STRING, description: STRING}
campaign {name: STRING, description: STRING}
course_of_action {name: STRING, description: STRING}
identity {name: STRING, description: STRING}
intrusion_set {name: STRING, description: STRING}
malware {name: STRING, description: STRING}
tool {name: STRING, description: STRING}
x_mitre_data_source {name: STRING, description: STRING}
x_mitre_data_component {name: STRING, description: STRING}
x_mitre_tactic {name: STRING, description: STRING}
Relationship properties:
attributed_to {}
component_of {}
contains_technique {}
detects {description: STRING}
mitigates {description: STRING}
subtechnique_of {}
uses {description: STRING}
The relationships:
(:course_of_action)-[:mitigates]->(:attack_pattern)
(:malware)-[:uses]->(:attack_pattern)
(:tool)-[:uses]->(:attack_pattern)
(:x_mitre_tactic)-[:contains_technique]->(:attack_pattern)
(:attack_pattern)-[:subtechnique_of]->(:attack_pattern)
(:x_mitre_data_component)-[:component_of]->(:x_mitre_data_source)
(:x_mitre_data_component)-[:detects]->(:attack_pattern)
(:intrusion_set)-[:uses]->(:malware)
(:intrusion_set)-[:uses]->(:tool)
(:intrusion_set)-[:uses]->(:attack_pattern)
(:campaign)-[:attributed_to]->(:intrusion_set)
(:campaign)-[:uses]->(:tool)
(:campaign)-[:uses]->(:attack_pattern)
(:campaign)-[:uses]->(:malware)
"""

PROMPT = """Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.

Schema:
{schema}
Retrieved Nodes:
{nodes}
Example:
How to prevent Phishing attacks?
MATCH (co:course_of_action)-[m:mitigates]->(ap:attack_pattern)
WHERE ap.name = 'Phishing'
RETURN co.name, co.description, m.description

Input:
{query_text}

Do not use any properties or relationships not included in the schema.
Retrieved Nodes have been retrieved from the graph in advance by using Vector Similarity. You may use them, but you do not have to. 
Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response.
Return name and description of relevant relationships and nodes. Prioritise relationship information.

Cypher query:
"""

print("Step 2: Generate Cypher query and use it to query the graph.\n")
print(PROMPT.format(schema=SCHEMA, nodes=nodes_str, query_text=QUERY))

# later: LLM evaluates each retrieved nodes' relevancy?

txt2cypher_retriever = Text2CypherRetriever(
    driver=driver, llm=llm, neo4j_schema=SCHEMA, custom_prompt=PROMPT
)

result = txt2cypher_retriever.search(
    query_text=QUERY,
    prompt_params={"schema": SCHEMA, "nodes": nodes_str},
)

context_str = ""
for item in result.items:
    context_str += f"{item.content}\n"

cypher = ""
if result.metadata:
    cypher = result.metadata.get("cypher")
    print(cypher)

print(context_str)

print("Step 3: Answer the question.\n")

SYSTEM_INSTRUCTION = "Answer the user question using the provided context, which has been retrieved from a graph using the provided Cypher query."
final_prompt = f"""Cypher Query:
{cypher}
Context:
{context_str}
Question:
{QUERY}

Answer:
"""

print(final_prompt)

answer = llm.invoke(input=final_prompt, system_instruction=SYSTEM_INSTRUCTION)

print(answer.content)


Query: Which tools did the INC Ransum group use???

Step 1: Retrieve nodes from the graph via Vector Similarity.

intrusion_set {name: 'INC Ransom'}
malware {name: 'INC Ransomware'}
campaign {name: 'C0015'}
intrusion_set {name: 'Rancor'}
malware {name: 'RansomHub'}
intrusion_set {name: 'Cinnamon Tempest'}
malware {name: 'DEATHRANSOM'}
malware {name: 'Regin'}
intrusion_set {name: 'Taidoor'}
intrusion_set {name: 'Indrik Spider'}

Step 2: Generate Cypher query and use it to query the graph.

Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.

Schema:
Node properties:
attack_pattern {name: STRING, description: STRING}
campaign {name: STRING, description: STRING}
course_of_action {name: STRING, description: STRING}
identity {name: STRING, description: STRING}
intrusion_set {name: STRING, description: STRING}
malware {name: STRING, description: STRING}
tool {name: STRING, description: STRING}
x_mitre_data_source {name: STRING, description: STRING}
x_mitr

Naives Text2Cypher

In [18]:
QUERY = "Which tools did the INC Ransum group use?"


SCHEMA = """Node properties:
attack_pattern {name: STRING, description: STRING}
campaign {name: STRING, description: STRING}
course_of_action {name: STRING, description: STRING}
identity {name: STRING, description: STRING}
intrusion_set {name: STRING, description: STRING}
malware {name: STRING, description: STRING}
tool {name: STRING, description: STRING}
x_mitre_data_source {name: STRING, description: STRING}
x_mitre_data_component {name: STRING, description: STRING}
x_mitre_tactic {name: STRING, description: STRING}
Relationship properties:
attributed_to {}
component_of {}
contains_technique {}
detects {description: STRING}
mitigates {description: STRING}
subtechnique_of {}
uses {description: STRING}
The relationships:
(:course_of_action)-[:mitigates]->(:attack_pattern)
(:malware)-[:uses]->(:attack_pattern)
(:tool)-[:uses]->(:attack_pattern)
(:x_mitre_tactic)-[:contains_technique]->(:attack_pattern)
(:attack_pattern)-[:subtechnique_of]->(:attack_pattern)
(:x_mitre_data_component)-[:component_of]->(:x_mitre_data_source)
(:x_mitre_data_component)-[:detects]->(:attack_pattern)
(:intrusion_set)-[:uses]->(:malware)
(:intrusion_set)-[:uses]->(:tool)
(:intrusion_set)-[:uses]->(:attack_pattern)
(:campaign)-[:attributed_to]->(:intrusion_set)
(:campaign)-[:uses]->(:tool)
(:campaign)-[:uses]->(:attack_pattern)
(:campaign)-[:uses]->(:malware)
"""

PROMPT = """Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.

Schema:
{schema}
Example:
How to prevent Phishing attacks?
MATCH (co:course_of_action)-[m:mitigates]->(ap:attack_pattern)
WHERE ap.name = 'Phishing'
RETURN co.name, co.description, m.description

Input:
{query_text}

Do not use any properties or relationships not included in the schema.
Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response.
Return name and description of relevant relationships and nodes. Prioritise relationship information.

Cypher query:
"""

txt2cypher_retriever = Text2CypherRetriever(
    driver=driver, llm=llm, neo4j_schema=SCHEMA, custom_prompt=PROMPT
)

result = txt2cypher_retriever.search(
    query_text=QUERY,
    prompt_params={"schema": SCHEMA},
)

if result.metadata:
    cypher = result.metadata["cypher"]
    print(cypher)

print(len(result.items))

for item in result.items:
    print(item.content)

MATCH (i:intrusion_set)-[u:uses]->(t:tool)
WHERE i.name = 'INC Ransum'
RETURN t.name, t.description, u.description

0
