In [1]:
from neo4j import GraphDatabase
from neo4j.exceptions import ClientError
from sklearn.manifold import TSNE

import numpy as np
import altair as alt
import pandas as pd
import os

from module.neo4j.graph_db import GraphDB
driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "erclab"))

In [2]:
result = {"label": [], "count": []}
with driver.session() as session:
    for row in session.run("CALL db.labels()"):
        label = row["label"]
        query = f"MATCH (:`{label}`) RETURN count(*) as count"
        count = session.run(query).single()["count"]
        result["label"].append(label)
        result["count"].append(count)
nodes_df = pd.DataFrame(data=result)
nodes_df.sort_values("count")

result = {"relType": [], "count": []}
with driver.session() as session:
    for row in session.run("CALL db.relationshipTypes()"):
        relationship_type = row["relationshipType"]
        query = f"MATCH ()-[:`{relationship_type}`]->() RETURN count(*) as count"
        count = session.run(query).single()["count"]
        result["relType"].append(relationship_type)
        result["count"].append(count)
rels_df = pd.DataFrame(data=result)
rels_df.sort_values("count")

print(nodes_df)
print(rels_df)

        label  count
0        User    541
1  Restaurant     20
2      Review    555
3        City     14
4     Country      2
5        Attr    106
6      Aspect     40
7        Menu     82
          relType  count
0      HAS_FRIEND     14
1      HAS_REVIEW    555
2    WRITE_REVIEW    555
3           VISIT    555
4            RATE    551
5      LOCATED_IN     34
6      HAS_ASPECT     40
7              IS    130
8        HAS_MENU     82
9           ENJOY      1
10          ORDER     13
11           LOVE      2
12           WANT      1
13       PAID_FOR      1
14           ROLL      1
15           LIKE      2
16       COME_FOR      1
17          BRING      1
18        PROMISE      1
19            EAT      1
20    DISASSEMBLE      1
21         EXPECT      1
22          CHECK      1
23  OBSESSED_WITH      1
24        WENT_TO      1
25        ENJOYED      1
26           SPOT      1


In [3]:
with driver.session() as session:
    result = session.run("""
    CALL gds.alpha.node2vec.stream({
       nodeProjection: "Restaurant",
       relationshipProjection: {
         has_menu: {
           type: "HAS_MENU",
           orientation: "UNDIRECTED"
        }
       },
       embeddingSize: 10,
       iterations: 10,
       walkLength: 10
    })
    YIELD nodeId, embedding
    RETURN gds.util.asNode(nodeId).name AS restaurant, embedding
    """)

    embeddings_df = pd.DataFrame([dict(record) for record in result])
embeddings_df.head(20)

Unnamed: 0,restaurant,embedding
0,Silver Spoon,"[0.703555166721344, 0.8387281894683838, -0.235..."
1,Silver Spoon,"[0.0020898347720503807, -0.4304209053516388, 0..."
2,Silver Spoon,"[0.843287467956543, 0.12618207931518555, 0.472..."
3,Silver Spoon,"[-0.17489485442638397, -0.7488159537315369, 0...."
4,Spice South,"[-0.06723348796367645, 0.8675433993339539, 0.1..."
5,Bombay Buffet Indian Cuisine,"[-0.34533876180648804, 0.08820705860853195, -0..."
6,Canbe Foods,"[-0.3411579132080078, 0.3910974860191345, -0.9..."
7,Red Chillez,"[-0.7224925756454468, 0.23916402459144592, -0...."
8,Lena's Roti & Doubles,"[-0.4543846845626831, -0.21689669787883759, -0..."
9,Tangra Villa Hakka Chinese Cuisine,"[0.47712406516075134, 0.2918386459350586, 0.77..."


In [4]:
with driver.session() as session:
    result = session.run("""
    CALL gds.alpha.node2vec.write({
       nodeProjection: "Restaurant",
       relationshipProjection: {
         has_menu: {
           type: "HAS_MENU",
           orientation: "UNDIRECTED"
        }
       },
       embeddingSize: 10,
       iterations: 10,
       walkLength: 10,
       writeProperty: $embeddingProperty
    })
    """, {"embeddingProperty": "embeddingNode2vec"})

    embeddings_df = pd.DataFrame([dict(record) for record in result])
embeddings_df


Unnamed: 0,nodeCount,nodePropertiesWritten,createMillis,computeMillis,writeMillis,configuration
0,20,20,2,10,62,"{'initialLearningRate': 0.025, 'writeConcurren..."


In [5]:
with driver.session() as session:
    result = session.run("""
    MATCH (rest:Restaurant)-[:HAS_MENU]->(menu:Menu)
    WHERE menu.name IN $menus
    RETURN rest.name AS restaurant, rest.embeddingNode2vec AS embedding, menu.name AS menu
    """, {"menus": ["Naan", "Chicken Biryani", "Butter Chicken"]})
    X = pd.DataFrame([dict(record) for record in result])
X.head(20)

Unnamed: 0,restaurant,embedding,menu
0,Silver Spoon,"[0.5603567957878113, -0.4897201359272003, 0.00...",Chicken Biryani
1,Silver Spoon,"[0.15289141237735748, -0.8653870224952698, 0.9...",Naan
2,Silver Spoon,"[0.15289141237735748, -0.8653870224952698, 0.9...",Butter Chicken
3,Silver Spoon,"[0.15289141237735748, -0.8653870224952698, 0.9...",Chicken Biryani
4,Silver Spoon,"[-0.8681433200836182, 0.09290854632854462, -0....",Naan
5,Bombay Buffet Indian Cuisine,"[-0.08530911803245544, -0.19394470751285553, -...",Chicken Biryani
6,Bombay Buffet Indian Cuisine,"[-0.08530911803245544, -0.19394470751285553, -...",Butter Chicken
7,Bombay Buffet Indian Cuisine,"[-0.08530911803245544, -0.19394470751285553, -...",Naan
8,Soma Grill,"[0.3156459331512451, -0.7007921934127808, 0.23...",Naan
9,Watan Kabob,"[0.5198988914489746, -0.9050036668777466, 0.82...",Naan


In [6]:
X_embedded = TSNE(n_components=2, random_state=6).fit_transform(list(X.embedding))
restaurants = list(X.restaurant)
df = pd.DataFrame(data = {
    "restaurant": restaurants,
    "menu": X.menu,
    "x": [value[0] for value in list(X_embedded)],
    "y": [value[1] for value in list(X_embedded)]
})
df.head(20)

Unnamed: 0,restaurant,menu,x,y
0,Silver Spoon,Chicken Biryani,23.985943,-90.52182
1,Silver Spoon,Naan,196.981461,-133.632812
2,Silver Spoon,Butter Chicken,42.126427,311.980164
3,Silver Spoon,Chicken Biryani,-120.773132,-58.216709
4,Silver Spoon,Naan,-236.107742,92.292038
5,Bombay Buffet Indian Cuisine,Chicken Biryani,63.579052,-249.224533
6,Bombay Buffet Indian Cuisine,Butter Chicken,209.350006,190.939804
7,Bombay Buffet Indian Cuisine,Naan,113.392563,23.721502
8,Soma Grill,Naan,-64.235985,71.692924
9,Watan Kabob,Naan,-114.598076,-222.779633


In [7]:
chart = alt.Chart(df).mark_circle(size=60).encode(
    x='x',
    y='y',
    tooltip=['restaurant', 'menu']
).properties(width=700, height=400)
chart.save('node2vec.json')
chart