# Knowledge Graph Embedding: DistMult embedding for Nation dataset

In [None]:
import os
import time
import warnings
from collections import defaultdict
from neo4j.exceptions import ClientError
from tqdm import tqdm
from graphdatascience import GraphDataScience

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
    NEO4J_AUTH = (
        os.environ.get("NEO4J_USER"),
        os.environ.get("NEO4J_PASSWORD"),
    )
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)

In [None]:
try:
    _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
except ClientError:
    print("CONSTRAINT entity_id already exists")

In [None]:
def get_text_to_id_map(data_dir, text_to_id_filename):
    with open(data_dir + "/" + text_to_id_filename, "r") as f:
        data = [x.split("\t") for x in f.read().split("\n")[:-1]]
    text_to_id_map = {text: int(id) for text, id in data}
    return text_to_id_map

In [None]:
def read_data():
    rel_types = {
        "train.txt": "TRAIN",
        "valid.txt": "VALID",
        "test.txt": "TEST",
    }
    url = "https://raw.githubusercontent.com/ZhenfengLei/KGDatasets/master/Nations"
    data_dir = "./Nations"

    raw_file_names = ["train.txt", "valid.txt", "test.txt"]
    node_id_filename = "entity2id.txt"
    rel_id_filename = "relation2id.txt"

    for file in raw_file_names + [node_id_filename, rel_id_filename]:
        if not os.path.exists(f"{data_dir}/{file}"):
            os.system(f"wget {url}/{file} -P {data_dir}")

    node_map = get_text_to_id_map(data_dir, node_id_filename)
    rel_map = get_text_to_id_map(data_dir, rel_id_filename)
    dataset = defaultdict(lambda: defaultdict(list))

    rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2}

    for file_name in raw_file_names:
        file_name_path = data_dir + "/" + file_name

        with open(file_name_path, "r") as f:
            data = [x.split("\t") for x in f.read().split("\n")[:-1]]

        for i, (src_text, rel_text, dst_text) in enumerate(data):
            source = node_map[src_text]
            target = node_map[dst_text]
            rel_type = "REL_" + rel_text.upper()
            rel_split = rel_types[file_name]

            dataset[rel_split][rel_type].append(
                {
                    "source": source,
                    "source_text": src_text,
                    "target": target,
                    "target_text": dst_text,
                    "rel_type": rel_type,
                    "rel_id": rel_map[rel_text],
                    "rel_split": rel_split,
                    "rel_split_id": rel_split_id[rel_split],
                }
            )

    print("Number of nodes: ", len(node_map))
    for rel_split in dataset:
        print(
            f"Number of relationships of type {rel_split}: ",
            sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),
        )
    return dataset


dataset = read_data()

In [None]:
def put_data_in_db():
    res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes")
    if res["num_nodes"].values[0] > 0:
        print("Data already in db, number of nodes: ", res["num_nodes"].values[0])
        return
    dataset = read_data()
    pbar = tqdm(
        desc="Putting data in db",
        total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),
    )

    for rel_split in dataset:
        for rel_type in dataset[rel_split]:
            edges = dataset[rel_split][rel_type]

            gds.run_cypher(
                f"""
                UNWIND $ll as l
                MERGE (n:Entity {{id:l.source, text:l.source_text}})
                MERGE (m:Entity {{id:l.target, text:l.target_text}})
                MERGE (n)-[:{rel_type} {{split: l.rel_split_id, rel_id: l.rel_id}}]->(m)
                """,
                params={"ll": edges},
            )
            pbar.update(len(edges))
    pbar.close()

    for rel_split in dataset:
        res = gds.run_cypher(
            f"""
            MATCH ()-[r:{rel_split}]->()
            RETURN COUNT(r) AS numberOfRelationships
            """
        )
        print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)


put_data_in_db()

In [None]:
def project_graphs():
    all_rels = gds.run_cypher(
        """
            CALL db.relationshipTypes() YIELD relationshipType
        """
    )
    all_rels = all_rels["relationshipType"].to_list()
    all_rels = {rel: {"properties": "split"} for rel in all_rels if rel.startswith("REL_")}
    gds.graph.drop("fullGraph", failIfMissing=False)
    gds.graph.drop("trainGraph", failIfMissing=False)
    gds.graph.drop("validGraph", failIfMissing=False)
    gds.graph.drop("testGraph", failIfMissing=False)

    G_full, _ = gds.graph.project("fullGraph", ["Entity"], all_rels)

    G_train, _ = gds.graph.filter("trainGraph", G_full, "*", "r.split = 0.0")
    G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0")
    G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0")

    gds.graph.drop("fullGraph", failIfMissing=False)

    return G_train, G_valid, G_test


G_train, G_valid, G_test = project_graphs()

In [None]:
gds.set_compute_cluster_ip("localhost")

model_name = "dummyModelName_" + str(time.time())

gds.kge.model.train(
    G_train,
    model_name=model_name,
    scoring_function="distmult",
    num_epochs=1,
    embedding_dimension=10,
    epochs_per_checkpoint=0,
    epochs_per_val=0,
)

predict_result = gds.kge.model.predict(
    model_name=model_name,
    top_k=3,
    node_ids=[
        gds.find_node_id(["Entity"], {"text": "brazil"}),
        gds.find_node_id(["Entity"], {"text": "uk"}),
        gds.find_node_id(["Entity"], {"text": "jordan"}),
    ],
    rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"],
)

print(predict_result.to_string())
#
# gds.kge.model.predict_tail(
#     G_train,
#     model_name=model_name,
#     top_k=10,
#     node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})],
#     rel_types=["REL_1", "REL_2"],
# )
#
# gds.kge.model.score_triples(
#     G_train,
#     model_name=model_name,
#     triples=[
#         (gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})),
#         (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})),
#     ],
# )

In [None]:
# Create the dictionary