# Knowledge graph embeddings: TransE

<a href="https://colab.research.google.com/github/neo4j/graph-data-science-client/blob/main/examples/kge-distmult.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from graphdatascience import GraphDataScience
import collections
from tqdm import tqdm
import pandas as pd
from neo4j.exceptions import ClientError

In [None]:
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
    NEO4J_AUTH = (
        os.environ.get("NEO4J_USER"),
        os.environ.get("NEO4J_PASSWORD"),
    )
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)

## Downloading and Storing the FB15k-237 Dataset in the Database
Download the FB15k-237 dataset
Extract the required files: train.txt, valid.txt, and test.txt.

Set a constraint for unique id entries to speed up data uploads.

In [None]:
try:
    _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.text IS UNIQUE")
except ClientError:
    print("CONSTRAINT entity_id already exists")

**Creating Entity Nodes**:
   Create a node with the label `Entity`. This node should have properties `id` and `text`. 
   - Syntax: `(:Entity {id: int, text: str})`

**Creating Relationships for Training with PyG**:
   Based on the training stage, create relationships of type `TRAIN`, `TEST`, or `VALID`. Each of these relationships should have a `rel_id` property.
   - Example Syntax: `[:TRAIN {rel_id: int}]`

**Creating Relationships for Prediction with GDS**:
   For the prediction stage, create relationships of a specific type denoted as `REL_i`. Each of these relationships should have `rel_id` and `text` properties.
   - Example Syntax: `[:REL_7 {rel_id: int, text: str}]`

In [None]:
from collections import defaultdict
from ogb.utils.url import download_url
import os
import zipfile

url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
raw_dir = "./data_from_zip"
download_url(f"{url}", raw_dir)

raw_file_names = ["train.txt", "valid.txt", "test.txt"]
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
    for filename in raw_file_names:
        zip_ref.extract(f"Release/{filename}", path=raw_dir)
data_dir = raw_dir + "/" + "Release"

rel_types = {
    "train.txt": "TRAIN",
    "valid.txt": "VALID",
    "test.txt": "TEST",
}
rel_id_to_text_dict = {}
rel_type_dict = collections.defaultdict(list)
rel_dict = {}


def read_data():
    node_id_set = {}
    dataset = defaultdict(lambda: defaultdict(list))
    for file_name in raw_file_names:
        file_name_path = data_dir + "/" + file_name

        with open(file_name_path, "r") as f:
            data = [x.split("\t") for x in f.read().split("\n")[:-1]]

        for i, (src_text, rel_text, dst_text) in enumerate(data):
            if src_text not in node_id_set:
                node_id_set[src_text] = len(node_id_set)
            if dst_text not in node_id_set:
                node_id_set[dst_text] = len(node_id_set)
            if rel_text not in rel_dict:
                rel_dict[rel_text] = len(rel_dict)
                rel_id_to_text_dict[rel_dict[rel_text]] = rel_text

            source = node_id_set[src_text]
            target = node_id_set[dst_text]
            rel_type = "REL_" + str(rel_dict[rel_text])
            rel_split = rel_types[file_name]

            dataset[rel_split][rel_type].append(
                {
                    "source": source,
                    "source_text": src_text,
                    "target": target,
                    "target_text": dst_text,
                    # "rel_text": rel_text,
                }
            )

    print("Number of nodes: ", len(node_id_set))
    for rel_split in dataset:
        print(
            f"Number of relationships of type {rel_split}: ",
            sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),
        )
    return dataset


dataset = read_data()

In [None]:
def put_data_in_db(data):
    for rel_split in tqdm(data, desc="Relationship"):
        for rel_type in tqdm(data[rel_split], mininterval=1, leave=False):
            edges = data[rel_split][rel_type]

            gds.run_cypher(
                f"""
                UNWIND $ll as l
                MERGE (n:Entity {{text:l.source_text}})
                MERGE (m:Entity {{text:l.target_text}})
                MERGE (n)-[:{rel_type}]->(m)
                MERGE (n)-[:{rel_split}]->(m)
                """,
                params={"ll": edges},
            )


put_data_in_db(dataset)

Project all data in graph to get mapping between `id` and internal `nodeId` field from database.

In [None]:
ALL_RELS = dataset["TRAIN"].keys()
G, result = gds.graph.cypher.project(
    """
    MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:"""
    + "|".join(ALL_RELS)
    + """]-(n:Entity)
    RETURN gds.graph.project($graph_name, n, m, {
        sourceNodeLabels: $label,
        targetNodeLabels: $label
    })
    """,  #  Cypher query
    database="neo4j",  #  Target database
    graph_name="G_full",  #  Query parameter
    label="Entity",  #  Query parameter
)

In [None]:
def inspect_graph(G):
    func_names = [
        "name",
        # "database",
        "node_count",
        "relationship_count",
        "node_labels",
        "relationship_types",
        # "degree_distribution", "density", "size_in_bytes", "memory_usage", "exists", "configuration", "creation_time", "modification_time",
    ]
    for func_name in func_names:
        print(f"==={func_name}===: {getattr(G, func_name)()}")


inspect_graph(G)

In [None]:
gds.set_compute_cluster_ip("localhost")

In [None]:
import time

model_name = "fb15k-TransE-128-model-" + str(time.time())
gds.kge.model.train(
    G,
    model_name=model_name,
    scoring_function="TransE",
    embedding_dimension=128,
    num_epochs=100,
    filtered_metrics=False,
    batch_size=32_768,
    optimizer="Adam",
    optimizer_kwargs={"lr": 0.0003},
    epochs_per_val=0,
    do_validation=False,
    do_test=False,
)

Project the graph with all nodes and existing relationships of the selected type.

In [None]:
source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]

source_ids_df = gds.run_cypher(
    "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
    params={"node_text_list": source_node_list},
)
node_ids = source_ids_df["nodeId"].to_list()

rel_label_to_predict = "REL_" + str(rel_dict["/film/film/genre"])

predict_result = gds.kge.model.predict(
    model_name=model_name,
    top_k=3,
    node_ids=node_ids,
    rel_types=[rel_label_to_predict],
)

print(predict_result.to_string())

Retrieve the embedding for the selected relationship from the PyG model. Then, create a GDS TransE model using the graph, node embeddings property, and the embedding for the relationship to be predicted.

In [None]:
source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]
source_ids_df = gds.run_cypher(
    "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
    params={"node_text_list": source_node_list},
)
source_ids_df["nodeId"].to_list()

Now, we can use the model to make prediction.

In [None]:
result = transe_model.predict_stream(
    source_node_filter=source_ids_df.nodeId,
    target_node_filter="Entity",
    relationship_type=rel_label_to_predict,
    top_k=3,
    concurrency=4,
)
print(result)

Augment the predicted result with node identifiers and their text values.

In [None]:
ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))

ids_to_text = gds.run_cypher(
    "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id",
    params={"ids": ids_in_result},
)

nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))
nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))

result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))
result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))

print(result)

## Using Write Mode

Write mode allows you to write results directly to the database as a new relationship type. This approach helps to avoid mapping from `nodeId` to `id`.

In [None]:
write_relationship_type = "PREDICTED_" + rel_label_to_predict
result_write = transe_model.predict_write(
    source_node_filter=source_ids_df.nodeId,
    target_node_filter="Entity",
    relationship_type=rel_label_to_predict,
    write_relationship_type=write_relationship_type,
    write_property="transe_score",
    top_k=3,
    concurrency=4,
)

Extract the result from the database.

In [None]:
gds.run_cypher(
    "MATCH (n)-[r:"
    + write_relationship_type
    + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score"
)

In [None]:
gds.graph.drop(G_test)