In [35]:
from graphdatascience import GraphDataScience
from torch_geometric.data import Data, download_url
import torch
import torch.optim as optim
from torch_geometric.nn import TransE

In [2]:
gds = GraphDataScience("bolt://localhost:7687", auth=('neo4j', 'neo4jneo4j'), database="fb15k-237")

In [3]:
url = ('https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237')
raw_file_names = ['train.txt', 'valid.txt', 'test.txt']
raw_dir = './data_from_url'
for filename in raw_file_names:
    download_url(f'{url}/{filename}', raw_dir)

Using existing file train.txt
Using existing file valid.txt
Using existing file test.txt


In [4]:
def process():
    data_list_, node_dict_, rel_dict_ = [], {}, {}
    for file_name in raw_file_names:
        file_name_path = raw_dir + '/' + file_name
        with open(file_name_path, 'r') as f:
            data = [x.split('\t') for x in f.read().split('\n')[:-1]]

        edge_index = torch.empty((2, len(data)), dtype=torch.long)
        edge_type = torch.empty(len(data), dtype=torch.long)
        for i, (src, rel, dst) in enumerate(data):
            if src not in node_dict_:
                node_dict_[src] = len(node_dict_)
            if dst not in node_dict_:
                node_dict_[dst] = len(node_dict_)
            if rel not in rel_dict_:
                rel_dict_[rel] = len(rel_dict_)

            edge_index[0, i] = node_dict_[src]
            edge_index[1, i] = node_dict_[dst]
            edge_type[i] = rel_dict_[rel]

        data = Data(edge_index=edge_index, edge_type=edge_type)
        data_list_.append(data)

    for data in data_list_:
        data.num_nodes = len(node_dict_)

    return data_list_, node_dict_, rel_dict_

data_list, node_dict, rel_dict = process()

In [5]:
gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")

ClientError: {code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists} {message: An equivalent constraint already exists, 'Constraint( id=4, name='entity_id', type='UNIQUENESS', schema=(:Entity {id}), ownedIndex=3 )'.}

In [6]:
rel_id_to_text_dict = {}
for k in rel_dict:
    text = k
    id = rel_dict[k]
    rel_id_to_text_dict[id] = text

In [7]:
print(data_list)
print(data_list[0].edge_index[0][1].item())
print(data_list[0].edge_type)

[Data(edge_index=[2, 272115], edge_type=[272115], num_nodes=14541), Data(edge_index=[2, 17535], edge_type=[17535], num_nodes=14541), Data(edge_index=[2, 20466], edge_type=[20466], num_nodes=14541)]
2
tensor([  0,   1,   2,  ..., 170,  30,  38])


In [8]:
def write_chunk(chunk_dict):
    gds.run_cypher(
            "UNWIND $nodes AS node CREATE (n:Entity {id: node[1], value: node[0]})",
            params={"nodes": list(chunk_dict.items())},
        )
    print(f"Written {len(chunk_dict)} elements...")

idx = 0
chunk_size = 1000
chunk_dict = {}
for k in node_dict:
    chunk_dict[k] = node_dict[k]
    idx += 1
    if idx % chunk_size == 0:
        write_chunk(chunk_dict)
        chunk_dict = {}
if len(chunk_dict) > 0:
    write_chunk(chunk_dict)
print(f"TOTAL records: {idx} from {len(node_dict)}")

Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 541 elements...
TOTAL records: 14541 from 14541


In [9]:
train_data = data_list[0]
val_data = data_list[1]
test_data = data_list[2]

In [10]:
def write_rel_chunk(ll:list, label):
    gds.run_cypher(
            "UNWIND $list AS l MATCH (e_s:Entity {id: l[0]}), (e_t:Entity {id: l[1]}) "+
            "CREATE (e_s)-["+label+" { rel_id: l[2], text: l[3] }]->(e_t)",
            params={"list": ll},
        )
    print(f"Written {len(ll)} elements...")


