# Knowledge Graph Embedding: Transe embedding for constructed dataset

In [None]:
import os
import time
import warnings
from neo4j.exceptions import ClientError
from graphdatascience import GraphDataScience

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
    NEO4J_AUTH = (
        os.environ.get("NEO4J_USER"),
        os.environ.get("NEO4J_PASSWORD"),
    )
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)

In [None]:
try:
    _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
except ClientError:
    print("CONSTRAINT entity_id already exists")

In [None]:
import pandas

nodes = pandas.DataFrame(
    {
        "nodeId": [0, 1, 2, 3, 7, 10],
        "labels": ["A", "B", "C", "A", "B", "C"],
        "prop1": [42, 1337, 8, 0, 1, 2],
        "otherProperty": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
    }
)

relationships = pandas.DataFrame(
    {
        "sourceNodeId": [0, 1, 2, 7],
        "targetNodeId": [1, 2, 3, 10],
        "relationshipType": ["REL1", "REL1", "REL2", "REL2"],
        "weight": [0.0, 0.0, 0.1, 42.0],
    }
)

gds.graph.drop("my-graph", failIfMissing=False)
G_train = gds.graph.construct(
    "my-graph",  # Graph name
    nodes,  # One or more dataframes containing node data
    relationships,  # One or more dataframes containing relationship data
)

assert "REL1" in G_train.relationship_types()
assert "REL2" in G_train.relationship_types()

In [None]:
G_train.relationship_types()
G_train.node_labels()

In [None]:
gds.set_compute_cluster_ip("localhost")

model_name = "dummyModelName_" + str(time.time())

gds.kge.model.train(
    G_train,
    model_name=model_name,
    scoring_function="transe",
    num_epochs=1,
    embedding_dimension=16,
    epochs_per_checkpoint=0,
    split_ratios={"TRAIN": 0.75, "TEST": 0.25},
)

In [None]:
predict_result = gds.kge.model.predict(
    model_name=model_name,
    top_k=3,
    node_ids=[1, 2, 0, 10, 7],
    rel_types=["REL1", "REL2"],
)

print(predict_result.to_string())