In [52]:
import csv
import pandas as pd
from py2neo import Graph
import json

Connect to Neo4J DB

In [9]:
PATH_CONNECTION = "../env/neo4j_connection.json"
connection_details = {}
with open(PATH_CONNECTION) as connection_file:
    connection_file_read = json.load(connection_file)
    connection_details['bolt_url'] = connection_file_read['bolt_url']
    connection_details['password'] = connection_file_read['password']

In [8]:
graph = Graph(connection_details['bolt_url'], auth=("neo4j", connection_details['password']))

### Write embeddings to the graph

In [14]:
PATH_TO_EMBEDDINGS = '../data/movies.emb'

In [15]:
with open(PATH_TO_EMBEDDINGS) as movies_emb:
    next(movies_emb)
    reader = csv.reader(movies_emb, delimiter=' ')
    
    params = []
    for row in reader:
        movie_id = int(row[0])
        params.append({
            'id': movie_id,
            'embedding': [float(item) for item in row[1:]]
        })
    graph.run("""
    UNWIND {params} AS param
    MATCH (m:Movie) WHERE id(m) = param.id
    SET m.embedding = param.embedding
    """, {"params": params})

In [41]:
def count_nodes(graph, node_type='Movie'):
    # TODO include node_type in the query
    return graph.run("""
    MATCH (m:Movie) RETURN count(m) as num_nodes
    """).to_data_frame()['num_nodes'].values[0]


In [42]:
def do_all_movies_have_embedding(graph):
    num_movies = count_nodes(graph, 'Movie')
    
    num_movies_with_embedding = graph.run("""
    MATCH (m:Movie) WHERE EXISTS(m.embedding) RETURN count(m) as num_movies_with_embedding
    """).to_data_frame()['num_movies_with_embedding'].values[0]
    
    return num_movies == num_movies_with_embedding

In [43]:
do_all_movies_have_embedding(graph)

True

All movie nodes have now the embedding property (100 dimensional array)

### Prepare data (movie, embedding, genres)

In [46]:
movie_genres = graph.run("""
MATCH (m:Movie)-[:IN_GENRE]->(genre)
WITH id(m) AS source, m.embedding as embedding, collect(id(genre)) AS target 
RETURN source, embedding, target
""").to_data_frame()

In [47]:
movie_genres

Unnamed: 0,source,embedding,target
0,0,"[1.1568369, 4.7216134, -4.829661, 2.67412, 4.1...","[1, 2, 3, 4, 6]"
1,5,"[0.3488815, 6.3227067, -3.6201174, -0.45922923...","[1, 3, 6]"
2,7,"[-0.07164516, 4.885974, -5.827049, -0.01449445...","[4, 9]"
3,8,"[0.10417974, -3.2304595, 3.5939155, 1.1410666,...","[9, 10, 4]"
4,11,"[-1.383354, 1.648171, -3.7612596, -1.3233246, ...",[4]
...,...,...,...
9120,9140,"[0.28404692, 0.21873271, -0.3254155, 1.6711218...","[10, 1, 9]"
9121,9141,"[-0.8980954, -1.6762947, -1.2756481, 0.9909565...","[37, 6, 1, 13]"
9122,9142,"[0.6182944, -0.6889125, -0.38154954, 1.0363413...",[49]
9123,9143,"[1.9886293, 1.4495139, 1.4017138, 0.2425547, 1...",[4]


In [48]:
movie_genres_onehot = graph.run("""\
MATCH (genre:Genre)
WITH genre ORDER BY genre.name
WITH collect(id(genre)) AS genres
MATCH (m:Movie)-[:IN_GENRE]->(genre)
WITH genres, id(m) AS source, m.embedding AS embedding, collect(id(genre)) AS target
RETURN source, embedding, [g in genres | CASE WHEN g in target THEN 1 ELSE 0 END] AS genres
""")
data = pd.DataFrame([dict(row) for row in movie_genres_onehot])

In [49]:
data

Unnamed: 0,source,embedding,genres
0,0,"[1.1568369, 4.7216134, -4.829661, 2.67412, 4.1...","[0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
1,5,"[0.3488815, 6.3227067, -3.6201174, -0.45922923...","[0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
2,7,"[-0.07164516, 4.885974, -5.827049, -0.01449445...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,8,"[0.10417974, -3.2304595, 3.5939155, 1.1410666,...","[0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."
4,11,"[-1.383354, 1.648171, -3.7612596, -1.3233246, ...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...
9120,9140,"[0.28404692, 0.21873271, -0.3254155, 1.6711218...","[0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."
9121,9141,"[-0.8980954, -1.6762947, -1.2756481, 0.9909565...","[0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
9122,9142,"[0.6182944, -0.6889125, -0.38154954, 1.0363413...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."
9123,9143,"[1.9886293, 1.4495139, 1.4017138, 0.2425547, 1...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [51]:
data.to_csv('../data/data.csv', index=False)