def create_rels(data:Data, label:str):
    idx = 0
    chunk_size = 1000
    chunk_list = []
    print("Writing " + label + " relationships")
    for i in range(data.num_edges):
        source = data.edge_index[0, i].item()
        target = data.edge_index[1, i].item()
        id = data.edge_type[i].item()
        text = rel_id_to_text_dict[id]
        l = [source, target, id, text]
        chunk_list.append(l)
        idx += 1
        if idx % chunk_size == 0:
            write_rel_chunk(chunk_list, label)
            chunk_list = []
    if len(chunk_list) > 0:
        write_rel_chunk(chunk_list, label)
    print(f"TOTAL records: {idx} from {data.num_edges}")

create_rels(test_data, ":TEST")
create_rels(val_data, ":VAL")
create_rels(train_data, ":TRAIN")

Writing :TEST relationships
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 466 elements...
TOTAL records: 20466 from 20466
Writing :VAL relationships
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 elements...
Written 1000 e

In [11]:
# Node: (:Entity {id:int, value:str})
# Edge: [:(TRAIN|TEST|VAL) {rel_id:int, text:str}]

In [13]:
def get_data_from_db(edge_label):
    node_projection = {"Entity": {"properties": "id"}}
    relationship_projection = {edge_label : {"orientation": "NATURAL", "properties": "rel_id"}}
    G, result = gds.graph.project(
        "fb15k-graph-t"+edge_label,
        node_projection,
        relationship_projection,
    )
    print(f"The projection took {result['projectMillis']} ms")

    # We can use convenience methods on `G` to check if the projection looks correct
    print(f"Graph '{G.name()}' node count: {G.node_count()}")
    print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
    print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")

    return G

def get_whole_dataset():
    node_projection = {"Entity": {"properties": "id"}}
    relationship_projection = {
        "TRAIN" : {"orientation": "NATURAL", "properties": "rel_id"},
        "TEST" : {"orientation": "NATURAL", "properties": "rel_id"},
        "VAL" : {"orientation": "NATURAL", "properties": "rel_id"},
    }
    G, result = gds.graph.project(
        "fb15k-graph-whole",
        node_projection,
        relationship_projection,
    )
    print(f"The projection took {result['projectMillis']} ms")

    # We can use convenience methods on `G` to check if the projection looks correct
    print(f"Graph '{G.name()}' node count: {G.node_count()}")
    print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
    print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")

    return G

In [14]:
train_db_data_G = get_data_from_db("TRAIN")
test_db_data_G = get_data_from_db("TEST")
val_db_data_G = get_data_from_db("VAL")
db_data_G = get_whole_dataset()

The projection took 47 ms
Graph 'fb15k-graph-tTRAIN' node count: 14541
Graph 'fb15k-graph-tTRAIN' node labels: ['Entity']
Graph 'fb15k-graph-tTRAIN' relationship count: 272115
The projection took 9 ms
Graph 'fb15k-graph-tTEST' node count: 14541
Graph 'fb15k-graph-tTEST' node labels: ['Entity']
Graph 'fb15k-graph-tTEST' relationship count: 20466
The projection took 22 ms
Graph 'fb15k-graph-tVAL' node count: 14541
Graph 'fb15k-graph-tVAL' node labels: ['Entity']
Graph 'fb15k-graph-tVAL' relationship count: 17535
The projection took 51 ms
Graph 'fb15k-graph-whole' node count: 14541
Graph 'fb15k-graph-whole' node labels: ['Entity']
Graph 'fb15k-graph-whole' relationship count: 310116


In [None]:
# gds.graph.drop(train_db_data_G)
# gds.graph.drop(test_db_data_G)
# gds.graph.drop(val_db_data_G)

In [15]:
print(db_data_G)

Graph(name=fb15k-graph-whole, node_count=14541, relationship_count=310116)


In [23]:
node_properties = gds.graph.nodeProperties.stream(
    db_data_G,
    ["id"],
    separate_property_columns=True,
)
print(node_properties)

       nodeId    id
0       10000  9695
1       10001  9696
2       10002  9697
3       10003  9698
4       10004  9699
...       ...   ...
14536    9995  9690
14537    9996  9691
14538    9997  9692
14539    9998  9693
14540    9999  9694

