# Knowledge Graph Embedding: DistMult embedding for Nation dataset

In this notebook, we will use the DistMult embedding model to make predictions on the Nations dataset.
The Nations dataset is a simple dataset that contains relationships between countries.

The dataset contains three files: `train.txt`, `valid.txt`, and `test.txt`.
Each file contains triplets of the form `source_country relation target_country`.
The `entity2id.txt` file contains the mapping of country names to ids, and the `relation2id.txt` file contains the mapping of relation names to ids.

## Setup

We start by installing and importing our dependencies, and setting up our GDS client connection to the database.

In [None]:
import os
import time
import warnings
from collections import defaultdict
from neo4j.exceptions import ClientError
from tqdm import tqdm
from graphdatascience import GraphDataScience

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
    NEO4J_AUTH = (
        os.environ.get("NEO4J_USER"),
        os.environ.get("NEO4J_PASSWORD"),
    )
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)

Create constraints to ensure that the `Entity` nodes have unique `text` properties.

In [None]:
try:
    _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.text IS UNIQUE")
except ClientError:
    print("CONSTRAINT entity_id already exists")

## Download and read the data

Let's download the Nations dataset and read the data.

In [None]:
def get_text_to_id_map(data_dir, text_to_id_filename):
    with open(data_dir + "/" + text_to_id_filename, "r") as f:
        data = [x.split("\t") for x in f.read().split("\n")[:-1]]
    text_to_id_map = {text: int(id) for text, id in data}
    return text_to_id_map


def read_data():
    rel_types = {
        "train.txt": "TRAIN",
        "valid.txt": "VALID",
        "test.txt": "TEST",
    }
    url = "https://raw.githubusercontent.com/ZhenfengLei/KGDatasets/master/Nations"
    data_dir = "./Nations"

    raw_file_names = ["train.txt", "valid.txt", "test.txt"]
    node_id_filename = "entity2id.txt"
    rel_id_filename = "relation2id.txt"

    for file in raw_file_names + [node_id_filename, rel_id_filename]:
        if not os.path.exists(f"{data_dir}/{file}"):
            os.system(f"wget {url}/{file} -P {data_dir}")

    node_map = get_text_to_id_map(data_dir, node_id_filename)
    rel_map = get_text_to_id_map(data_dir, rel_id_filename)
    dataset = defaultdict(lambda: defaultdict(list))

    rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2}

    for file_name in raw_file_names:
        file_name_path = data_dir + "/" + file_name

        with open(file_name_path, "r") as f:
            data = [x.split("\t") for x in f.read().split("\n")[:-1]]

        for i, (src_text, rel_text, dst_text) in enumerate(data):
            source = node_map[src_text]
            target = node_map[dst_text]
            rel_type = "REL_" + rel_text.upper()
            rel_split = rel_types[file_name]

            dataset[rel_split][rel_type].append(
                {
                    "source": source,
                    "source_text": src_text,
                    "target": target,
                    "target_text": dst_text,
                    "rel_type": rel_type,
                    "rel_id": rel_map[rel_text],
                    "rel_split": rel_split,
                    "rel_split_id": rel_split_id[rel_split],
                }
            )

    print("Number of nodes: ", len(node_map))
    for rel_split in dataset:
        print(
            f"Number of relationships of type {rel_split}: ",
            sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),
        )
    return dataset, node_map


dataset, node_map = read_data()

## Put data in the database

We will put the data in the database, creating `Entity` nodes and relationships between them.

Each node will have a `text` property. We will use `text` to identify the node later.

Each relationship will have a `split` property to indicate whether it is part of the training, validation, or test set.

In [None]:
def put_data_in_db():
    res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes")
    if res["num_nodes"].values[0] > 0:
        print("Data already in db, number of nodes: ", res["num_nodes"].values[0])
        return
    pbar = tqdm(
        desc="Putting data in db",
        total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),
    )

    for rel_split in dataset:
        for rel_type in dataset[rel_split]:
            edges = dataset[rel_split][rel_type]

            gds.run_cypher(
                f"""
                UNWIND $ll as l
                MERGE (n:Entity {{id:l.source, text:l.source_text}})
                MERGE (m:Entity {{id:l.target, text:l.target_text}})
                MERGE (n)-[:{rel_type} {{split: l.rel_split_id, rel_id: l.rel_id}}]->(m)
                """,
                params={"ll": edges},
            )
            pbar.update(len(edges))
    pbar.close()

    for rel_split in dataset:
        res = gds.run_cypher(
            f"""
            MATCH ()-[r:{rel_split}]->()
            RETURN COUNT(r) AS numberOfRelationships
            """
        )
        print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)


put_data_in_db()

## Project graphs

First, we will project the full graph, then we will filter the graph to create the training graph based on the `split` property.

In [None]:
def project_graphs():
    all_rels = gds.run_cypher(
        """
            CALL db.relationshipTypes() YIELD relationshipType
        """
    )
    all_rels = all_rels["relationshipType"].to_list()
    all_rels = {rel: {"properties": "split"} for rel in all_rels if rel.startswith("REL_")}
    gds.graph.drop("fullGraph", failIfMissing=False)

    G_full, _ = gds.graph.project("fullGraph", ["Entity"], all_rels)

    return G_full


G = project_graphs()

In [None]:
G.relationship_types()

We will train a knowledge graph embedding model using the Graph Data Science library. The model will be trained on the `G` graph.

We will use the DistMult scoring function and set the embedding dimension to 64. The model will be trained for 30 epochs with a split ratio of 80% for training, 10% for validation, and 10% for testing.

After training the model, we will use it to make predictions on three specific nodes: "brazil", "uk", and "jordan". We will predict the top 3 relationships for each node and print the results.

Finally, we will create new relationships in the graph based on the predicted relationships. For each predicted relationship, we will create a new relationship between the corresponding nodes.

In [None]:
gds.set_compute_cluster_ip("localhost")

model_name = "dummyModelName_" + str(time.time())

gds.kge.model.train(
    G,
    model_name=model_name,
    scoring_function="DistMult",
    num_epochs=30,
    embedding_dimension=64,
    split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1},
)

predict_result = gds.kge.model.predict(
    model_name=model_name,
    top_k=3,
    node_ids=[
        gds.find_node_id(["Entity"], {"text": "brazil"}),
        gds.find_node_id(["Entity"], {"text": "uk"}),
        gds.find_node_id(["Entity"], {"text": "jordan"}),
    ],
    rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"],
)

print(predict_result.to_string())

In the next cell we will add this top scored relationships to te database.

In [None]:
for index, row in predict_result.iterrows():
    h = row["head"]
    r = row["rel"]
    gds.run_cypher(
        f"""
        UNWIND $tt as t
        MATCH (a:Entity WHERE id(a) = {h})
        MATCH (b:Entity WHERE id(b) = t)
        MERGE (a)-[:NEW_REL_{r}]->(b)
    """,
        params={"tt": row["tail"]},
    )

There is also a API that can be used to score a list of triplets. In the next cell we will use a call to score the triplets `(brazil, REL_RELNGO, uk)` and `(brazil, REL_RELDIPLOMACY, jordan)`.

In [None]:
brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"})
uk_node = gds.find_node_id(["Entity"], {"text": "uk"})
jordan_node = gds.find_node_id(["Entity"], {"text": "jordan"})

triplets = [
    (brazil_node, "REL_RELNGO", uk_node),
    (brazil_node, "REL_RELDIPLOMACY", jordan_node),
]

scores = gds.kge.model.score_triplets(
    model_name=model_name,
    triplets=triplets,
)

print(scores)