# Dataset Construction

In this notebook, we construct the data set that will be basis for our experiments. As mentioned in the paper, we are working with the medical domain, more specifically the relations between medical conditions (diseases) and treatments (drugs). We query [Wikidata's SPARQL endpoints](https://query.wikidata.org) using the [SPARQLWrapper library](https://rdflib.github.io/sparqlwrapper/) and process the data with a combintation of [rdflib](https://rdflib.readthedocs.io/en/stable/) and [networkx](http://networkx.github.io).

In [1]:
import rdflib.graph
import networkx as nx
import os.path
import hashlib
import json
import SPARQLWrapper
import datetime as dt

First, we set up a connection to the SPARQL endpoint that we will use throughout. Furhtermore, the helper function `query` executes a SPARQL query, converts it to an `rdflib` graph and caches the result. It it only possible to execute a [limitted number of queries to Wikidata](https://www.mediawiki.org/wiki/Wikidata_Query_Service/User_Manual#Query_limits) so this caching functionality saves as many requests to the endpoint as possible (with an invalidation period of one day).

In [2]:
sparql = SPARQLWrapper.SPARQLWrapper("https://query.wikidata.org/sparql")

cache_dir = "_query_cache"
cache_dict = f"{cache_dir}/_dict.json"

cache_invalidation_periode = dt.timedelta(days=1)


def query(query: str, force=False) -> rdflib.Graph:
    # Ensure that the cache directory exists.
    if not os.path.isfile(cache_dict):
        os.makedirs(cache_dir)
        json.dump({}, open(cache_dict, "w"))

    # Load the cache.
    cache = json.load(open(cache_dict))

    # The MD5 of the query acts as the cache key.
    key = hashlib.md5(query.encode("utf-8")).hexdigest()

    # Is the key is present in the cache, load and return the cached data.
    if not force and key in cache:
        cache_entry = cache[key]
        created_at = dt.datetime.fromtimestamp(cache_entry["timestamp"])

        # If the cache entry is not invalidated, return the cached data.
        if dt.datetime.now() < created_at + cache_invalidation_periode:
            return rdflib.graph.Graph().parse(cache_entry["path"])

    # Execute SPARLQL query.
    sparql.setQuery(query)
    result = sparql.queryAndConvert()

    # Save the response to the cache.
    dest = f"{cache_dir}/{key}.xml"
    result.serialize(dest)
    cache[key] = {"timestamp": dt.datetime.now().timestamp(), "path": dest}
    json.dump(cache, open(cache_dict, "w"))

    return result

We construct our data as a directed graph using the networkx class `DiGraph`.

In [3]:
graph = nx.DiGraph()

## Diseases

The following two queries fetch [instances (P31)](https://www.wikidata.org/wiki/Property:P31) and [subclasses (P279)](https://www.wikidata.org/wiki/Property:P279) of the [disease (Q12136)](https://www.wikidata.org/wiki/Q12136) entity. This will give the graph a hierarcical structure that suits the HAKE link prediction technique. First, we fetch entites that are direct disease instances and add them to the graph with a "is-a" edge to the "disease" node.

In [4]:
disease_instances_query = """
    CONSTRUCT
    WHERE {
        ?d wdt:P31 wd:Q12136.
    }
"""

disease_instances_result = query(disease_instances_query)

In [5]:
for subj, obj in disease_instances_result.subject_objects():
    graph.add_edge(subj, "disease", type="is-a")

Next, we get all instances of entites that are subclasses of the dieases entity. These are added to the graph with a "is-a" edge to the subclass which in turn has a "subclass-of" edge to the "disease" node.

In [6]:
disease_subclasses_query = """
    CONSTRUCT {
      ?disease wdt:P31 ?subclass.
    }
    WHERE {
      ?subclass wdt:P279 wd:Q12136.
      ?disease wdt:P31 ?subclass.
}
"""

disease_subclasses_result = query(disease_subclasses_query)

In [7]:
for subj, obj in disease_subclasses_result.subject_objects():
    graph.add_edge(subj, obj, type="is-a")
    graph.add_edge(obj, "disease", type="subclass-of")

## Medication

In the same way as above, we create a subclass-instance hierarchy for medications based on the [medication (Q12140)](https://www.wikidata.org/wiki/Q12140) entity.


In [8]:
medication_instances_query = """
    CONSTRUCT
    WHERE {
        ?d wdt:P31 wd:Q12140.
    }
"""

medication_instances_result = query(medication_instances_query)

In [9]:
for subj, obj in medication_instances_result.subject_objects():
    graph.add_edge(subj, "meddication", type="is-a")

In [10]:
medication_subclasses_query = """
    CONSTRUCT {
      ?disease wdt:P31 ?subclass.
    }
    WHERE {
      ?subclass wdt:P279 wd:Q12140.
      ?disease wdt:P31 ?subclass.
}
"""

medication_subclasses_result = query(medication_subclasses_query)

In [11]:
for subj, obj in medication_subclasses_result.subject_objects():
    graph.add_edge(subj, obj, type="is-a")
    graph.add_edge(obj, "medication", type="subclass-of")

## Treaments

Now we add the relations between diseases and medications starting with the [drug used for treatment (P2176)](https://www.wikidata.org/wiki/Property:P2176) predicate. If, for some reason, we encounter a medication or disease that is not present in the graph, we ignore it.

In [12]:
treatments_query = """
    CONSTRUCT 
    WHERE {
      ?disease wdt:P2176 ?medication.
    }
"""

treatments_result = query(treatments_query)

In [13]:
for subj, obj in treatments_result.subject_objects():
    if subj not in graph.nodes() or obj not in graph.nodes():
        continue

    graph.add_edge(subj, obj, type="treated-with")

## Symptoms

We repeat the process above for symptoms querying for the [symptoms (P780)](https://www.wikidata.org/wiki/Property:P780) predicate. Since this predicate is not strictly related to diseases in Wikidata, we skip the triples where the subject is not already in the graph.

In [14]:
symptoms_query = """
    CONSTRUCT
    WHERE {
        ?d wdt:P780 ?s.
    }
"""

symptoms_result = query(symptoms_query)

In [15]:
for subj, obj in symptoms_result.subject_objects():
    if subj not in graph.nodes():
        continue

    graph.add_edge(subj, obj, type="has-symptom")
    graph.add_edge(obj, "symptom", type="is-a")

### Causes

Using the [has cause (P828)](https://www.wikidata.org/wiki/Property:P828) predicate, we do the same as above.

In [16]:
causes_query = """
    CONSTRUCT
    WHERE {
        ?d wdt:P828 ?s.
    }
"""

causes_result = query(causes_query)

The result is added to the graph in the same way as before

In [17]:
for subj, obj in causes_result.subject_objects():
    if subj not in graph.nodes():
        continue

    graph.add_edge(subj, obj, type="has-cause")
    graph.add_edge(obj, "cause", type="is-a")

## Graph report

In [18]:
print(f"{graph.number_of_nodes()} nodes")
    
print(f"{graph.number_of_edges()} edges")

for type in set(nx.get_edge_attributes(graph, 'type').values()):
    number = len(
        [(head, tail) for head, tail, data in graph.edges(data=True) if data["type"] == type]
    )
    print(f"{number} {type} edges")

20422 nodes
34216 edges
1001 has-cause edges
26676 is-a edges
1433 has-symptom edges
87 subclass-of edges
5019 treated-with edges


Now, we can save the graph to the file system and load it up elsewhere.

In [19]:
nx.write_gml(graph, "data.gml")

## Conversion to a HAKE-Friendly Format

The HAKE link prediction model expects a certain data format. First of all, the data must be represented as `(head, relation, tail)` triples and be split up into a training, validation, and testing dataset. Second, the entities and relations must be defined in a dictionary structure that maps ther string representation to an integer.

In [20]:
graph = nx.read_gml("data.gml")

In [21]:
entities = set(graph.nodes())
entity_dict = {entity: n for n, entity in enumerate(entities)}

relations = set([data["type"] for _head, _tail, data in graph.edges(data=True)])
relation_dict = {relation: n for n, relation in enumerate(relations)}

In [22]:
triples = [(head, data["type"], tail) for head, tail, data in graph.edges(data=True)]

Now, to split the data, we define a proportion of the triples to be used for validation and testing data. The remaining triples will be used for training.

In [23]:
validation_proportion = 0.1
testing_proportion = 0.1

In [24]:
validation_set_size = round(len(triples) * validation_proportion)
test_set_size = round(len(triples) * testing_proportion)
training_set_size = len(triples) - validation_set_size - test_set_size

training_set = triples[:training_set_size]
validation_set = triples[training_set_size:training_set_size + validation_set_size]
test_set = triples[training_set_size + validation_set_size:]

Finally, we write the entity and relation dictionaries as well as the datasets to disk. The resulting files can be passed directly into the HAKE model.

In [25]:
directory = "hake_data"

if not os.path.isdir(directory):
    os.makedirs(directory)
    
def write_triples(triples, file):
    with open(file, "w") as f:
        for head, relation, tail in triples:
            string = "{}\t{}\t{}\n".format(head, relation, tail)
            f.write(string)

def write_dict(dictionary, file):
    with open(file, "w") as f:
        for key, value in dictionary.items():
            string = "{}\t{}\n".format(value, key)
            f.write(string)
    
write_triples(training_set, f"{directory}/train.txt")
write_triples(validation_set, f"{directory}/valid.txt")
write_triples(test_set, f"{directory}/test.txt")
write_dict(entity_dict, f"{directory}/entities.dict")
write_dict(relation_dict, f"{directory}/relations.dict")