[14541 rows x 2 columns]


In [24]:
nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))

In [25]:
sample_topology_df = gds.beta.graph.relationships.stream(db_data_G)
# Let's see what we got:
display(sample_topology_df)

Unnamed: 0,sourceNodeId,targetNodeId,relationshipType
0,0,811,VAL
1,0,2664,VAL
2,0,4671,VAL
3,0,9126,VAL
4,0,2,TRAIN
...,...,...,...
310111,9999,3249,TRAIN
310112,9999,5257,TRAIN
310113,9999,6475,TRAIN
310114,9999,11890,TRAIN


In [28]:
rels_tmp = gds.graph.relationshipProperties.stream(db_data_G, ["rel_id"], separate_property_columns=True)
display(rels_tmp)
rels_tmp.rel_id.astype(int)
display(rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]))

Unnamed: 0,sourceNodeId,targetNodeId,relationshipType,rel_id
0,0,811,VAL,118.0
1,0,2664,VAL,195.0
2,0,4671,VAL,0.0
3,0,9126,VAL,23.0
4,0,2,TRAIN,0.0
...,...,...,...,...
310111,9999,3249,TRAIN,15.0
310112,9999,5257,TRAIN,154.0
310113,9999,6475,TRAIN,94.0
310114,9999,11890,TRAIN,62.0


KeyboardInterrupt: 

In [33]:
topology = [rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x])]
edge_index = torch.tensor(topology, dtype=torch.long)
edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)
display(edge_index)
display(edge_type)
data = Data(edge_index=edge_index, edge_type=edge_type)
data.num_nodes = len(nodeId_to_id)
print(data)

tensor([[    0,     0,     0,  ...,  9694,  9694,  9694],
        [  774,  2550,  4494,  ...,  6225, 11585, 11585]])

tensor([118, 195,   0,  ...,  94,  62, 127])

Data(edge_index=[2, 310116], edge_type=[310116], num_nodes=14541)


In [34]:
def create_tensor(graph):
    rels_tmp = gds.graph.relationshipProperties.stream(graph, ["rel_id"], separate_property_columns=True)
    topology = [rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x])]
    edge_index = torch.tensor(topology, dtype=torch.long)
    edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)
    data = Data(edge_index=edge_index, edge_type=edge_type)
    data.num_nodes = len(nodeId_to_id)
    display(data)
    return data

train_tensor = create_tensor(train_db_data_G)
test_tensor = create_tensor(test_db_data_G)
val_tensor = create_tensor(val_db_data_G)

Data(edge_index=[2, 272115], edge_type=[272115], num_nodes=14541)

Data(edge_index=[2, 20466], edge_type=[20466], num_nodes=14541)

Data(edge_index=[2, 17535], edge_type=[17535], num_nodes=14541)

In [39]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = TransE(
    num_nodes=train_tensor.num_nodes,
    num_relations=train_tensor.num_edge_types,
    hidden_channels=50,
).to(device)

loader = model.loader(
    head_index=train_tensor.edge_index[0],
    rel_type=train_tensor.edge_type,
    tail_index=train_tensor.edge_index[1],
    batch_size=1000,
    shuffle=True,
)

optimizer = optim.Adam(model.parameters(), lr=0.01)


def train():
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples


@torch.no_grad()
def test(data):
    model.eval()
    return model.test(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=20000,
        k=10,
    )


for epoch in range(1, 501):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 75 == 0:
        rank, hits = test(val_tensor)
        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
              f'Val Hits@10: {hits:.4f}')

print(model)
idx = torch.LongTensor([1])
print(model.rel_emb(idx))
rank, hits_at_10 = test(test_tensor)
print(f'Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}')

