In [160]:
import torch
import pandas as pd
import numpy as np
from torch.nn import Linear
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import json

import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero

from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected, RandomLinkSplit

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [87]:
from neo4j import GraphDatabase, basic_auth

url= 'bolt://localhost:7687'
user = 'neo4j'
password = '12345678'

driver = GraphDatabase.driver(url, auth=(user, password))

def fetch_data(query, params={}):
  with driver.session() as session:
    result = session.run(query, params)
    return pd.DataFrame([r.values() for r in result], columns=result.keys())

In [88]:
fetch_data("""
CALL gds.graph.project('articles', ['Article', 'Colour'], 
  {IS_COLOUR: {orientation:'UNDIRECTED'}})
""")

Unnamed: 0,nodeProjection,relationshipProjection,graphName,nodeCount,relationshipCount,projectMillis
0,"{'Article': {'label': 'Article', 'properties':...","{'IS_COLOUR': {'orientation': 'UNDIRECTED', 'a...",articles,7824,15550,118


In [89]:
fetch_data("""
CALL gds.fastRP.write('articles', {writeProperty:'fastrp', embeddingDimension:56})
""")

Unnamed: 0,nodeCount,nodePropertiesWritten,preProcessingMillis,computeMillis,writeMillis,configuration
0,7824,7824,1,50,631,"{'writeConcurrency': 4, 'nodeSelfInfluence': 0..."


In [90]:
def load_node(cypher, index_col, encoders=None, **kwargs):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher)
    df.set_index(index_col, inplace=True)
    # Define node mapping
    mapping = {index: i for i, index in enumerate(df.index.unique())}
    # Define node features
    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    return x, mapping

In [91]:
def load_edge(cypher, src_index_col, src_mapping, dst_index_col, dst_mapping,
                  encoders=None, **kwargs):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher)
    # Define edge index
    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])
    # Define edge features
    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)

    return edge_index, edge_attr

In [92]:
class SequenceEncoder(object):

    def __init__(self, model_name='all-MiniLM-L6-v2', device=None):
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)

    @torch.no_grad()
    def __call__(self, df):
        x = self.model.encode(df.values, show_progress_bar=True,
                              convert_to_tensor=True, device=self.device)
        return x.cpu()

In [93]:
class GenresEncoder(object):

    def __init__(self, sep='|'):
        self.sep = sep

    def __call__(self, df):
        genres = set(g for col in df.values for g in col.split(self.sep))
        mapping = {genre: i for i, genre in enumerate(genres)}

        x = torch.zeros(len(df), len(mapping))
        for i, col in enumerate(df.values):
            for genre in col.split(self.sep):
                x[i, mapping[genre]] = 1
        return x

In [94]:
class IdentityEncoder(object):

    def __init__(self, dtype=None, is_list=False):
        self.dtype = dtype
        self.is_list = is_list

    def __call__(self, df):
        if self.is_list:
            return torch.stack([torch.tensor(el) for el in df.values])
        return torch.from_numpy(df.values).to(self.dtype)

In [95]:
customer_query = """
MATCH (c:Customer) RETURN c.customer_id AS customerId
"""

customer_x, customer_mapping = load_node(customer_query, index_col='customerId')

In [96]:
article_query = """
MATCH (a:Article)-[:IN_SECTION]->(section:Section)
WITH a, collect(section.sectionName) AS section_list
RETURN a.article_id AS article_id, a.articleName AS articleName, apoc.text.join(section_list, '|') AS sections, a.fastrp AS fastrp
"""

article_x, article_mapping = load_node(
    article_query, 
    index_col='article_id', encoders={
        'articleName': SequenceEncoder(),
        'sections': GenresEncoder(),
        'fastrp': IdentityEncoder(is_list=True)
    })

Batches:   0%|          | 0/243 [00:00<?, ?it/s]

In [136]:
count_query = """
MATCH (c:Customer)-[b:BOUGHT]->(a:Article) 
RETURN c.customer_id AS customerId, a.article_id AS articleId, b.count AS bCount
"""

edge_index, edge_label = load_edge(
    count_query,
    src_index_col='customerId',
    src_mapping=customer_mapping,
    dst_index_col='articleId',
    dst_mapping=article_mapping,
    encoders={'bCount': IdentityEncoder(dtype=torch.long)},
)

In [145]:
data = HeteroData()
data['customer'].x = torch.eye(len(customer_mapping), device=device)
data['article'].x = article_x
data['customer', 'bought', 'article'].edge_index = edge_index
data['customer', 'bought', 'article'].edge_label = edge_label
data.to(device, non_blocking=True)

print(edge_label)

tensor([1, 1, 1,  ..., 1, 1, 1])


In [146]:
data = ToUndirected()(data)
del data['article', 'rev_bought', 'customer'].edge_label  # Remove "reverse" label.

# 2. Perform a link-level split into training, validation, and test edges.
transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('customer', 'bought', 'article')],
    rev_edge_types=[('article', 'rev_bought', 'customer')],
)
train_data, val_data, test_data = transform(data)
print(train_data)
print(train_data['customer', 'article'])

