<img width="8%" alt="Neo4j.png" src="https://raw.githubusercontent.com/jupyter-naas/awesome-notebooks/master/.github/assets/logos/Neo4j.png" style="border-radius: 15%">

# Neo4j - Push Content to Graph Database

**Tags:** #neo4j #abi #knowledgegraph

**Author:** [Florent Ravenel](https://www.linkedin.com/in/florent-ravenel)

**Last update:** 2024-04-08 (Created: 2024-03-25)

**Description:** This notebook push content data to Neo4j.

## Input

### Import libraries

In [None]:
import naas_data_product
import naas_python as naas
try:
    import neo4j
except:
    !pip install neo4j==5.18.0 --user
    import neo4j
from neo4j import GraphDatabase
from pyvis.network import Network
import pandas as pd
from naas_drivers import gsheet
import os
import re
import random

### Setup variables

In [None]:
# Inputs
entity_index = "0"
spreadsheet_url = pload(os.path.join(naas_data_product.OUTPUTS_PATH, "entities", entity_index), "abi_spreadsheet") or ""
spreadsheet_url_ref = "https://docs.google.com/spreadsheets/d/1ofYdsJ8Tq86_FbLBeUuBTB06RAx9cHZ3aea3ScDkPoQ/edit#gid=579458046"
sheet_posts = "POSTS"
excludes = ["NA", "TBD", "None", "Not Found"]

# Outputs
url = naas.secret.get("NEO4J_URI").value
username = naas.secret.get("NEO4J_USERNAME").value
password = naas.secret.get("NEO4J_PASSWORD").value
output_graph = "graph.html"
output_file_path = "cypher_query.txt"

## Model

### Create empty txt file to store cypher query

In [None]:
def create_txt_file(file_path, force_update=True):
    if not os.path.exists(file_path) or force_update:
        data = ""
        with open(file_path, 'w') as f:  # Open the output file in write mode
            f.write(data)  # Write the query to the output file
        print(f"✅ Text file successfully created: '{file_path}'")

create_txt_file(output_file_path)

### Connect to GraphDatabase

In [None]:
driver = GraphDatabase.driver(url, auth=(username, password))

### Helper function

In [None]:
def generate_unique_id_from_text(text):
    letters = ''.join(re.findall('[a-zA-Z]+', text))
    return ''.join(random.choice(letters) for i in range(len(letters))).lower()

def match_node_with_type_and_uid(driver, node_type, uid):
    with driver.session() as session:
        cypher_query = 'MATCH (n:' + node_type + ' {id: "' + uid + '"}) RETURN n'
        result = session.run(cypher_query)
        return [record[0] for record in result]
    
def update_txt_file(output_file_path, cypher_query):
    with open(output_file_path, 'r') as f:
        data = f.read()
    data += cypher_query
    with open(output_file_path, 'w') as f:  # Open the output file in write mode
        f.write(data)  # Write the query to the output file
        
def clean_string_property(text):
    return text.replace("'", "\\'").replace("\n", "\\n").strip()

def create_node(
    driver,
    node_type,
    node_id,
    node_id_label="id",
    properties=[],
    output_file_path='cypher_query.txt'
):
    # Function to create a node with custom type and properties in Neo4j
    def create_custom_node(tx, node_type, node_id, node_id_label, properties):        
        # Init
        v = generate_unique_id_from_text(node_id)
        
        # Prepare properties string for Cypher query
        properties_list = []
        for key, value in properties.items():
            if value not in excludes:
                if 'date' in key:
                    value = value.replace(" ", "T")
                    properties_list.append(f"{v}.{key} = datetime('{value}')")
                elif isinstance(value, int) or isinstance(value, float):
                    properties_list.append(f"{v}.{key} = {value}")
                else:
                    value = clean_string_property(value)
                    properties_list.append(f"{v}.{key} = '{value}'")
        properties_str = ', '.join(properties_list)
        
        # Create cypher query
        cypher_query = f"MERGE ({v}:{node_type} {{{node_id_label}: '{node_id}'}}) SET {properties_str} "
        tx.run(cypher_query)
        print(f"✔️ Node '{node_id}' created successfully.")
        
        # Save in txt file
        update_txt_file(output_file_path, cypher_query)

    ### Use the driver to create a session and run the function
    with driver.session() as session:
        session.execute_write(create_custom_node, node_type, node_id, node_id_label, properties)
                    
def create_nodes(
    driver,
    node_type,
    data,
    output_file_path='cypher_query.txt'
):
    # Cleaning
    node_label = node_type.strip().replace('_', '')
    
    # Check if nodes already created
    for d in data:
        create_node(
            driver,
            node_label,
            d.get("id"),
            node_id_label="id",
            properties=d,
            output_file_path=output_file_path
        )
    print(f"✅ Nodes type '{node_label}' successfully created (total: {len(data)})")
    return data

def create_node_from_gsheet(
    driver,
    spreadsheet_url,
    sheet_name,
    output_file_path,
):
    # Prep data
    df = gsheet.connect(spreadsheet_url).get(sheet_name=sheet_name)
    if "id" not in df.columns:
        df["id"] = df["name"].str.lower().str.strip().str.replace(' ', '_')
    data = df.to_dict(orient="records")

    # Create nodes
    create_nodes(driver, sheet_name, data, output_file_path)
    return df

def create_nodes_from_single_column(df, column):
    # Init
    df = df[~df[column].isin(["Not Found", "NA", "TBD"])].reset_index(drop=True)
    data = []
    for x in df[column].unique().tolist():
        data.append({"id": x.lower(), "name": f"{x}"})
    return data

### ProfessionalRole

In [None]:
data_professionalrole = create_node_from_gsheet(driver, spreadsheet_url_ref, "ProfessionalRole", output_file_path)

### Sentiment

In [None]:
data_sentiment = create_node_from_gsheet(driver, spreadsheet_url_ref, "Sentiment", output_file_path)

### Objective

In [None]:
data_objective = create_node_from_gsheet(driver, spreadsheet_url_ref, "Objective", output_file_path)

### Content

In [None]:
# Node name
content_node = "Content"

# Get data from gsheet
df_content = gsheet.connect(spreadsheet_url).get(sheet_name=sheet_posts)
df = df_content.copy()
df.columns = df.columns.str.lower()
data = df.drop(["date", "time"], axis=1).to_dict(orient="records")

# Create nodes
data_content = create_nodes(driver, content_node, data, output_file_path)

#### Create Content Type node

In [None]:
data_content_type = create_nodes(driver, "ContentType", create_nodes_from_single_column(df_content, "TYPE"), output_file_path)

#### Create Content relationships with Concept, Sentiment, Target, Objective, People, Organization, ContentType

In [None]:
def create_content_relationships(
    text,
    node_label,
    relationship,
    content_id,
    output_file_path
):
    if text not in excludes:
        for t in text.split("|"):
            name = t.split(":", 1)[0]
            summary = t.split(":", 1)[1]
            uid = name.lower().strip().replace(' ', '_')
            data = {"id": uid, "name": name}
            nodes = match_node_with_type_and_uid(driver, node_label, uid)
            if len(nodes) == 0:
                create_nodes(driver, node_label, [data], output_file_path)
            a = generate_unique_id_from_text(content_id)
            b = generate_unique_id_from_text(uid)
            c = generate_unique_id_from_text(relationship)
            # Create relationships
            with driver.session() as session:
                cypher_query = "MERGE (" + a + ":Content {id: '" + content_id + "'}) MERGE (" + b + ":" + node_label + " {id: '" + uid + "'}) MERGE (" + a + ")-[" + c + ":" + relationship + "]->(" + b + ") SET " + c + ".summary = '" + clean_string_property(summary) + "' "
                session.run(cypher_query)
            
            # Update txt file
            update_txt_file(output_file_path, cypher_query)
    
content_relationships = "ContentRelationsShips"

# Create RelationShips
for row in df_content.itertuples():
    uid = str(row.ID)
    concepts = str(row.CONCEPT)
    sentiments = str(row.SENTIMENT)
    targets = str(row.TARGET)
    objectives = str(row.OBJECTIVE)

    create_content_relationships(concepts, "Concept", "DISCUSSES", uid, output_file_path)
    create_content_relationships(sentiments, "Sentiment", "EXPRESSES", uid, output_file_path)
    create_content_relationships(targets, "ProfessionalRole", "TARGETS", uid, output_file_path)
    create_content_relationships(objectives, "Objective", "AIMS_TO_ACHIEVE", uid, output_file_path)

    # Create relationships
    content_type = str(row.TYPE)
    article_shared = str(row.CONTENT_URL_SHARED)
    image_shared = str(row.IMAGE_SHARED)
    url_shared = ""
    if article_shared not in excludes:
        url_shared = article_shared
    elif image_shared not in excludes:
        url_shared = image_shared
    else:
        url_shared = ""
        
    a = generate_unique_id_from_text(uid)
    b = generate_unique_id_from_text(content_type)
    c = generate_unique_id_from_text("IS_TYPE")
    with driver.session() as session:
        cypher_query = "MERGE (" + a + ":Content {id: '" + uid + "'}) MERGE (" + b + ":ContentType {name: '" + content_type + "'}) MERGE (" + a + ")-[" + c + ":IS_TYPE]->(" + b + ") SET " + c + ".url = '" + url_shared + "' "
        session.run(cypher_query)
        
    # Update txt file
    update_txt_file(output_file_path, cypher_query)

### Close Neo4j connection

In [None]:
driver.close()

## Output

### Create Pyvis

In [None]:
import naas

# Query your graph data
def get_graph_data(tx):
    result = tx.run("MATCH (n)-[r]->(m) RETURN n, r, m")
    return [(record["n"], record["r"], record["m"]) for record in result]

with driver.session() as session:
    graph_data = session.execute_read(get_graph_data)

# Create a new PyVis graph
net = Network(
    notebook=True, height="100%", width="100%", bgcolor="#222222", font_color="lightgrey"
)

# Adjust the spring length (edge length)
net.force_atlas_2based(gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.08, damping=0.4, overlap=0)

# Define colors for different node labels
colors = {
    "Content": "lightblue",
    "Concept": "#f24f4f",
    "Sentiment": "#F26C4F",
    "ProfessionalRole": "#741fb5",
    "Objective": "#767dcf",
    "ContentType": "#c476ff",
#     "People": "#f2f24f",
#     "ProfessionalRole": "#525217",
}
# Add nodes and edges to the PyVis graph
nodes = []
for n, r, m in graph_data:
    n_label = next(iter(n.labels)) if n.labels else None
    if n_label not in nodes:
        nodes.append(n_label)
    m_label = next(iter(m.labels)) if m.labels else None
    if m_label not in nodes:
        nodes.append(m_label)
    n_name = n["name"]
    if n_label == "Content":
        n_name = n["title"]
    if len(n_name) > 30:
        n_name = n_name[:27] + "..."
        
    m_name = m["name"]
    if m_label == "Content":
        m_name = m["title"]
    if len(m_name) > 30:
        m_name = m_name[:27] + "..."
    net.add_node(n.element_id, label=n_name, title=n_label, color=colors.get(n_label, "gray"))
    net.add_node(m.element_id, label=m_name, title=m_label, color=colors.get(m_label, "gray"))
    net.add_edge(n.element_id, m.element_id, title=r.type, label=(r.type).capitalize(), arrows={"to": {"enabled": True, "scaleFactor": 2.0}}, font={"size": 10})

# Show the graph
network = net.show(output_graph)
naas.asset.add(output_graph, {"inline": True})

driver.close()