Epoch: 001, Loss: 0.7623
Epoch: 002, Loss: 0.5531
Epoch: 003, Loss: 0.4303
Epoch: 004, Loss: 0.3451
Epoch: 005, Loss: 0.2911
Epoch: 006, Loss: 0.2590
Epoch: 007, Loss: 0.2375
Epoch: 008, Loss: 0.2226
Epoch: 009, Loss: 0.2116
Epoch: 010, Loss: 0.2028
Epoch: 011, Loss: 0.1965
Epoch: 012, Loss: 0.1890
Epoch: 013, Loss: 0.1841
Epoch: 014, Loss: 0.1798
Epoch: 015, Loss: 0.1763
Epoch: 016, Loss: 0.1721
Epoch: 017, Loss: 0.1685
Epoch: 018, Loss: 0.1664
Epoch: 019, Loss: 0.1640
Epoch: 020, Loss: 0.1609
Epoch: 021, Loss: 0.1586
Epoch: 022, Loss: 0.1569
Epoch: 023, Loss: 0.1537
Epoch: 024, Loss: 0.1526
Epoch: 025, Loss: 0.1515
Epoch: 026, Loss: 0.1492
Epoch: 027, Loss: 0.1467
Epoch: 028, Loss: 0.1466
Epoch: 029, Loss: 0.1445
Epoch: 030, Loss: 0.1434
Epoch: 031, Loss: 0.1422
Epoch: 032, Loss: 0.1408
Epoch: 033, Loss: 0.1403
Epoch: 034, Loss: 0.1377
Epoch: 035, Loss: 0.1367
Epoch: 036, Loss: 0.1356
Epoch: 037, Loss: 0.1335
Epoch: 038, Loss: 0.1318
Epoch: 039, Loss: 0.1317
Epoch: 040, Loss: 0.1309


100%|██████████| 17535/17535 [03:00<00:00, 97.29it/s] 


Epoch: 075, Val Mean Rank: 337.85, Val Hits@10: 0.3734
Epoch: 076, Loss: 0.1081
Epoch: 077, Loss: 0.1064
Epoch: 078, Loss: 0.1056
Epoch: 079, Loss: 0.1063
Epoch: 080, Loss: 0.1054
Epoch: 081, Loss: 0.1052
Epoch: 082, Loss: 0.1060
Epoch: 083, Loss: 0.1043
Epoch: 084, Loss: 0.1051
Epoch: 085, Loss: 0.1050
Epoch: 086, Loss: 0.1038
Epoch: 087, Loss: 0.1035
Epoch: 088, Loss: 0.1033
Epoch: 089, Loss: 0.1031
Epoch: 090, Loss: 0.1027
Epoch: 091, Loss: 0.1018
Epoch: 092, Loss: 0.1019
Epoch: 093, Loss: 0.1014
Epoch: 094, Loss: 0.1012
Epoch: 095, Loss: 0.1012
Epoch: 096, Loss: 0.1016
Epoch: 097, Loss: 0.1012
Epoch: 098, Loss: 0.1005
Epoch: 099, Loss: 0.0996
Epoch: 100, Loss: 0.1004
Epoch: 101, Loss: 0.0996
Epoch: 102, Loss: 0.0995
Epoch: 103, Loss: 0.0992
Epoch: 104, Loss: 0.0991
Epoch: 105, Loss: 0.0993
Epoch: 106, Loss: 0.0982
Epoch: 107, Loss: 0.0983
Epoch: 108, Loss: 0.0982
Epoch: 109, Loss: 0.0971
Epoch: 110, Loss: 0.0975
Epoch: 111, Loss: 0.0973
Epoch: 112, Loss: 0.0975
Epoch: 113, Loss: 0.

100%|██████████| 17535/17535 [02:59<00:00, 97.49it/s] 