HeteroData(
  [1mcustomer[0m={ x=[9867, 9867] },
  [1marticle[0m={ x=[7775, 494] },
  [1m(customer, bought, article)[0m={
    edge_index=[2, 8000],
    edge_label=[8000],
    edge_label_index=[2, 8000]
  },
  [1m(article, rev_bought, customer)[0m={ edge_index=[2, 8000] }
)
{'edge_index': tensor([[2931, 2585, 2449,  ..., 3020, 9426, 9559],
        [7432, 3634, 4517,  ..., 7071,  915, 7088]]), 'edge_label': tensor([1, 1, 1,  ..., 1, 1, 1]), 'edge_label_index': tensor([[2931, 2585, 2449,  ..., 3020, 9426, 9559],
        [7432, 3634, 4517,  ..., 7071,  915, 7088]])}


In [147]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['customer'][row], z_dict['article'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

In [148]:
print(train_data['customer', 'article'].edge_label)
weight = torch.bincount(train_data['customer', 'article'].edge_label)
weight = weight.max() / weight

def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

tensor([1, 1, 1,  ..., 1, 1, 1])


In [149]:
model = Model(hidden_channels=64).to(device)

In [150]:
# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
    model.encoder(train_data.x_dict, train_data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [151]:
def train():
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['customer', 'bought', 'article'].edge_label_index)
    target = train_data['customer', 'bought', 'article'].edge_label
    loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss)


In [152]:
torch.no_grad()
def test(data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['customer', 'bought', 'article'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['customer', 'bought', 'article'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)

In [153]:
for epoch in range(1, 300):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')

Epoch: 001, Loss: 0.9623, Train: 0.5815, Val: 0.6836, Test: 0.6791
Epoch: 002, Loss: 0.3381, Train: 0.7375, Val: 0.2444, Test: 0.2441
Epoch: 003, Loss: 0.5440, Train: 0.1685, Val: 0.4378, Test: 0.4283
Epoch: 004, Loss: 0.0284, Train: 0.4527, Val: 0.5599, Test: 0.5562
Epoch: 005, Loss: 0.2049, Train: 0.4954, Val: 0.5694, Test: 0.5667
Epoch: 006, Loss: 0.2454, Train: 0.4290, Val: 0.5203, Test: 0.5172
Epoch: 007, Loss: 0.1841, Train: 0.2893, Val: 0.4279, Test: 0.4235
Epoch: 008, Loss: 0.0837, Train: 0.1026, Val: 0.2971, Test: 0.2909
Epoch: 009, Loss: 0.0105, Train: 0.2317, Val: 0.1643, Test: 0.1596
Epoch: 010, Loss: 0.0537, Train: 0.3342, Val: 0.1223, Test: 0.1225
Epoch: 011, Loss: 0.1117, Train: 0.2189, Val: 0.1381, Test: 0.1336
Epoch: 012, Loss: 0.0479, Train: 0.0544, Val: 0.2046, Test: 0.1985
Epoch: 013, Loss: 0.0030, Train: 0.1206, Val: 0.2705, Test: 0.2654
Epoch: 014, Loss: 0.0145, Train: 0.1937, Val: 0.3070, Test: 0.3027
Epoch: 015, Loss: 0.0375, Train: 0.2146, Val: 0.3124, Test: 0.

Epoch: 124, Loss: 0.0000, Train: 0.0010, Val: 0.0358, Test: 0.0343
Epoch: 125, Loss: 0.0000, Train: 0.0010, Val: 0.0358, Test: 0.0343
Epoch: 126, Loss: 0.0000, Train: 0.0009, Val: 0.0359, Test: 0.0344
Epoch: 127, Loss: 0.0000, Train: 0.0008, Val: 0.0361, Test: 0.0346
Epoch: 128, Loss: 0.0000, Train: 0.0008, Val: 0.0364, Test: 0.0349
Epoch: 129, Loss: 0.0000, Train: 0.0008, Val: 0.0365, Test: 0.0350
Epoch: 130, Loss: 0.0000, Train: 0.0008, Val: 0.0365, Test: 0.0351
Epoch: 131, Loss: 0.0000, Train: 0.0008, Val: 0.0364, Test: 0.0350
Epoch: 132, Loss: 0.0000, Train: 0.0007, Val: 0.0363, Test: 0.0348
Epoch: 133, Loss: 0.0000, Train: 0.0007, Val: 0.0361, Test: 0.0346
Epoch: 134, Loss: 0.0000, Train: 0.0007, Val: 0.0359, Test: 0.0345
Epoch: 135, Loss: 0.0000, Train: 0.0007, Val: 0.0359, Test: 0.0344
Epoch: 136, Loss: 0.0000, Train: 0.0007, Val: 0.0360, Test: 0.0345
Epoch: 137, Loss: 0.0000, Train: 0.0007, Val: 0.0361, Test: 0.0347
Epoch: 138, Loss: 0.0000, Train: 0.0007, Val: 0.0363, Test: 0.

Epoch: 247, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 248, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 249, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 250, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 251, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 252, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 253, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 254, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 255, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 256, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 257, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 258, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 259, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 260, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.0347
Epoch: 261, Loss: 0.0000, Train: 0.0003, Val: 0.0362, Test: 0.

In [161]:
with open('convert.txt', 'w') as convert_file:
     convert_file.write(json.dumps(results))

In [230]:
torch.save(model.state_dict(), "model")

In [194]:
torch.save(data.x_dict, 'x_dict.pt')

In [197]:
torch.save(data.edge_index_dict,'edge_index_dict.pt')

In [226]:
torch.save(edge_label_index,'edge_label_index.pt')

In [199]:
with open('reverse_article_mapping.txt', 'w') as convert_file:
     convert_file.write(json.dumps(reverse_article_mapping))

In [200]:
with open('reverse_customer_mapping.txt', 'w') as convert_file:
     convert_file.write(json.dumps(reverse_customer_mapping))

In [203]:
with open('customer_mapping.txt', 'w') as convert_file:
     convert_file.write(json.dumps(customer_mapping))

In [217]:
torch.save(data.metadata(),'metadata.pt')