In [7]:
from graphdatascience import GraphDataScience
from torch_geometric.data import Data, download_url
import torch

In [2]:
gds = GraphDataScience("bolt://localhost:7687", auth=('neo4j', 'neo4jneo4j'), database="fb15k-237")

In [3]:
url = ('https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237')
raw_file_names = ['train.txt', 'valid.txt', 'test.txt']
raw_dir = './data_from_url'
for filename in raw_file_names:
    download_url(f'{url}/{filename}', raw_dir)

Downloading https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237/train.txt
Downloading https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237/valid.txt
Downloading https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237/test.txt


In [8]:
def process():
    data_list_, node_dict_, rel_dict_ = [], {}, {}
    for file_name in raw_file_names:
        file_name_path = raw_dir + '/' + file_name
        with open(file_name_path, 'r') as f:
            data = [x.split('\t') for x in f.read().split('\n')[:-1]]

        edge_index = torch.empty((2, len(data)), dtype=torch.long)
        edge_type = torch.empty(len(data), dtype=torch.long)
        for i, (src, rel, dst) in enumerate(data):
            if src not in node_dict_:
                node_dict_[src] = len(node_dict_)
            if dst not in node_dict_:
                node_dict_[dst] = len(node_dict_)
            if rel not in rel_dict_:
                rel_dict_[rel] = len(rel_dict_)

            edge_index[0, i] = node_dict_[src]
            edge_index[1, i] = node_dict_[dst]
            edge_type[i] = rel_dict_[rel]

        data = Data(edge_index=edge_index, edge_type=edge_type)
        data_list_.append(data)

    for data in data_list_:
        data.num_nodes = len(node_dict_)

    return data_list_, node_dict_, rel_dict_

data_list, node_dict, rel_dict = process()

In [9]:
gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")

ClientError: {code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists} {message: An equivalent constraint already exists, 'Constraint( id=4, name='entity_id', type='UNIQUENESS', schema=(:Entity {id}), ownedIndex=3 )'.}

In [10]:
rel_id_to_text_dict = {}
for k in rel_dict:
    text = k
    id = rel_dict[k]
    rel_id_to_text_dict[id] = text

In [11]:
print(data_list)
print(data_list[0].edge_index[0][1].item())
print(data_list[0].edge_type)

[Data(edge_index=[2, 272115], edge_type=[272115], num_nodes=14541), Data(edge_index=[2, 17535], edge_type=[17535], num_nodes=14541), Data(edge_index=[2, 20466], edge_type=[20466], num_nodes=14541)]
2
tensor([  0,   1,   2,  ..., 170,  30,  38])


In [12]:
def write_chunk(chunk_dict):
    gds.run_cypher(
            "UNWIND $nodes AS node CREATE (n:Entity {id: node[1], value: node[0]})",
            params={"nodes": list(chunk_dict.items())},
        )
    print(f"Written {len(chunk_dict)} elements...")

idx = 0
chunk_size = 1000
chunk_dict = {}
for k in node_dict:
    chunk_dict[k] = node_dict[k]
    idx += 1
    if idx % chunk_size == 0:
        write_chunk(chunk_dict)
        chunk_dict = {}
if len(chunk_dict) > 0:
    write_chunk(chunk_dict)
print(f"TOTAL records: {idx} from {len(node_dict)}")

Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 541 elements...
TOTAL records: 14541 from 14541


In [13]:
train_data = data_list[0]
val_data = data_list[1]
test_data = data_list[2]

In [14]:
def write_rel_chunk(ll:list, label):
    gds.run_cypher(
            "UNWIND $list AS l MATCH (e_s:Entity {id: l[0]}), (e_t:Entity {id: l[1]}) "+
            "CREATE (e_s)-["+label+" { rel_id: l[2], text: l[3] }]->(e_t)",
            params={"list": ll},
        )
    print(f"Written {len(ll)} elements...")


def create_rels(data:Data, label:str):
    idx = 0
    chunk_size = 1000
    chunk_list = []
    print("Writing " + label + " relationships")
    for i in range(data.num_edges):
        source = data.edge_index[0, i].item()
        target = data.edge_index[1, i].item()
        id = data.edge_type[i].item()
        text = rel_id_to_text_dict[id]
        l = [source, target, id, text]
        chunk_list.append(l)
        idx += 1
        if idx % chunk_size == 0:
            write_rel_chunk(chunk_list, label)
            chunk_list = []
    if len(chunk_list) > 0:
        write_rel_chunk(chunk_list, label)
    print(f"TOTAL records: {idx} from {data.num_edges}")

create_rels(test_data, ":TEST")
create_rels(val_data, ":VAL")
create_rels(train_data, ":TRAIN")

Writing :TEST relationships
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 466 elements...
TOTAL records: 20466 from 20466
Writing :VAL relationships
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 e