Epoch: 150, Val Mean Rank: 289.27, Val Hits@10: 0.3764
Epoch: 151, Loss: 0.0908
Epoch: 152, Loss: 0.0914
Epoch: 153, Loss: 0.0913
Epoch: 154, Loss: 0.0918
Epoch: 155, Loss: 0.0917
Epoch: 156, Loss: 0.0915
Epoch: 157, Loss: 0.0906
Epoch: 158, Loss: 0.0905
Epoch: 159, Loss: 0.0902
Epoch: 160, Loss: 0.0897
Epoch: 161, Loss: 0.0901
Epoch: 162, Loss: 0.0903
Epoch: 163, Loss: 0.0903
Epoch: 164, Loss: 0.0891
Epoch: 165, Loss: 0.0894
Epoch: 166, Loss: 0.0897
Epoch: 167, Loss: 0.0906
Epoch: 168, Loss: 0.0892
Epoch: 169, Loss: 0.0882
Epoch: 170, Loss: 0.0900
Epoch: 171, Loss: 0.0897
Epoch: 172, Loss: 0.0890
Epoch: 173, Loss: 0.0895
Epoch: 174, Loss: 0.0887
Epoch: 175, Loss: 0.0880
Epoch: 176, Loss: 0.0889
Epoch: 177, Loss: 0.0891
Epoch: 178, Loss: 0.0886
Epoch: 179, Loss: 0.0882
Epoch: 180, Loss: 0.0875
Epoch: 181, Loss: 0.0887
Epoch: 182, Loss: 0.0884
Epoch: 183, Loss: 0.0889
Epoch: 184, Loss: 0.0874
Epoch: 185, Loss: 0.0885
Epoch: 186, Loss: 0.0881
Epoch: 187, Loss: 0.0877
Epoch: 188, Loss: 0.

100%|██████████| 17535/17535 [03:00<00:00, 97.36it/s] 


Epoch: 225, Val Mean Rank: 272.78, Val Hits@10: 0.3770
Epoch: 226, Loss: 0.0854
Epoch: 227, Loss: 0.0859
Epoch: 228, Loss: 0.0858
Epoch: 229, Loss: 0.0854
Epoch: 230, Loss: 0.0851
Epoch: 231, Loss: 0.0857
Epoch: 232, Loss: 0.0849
Epoch: 233, Loss: 0.0856
Epoch: 234, Loss: 0.0858
Epoch: 235, Loss: 0.0851
Epoch: 236, Loss: 0.0848
Epoch: 237, Loss: 0.0858
Epoch: 238, Loss: 0.0850
Epoch: 239, Loss: 0.0852
Epoch: 240, Loss: 0.0859
Epoch: 241, Loss: 0.0853
Epoch: 242, Loss: 0.0845
Epoch: 243, Loss: 0.0854
Epoch: 244, Loss: 0.0853
Epoch: 245, Loss: 0.0863
Epoch: 246, Loss: 0.0847
Epoch: 247, Loss: 0.0849
Epoch: 248, Loss: 0.0848
Epoch: 249, Loss: 0.0843
Epoch: 250, Loss: 0.0840
Epoch: 251, Loss: 0.0854
Epoch: 252, Loss: 0.0842
Epoch: 253, Loss: 0.0853
Epoch: 254, Loss: 0.0852
Epoch: 255, Loss: 0.0840
Epoch: 256, Loss: 0.0846
Epoch: 257, Loss: 0.0847
Epoch: 258, Loss: 0.0852
Epoch: 259, Loss: 0.0844
Epoch: 260, Loss: 0.0846
Epoch: 261, Loss: 0.0848
Epoch: 262, Loss: 0.0850
Epoch: 263, Loss: 0.

100%|██████████| 17535/17535 [03:00<00:00, 97.25it/s] 


Epoch: 300, Val Mean Rank: 271.31, Val Hits@10: 0.3601
Epoch: 301, Loss: 0.0826
Epoch: 302, Loss: 0.0825
Epoch: 303, Loss: 0.0831
Epoch: 304, Loss: 0.0829
Epoch: 305, Loss: 0.0832
Epoch: 306, Loss: 0.0816
Epoch: 307, Loss: 0.0832
Epoch: 308, Loss: 0.0838
Epoch: 309, Loss: 0.0826
Epoch: 310, Loss: 0.0823
Epoch: 311, Loss: 0.0816
Epoch: 312, Loss: 0.0833
Epoch: 313, Loss: 0.0822
Epoch: 314, Loss: 0.0820
Epoch: 315, Loss: 0.0828
Epoch: 316, Loss: 0.0826
Epoch: 317, Loss: 0.0832
Epoch: 318, Loss: 0.0824
Epoch: 319, Loss: 0.0826
Epoch: 320, Loss: 0.0827
Epoch: 321, Loss: 0.0816
Epoch: 322, Loss: 0.0832
Epoch: 323, Loss: 0.0825
Epoch: 324, Loss: 0.0821
Epoch: 325, Loss: 0.0823
Epoch: 326, Loss: 0.0823
Epoch: 327, Loss: 0.0818
Epoch: 328, Loss: 0.0824
Epoch: 329, Loss: 0.0820
Epoch: 330, Loss: 0.0822
Epoch: 331, Loss: 0.0817
Epoch: 332, Loss: 0.0824
Epoch: 333, Loss: 0.0820
Epoch: 334, Loss: 0.0815
Epoch: 335, Loss: 0.0820
Epoch: 336, Loss: 0.0815
Epoch: 337, Loss: 0.0828
Epoch: 338, Loss: 0.

100%|██████████| 17535/17535 [02:54<00:00, 100.26it/s]


Epoch: 375, Val Mean Rank: 265.51, Val Hits@10: 0.3694
Epoch: 376, Loss: 0.0811
Epoch: 377, Loss: 0.0811
Epoch: 378, Loss: 0.0813
Epoch: 379, Loss: 0.0813
Epoch: 380, Loss: 0.0808
Epoch: 381, Loss: 0.0818
Epoch: 382, Loss: 0.0808
Epoch: 383, Loss: 0.0813
Epoch: 384, Loss: 0.0813
Epoch: 385, Loss: 0.0814
Epoch: 386, Loss: 0.0814
Epoch: 387, Loss: 0.0806
Epoch: 388, Loss: 0.0809
Epoch: 389, Loss: 0.0811
Epoch: 390, Loss: 0.0807
Epoch: 391, Loss: 0.0800
Epoch: 392, Loss: 0.0806
Epoch: 393, Loss: 0.0806
Epoch: 394, Loss: 0.0808
Epoch: 395, Loss: 0.0806
Epoch: 396, Loss: 0.0808
Epoch: 397, Loss: 0.0804
Epoch: 398, Loss: 0.0802
Epoch: 399, Loss: 0.0808
Epoch: 400, Loss: 0.0801
Epoch: 401, Loss: 0.0797
Epoch: 402, Loss: 0.0803
Epoch: 403, Loss: 0.0811
Epoch: 404, Loss: 0.0804
Epoch: 405, Loss: 0.0803
Epoch: 406, Loss: 0.0809
Epoch: 407, Loss: 0.0805
Epoch: 408, Loss: 0.0809
Epoch: 409, Loss: 0.0815
Epoch: 410, Loss: 0.0804
Epoch: 411, Loss: 0.0810
Epoch: 412, Loss: 0.0812
Epoch: 413, Loss: 0.

100%|██████████| 17535/17535 [02:57<00:00, 98.73it/s]


Epoch: 450, Val Mean Rank: 263.85, Val Hits@10: 0.3717
Epoch: 451, Loss: 0.0803
Epoch: 452, Loss: 0.0799
Epoch: 453, Loss: 0.0792
Epoch: 454, Loss: 0.0793
Epoch: 455, Loss: 0.0805
Epoch: 456, Loss: 0.0790
Epoch: 457, Loss: 0.0797
Epoch: 458, Loss: 0.0795
Epoch: 459, Loss: 0.0800
Epoch: 460, Loss: 0.0797
Epoch: 461, Loss: 0.0806
Epoch: 462, Loss: 0.0793
Epoch: 463, Loss: 0.0800
Epoch: 464, Loss: 0.0798
Epoch: 465, Loss: 0.0801
Epoch: 466, Loss: 0.0792
Epoch: 467, Loss: 0.0804
Epoch: 468, Loss: 0.0793
Epoch: 469, Loss: 0.0790
Epoch: 470, Loss: 0.0798
Epoch: 471, Loss: 0.0803
Epoch: 472, Loss: 0.0794
Epoch: 473, Loss: 0.0799
Epoch: 474, Loss: 0.0801
Epoch: 475, Loss: 0.0793
Epoch: 476, Loss: 0.0800
Epoch: 477, Loss: 0.0795
Epoch: 478, Loss: 0.0797
Epoch: 479, Loss: 0.0793
Epoch: 480, Loss: 0.0789
Epoch: 481, Loss: 0.0796
Epoch: 482, Loss: 0.0791
Epoch: 483, Loss: 0.0796
Epoch: 484, Loss: 0.0787
Epoch: 485, Loss: 0.0793
Epoch: 486, Loss: 0.0798
Epoch: 487, Loss: 0.0791
Epoch: 488, Loss: 0.

100%|██████████| 20466/20466 [03:31<00:00, 96.58it/s] 


Test Mean Rank: 271.21, Test Hits@10: 0.3561


In [41]:
torch.save(model, "./model_501.pt")

In [42]:
torch.save(model.state_dict(), "./model_501_st_dict")

In [46]:
rel_123 = 123
print(rel_id_to_text_dict[rel_123])
t = model.rel_emb(torch.LongTensor([rel_123]))
print(t)
rel_123_emb = t[0].tolist()
print(rel_123_emb)
print(len(rel_123_emb))

/base/popstra/celebrity/canoodled./base/popstra/canoodled/participant
tensor([[-0.1675,  0.0308, -0.1782,  0.1345, -0.0045,  0.0340,  0.0159, -0.3096,
         -0.0745, -0.0816, -0.1337, -0.0214,  0.2474,  0.0208,  0.0202, -0.2063,
          0.2327,  0.1317, -0.0562,  0.2252,  0.2233, -0.0494,  0.0689, -0.1421,
         -0.1578, -0.2542, -0.3715, -0.0021, -0.1611,  0.2791,  0.4206, -0.0711,
         -0.1338, -0.0318, -0.0076, -0.0900, -0.0040,  0.0135, -0.2083,  0.3131,
         -0.2107, -0.1513, -0.2384, -0.0324,  0.2208, -0.3213,  0.1846,  0.0350,
          0.0129, -0.3476]], grad_fn=<EmbeddingBackward0>)
[-0.1674661487340927, 0.030775610357522964, -0.17823708057403564, 0.13445991277694702, -0.004469443578273058, 0.03396214172244072, 0.015884431079030037, -0.30958986282348633, -0.07452814280986786, -0.08162254840135574, -0.1337374746799469, -0.021421784535050392, 0.24736957252025604, 0.020815536379814148, 0.020227091386914253, -0.2062523514032364, 0.23273737728595734, 0.1317358613014

In [56]:
print(model.node_emb.weight[777].tolist())


[-0.3172686696052551, 0.1578812301158905, -0.08087439835071564, -0.004596684593707323, -19.778369903564453, 0.026383597403764725, 0.13941125571727753, 0.12102761119604111, -12.489096641540527, 0.010141335427761078, -0.030836760997772217, 0.07849875837564468, 0.00281003350391984, -15.232320785522461, -0.15866008400917053, 0.19668158888816833, -0.19905561208724976, -0.5277314782142639, -0.08227643370628357, -0.17938721179962158, 0.015473323874175549, -0.06681554019451141, 0.11557577550411224, 0.05651519075036049, 0.2853202223777771, -0.05212356895208359, -0.1935097575187683, 0.27248385548591614, 0.0724930614233017, 0.0949799194931984, 0.08945981413125992, 0.04437510669231415, -0.1985391527414322, -18.364181518554688, -0.048563919961452484, -0.08124208450317383, -0.25457608699798584, -0.03215474635362625, -0.48660868406295776, -0.04439295083284378, 0.584082841873169, 0.403334379196167, 0.058653656393289566, 10.785226821899414, 0.1640355885028839, -0.1641729325056076, -0.07735151797533035,

In [59]:
# write embeddings to graph
for i in range(0, len(nodeId_to_id)):
    if i % 100 == 0:
        print(i)
    gds.run_cypher(
            "MATCH (n:Entity {id: $i}) SET n.emb = $EMBEDDING",
            params={
                "i": i,
                "EMBEDDING": model.node_emb.weight[i].tolist()
            },
        )

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
