<img width="8%" alt="Neo4j.png" src="https://raw.githubusercontent.com/jupyter-naas/awesome-notebooks/master/.github/assets/logos/Neo4j.png" style="border-radius: 15%">

# Neo4j - Push Content to Graph Database

**Tags:** #neo4j #abi #knowledgegraph

**Author:** [Florent Ravenel](https://www.linkedin.com/in/florent-ravenel)

**Last update:** 2024-04-08 (Created: 2024-03-25)

**Description:** This notebook creates ABI Knowledge Graph in Neo4j.

## Input

### Import libraries

In [None]:
import naas
try:
    import neo4j
except:
    !pip install neo4j==5.18.0 --user
    import neo4j
from neo4j import GraphDatabase
from pyvis.network import Network
import pandas as pd
from naas_drivers import gsheet
import os

### Setup variables

In [None]:
# Inputs
spreadsheet_url = pload(os.path.join(naas_data_product.OUTPUTS_PATH, "entities", entity_index), "abi_spreadsheet") or ""
sheet_posts = "DATASET_POSTS"
excludes = ["NA", "TBD", "None", "Not Found"]

# Outputs
url = "neo4j+s://d3eeda32.databases.neo4j.io:7687"
username = naas.secret.get("NEO4J_USERNAME")
password = naas.secret.get("NEO4J_PASSWORD")
output_graph = os.path.join("..", "outputs", "graph.html")

## Model

### Helper Functions

In [None]:
def set_properties(tx, node_label, d, properties):
    uid = d.get("id").strip().replace(' ', '_').replace("'", "")
    merge = f"MERGE (a:" + node_label + " {id: '" + str(uid) + "'})"
    for prop in properties:
        value = d.get(prop)
        set_property = "SET a." + prop + " = $value"
        if "date" in prop and value not in excludes:
            value = value.replace(" ", "T")
            set_property = "SET a." + prop + " = datetime($value)"
        tx.run(
            f"{merge} {set_property}",
            value=value
        )
                    
def create_nodes(
    driver,
    node_label,
    data,
    properties=["name"],
):
    # Cleaning
    node_label = node_label.strip().replace('_', '')
    
    # Check if nodes already created
    for d in data:
        with driver.session() as session:
            session.execute_write(set_properties, node_label, d, properties)
    print(f"✅ Nodes '{node_label}' successfully created (total: {len(data)})")
    return data

def create_node_from_gsheet(
    driver,
    spreadsheet_url,
    sheet_name,
):
    # Prep data
    df = gsheet.connect(spreadsheet_url).get(sheet_name=sheet_name)
    if "id" not in df.columns:
        df["id"] = df["name"].str.lower().str.strip().str.replace(' ', '_')
    data = df.to_dict(orient="records")
    properties = list(df.columns)
    properties.remove("id")
    
    # Create nodes
    create_nodes(driver, sheet_name, data, properties)
    return df

def clean_df(
    df_init,
    to_keep,
    to_rename,
):
    # Init
    df = df_init.copy()
            
    # Cleaning
    df = df[to_keep].rename(columns=to_rename)
    df = df[~df["ID"].isin(["TBD", "NA"])]
    df["ID"] = df["ID"].astype(str)
    df.columns = df.columns.str.lower()
    return df.reset_index(drop=True)

def send_graph_data_to_gsheet(
    df_init,
    to_keep,
    to_rename,
    spreadsheet_url,
    sheet_name
):
    # Init
    df = df_init.copy()
        
    # Cleaning
    df = clean_df(df_init, to_keep, to_rename)
    
    # Send data to gsheet
    send_data_to_gsheet(df, pd.DataFrame(), spreadsheet_url, sheet_name)
    return df

def create_nodes_from_single_column(df, column):
    # Init
    df = df[~df[column].isin(["Not Found", "NA", "TBD"])].reset_index(drop=True)
    data = []
    for x in df[column].unique().tolist():
        data.append({"id": x.lower(), "name": f"{x}"})
    return data

### Connect to GraphDatabase

In [None]:
driver = GraphDatabase.driver(url, auth=(username, password))

### ProfessionalRole

In [None]:
data_professionalrole = create_node_from_gsheet(driver, spreadsheet_url, "ProfessionalRole")

### Sentiment

In [None]:
data_sentiment = create_node_from_gsheet(driver, spreadsheet_url, "Sentiment")

### Objective

In [None]:
data_objective = create_node_from_gsheet(driver, spreadsheet_url, "Objective")

### Content

In [None]:
# Node name
content_node = "Content"

# Get data from gsheet
df_content = gsheet.connect(spreadsheet_url).get(sheet_name=sheet_posts)
df = df_content.copy()
df.columns = df.columns.str.lower()
data = df.to_dict(orient="records")
properties = list(df.columns)
properties.remove("id")

# Create nodes
data_content = create_nodes(driver, content_node, data, properties)

#### Create Content Type node

In [None]:
data_content_type = create_nodes(driver, "ContentType", create_nodes_from_single_column(df_content, "TYPE"))

#### Create Content relationships with Concept, Sentiment, Target, Objective, People, Organization, ContentType

In [None]:
def create_content_relationships(
    text,
    node_label,
    relationship,
    content_id
):
    if text not in excludes:
        for t in text.split("|"):
            name = t.split(":", 1)[0]
            summary = t.split(":", 1)[1]
            uid = name.lower().strip().replace(' ', '_')
            data = {"id": uid, "name": name}
            properties = ["name"]

            # Create node
            with driver.session() as session:
                session.execute_write(set_properties, node_label, data, properties)

            # Create relationships
            with driver.session() as session:
                cypher_query = 'MATCH (a:Content {id: "' + content_id + '"}), (b:' + node_label + ' {id: "' + uid + '"}) MERGE (a)-[c:' + relationship + ']->(b) SET c.summary = "' + summary.strip() + '"'
                session.run(cypher_query)
    
content_relationships = "ContentRelationsShips"

# Create RelationShips
for row in df_content.itertuples():
    uid = str(row.ID)
    concepts = str(row.CONCEPT)
    sentiments = str(row.SENTIMENT)
    targets = str(row.TARGET)
    objectives = str(row.OBJECTIVE)


    create_content_relationships(concepts, "Concept", "DISCUSSES", uid)
    create_content_relationships(sentiments, "Sentiment", "EXPRESSES", uid)
    create_content_relationships(targets, "ProfessionalRole", "TARGETS", uid)
    create_content_relationships(objectives, "Objective", "AIMS_TO_ACHIEVE", uid)

    # Create relationships
    content_type = str(row.TYPE)
    article_shared = str(row.CONTENT_URL_SHARED)
    image_shared = str(row.IMAGE_SHARED)
    url_shared = ""
    if article_shared not in excludes:
        url_shared = article_shared
    elif image_shared not in excludes:
        url_shared = image_shared
    else:
        url_shared = ""

    with driver.session() as session:
        cypher_query = 'MATCH (a:Content {id: "' + uid + '"}), (b:ContentType {name: "' + content_type + '"}) MERGE (a)-[c:IS_TYPE]->(b) SET c.url = "' + url_shared + '"'
        session.run(cypher_query)

### Close Neo4j connection

In [None]:
driver.close()

## Output

### Create Pyvis

In [None]:
# Query your graph data
def get_graph_data(tx):
    result = tx.run("MATCH (n)-[r]->(m) RETURN n, r, m")
    return [(record["n"], record["r"], record["m"]) for record in result]

with driver.session() as session:
    graph_data = session.execute_read(get_graph_data)

# Create a new PyVis graph
net = Network(
    notebook=True, height="100%", width="100%", bgcolor="#222222", font_color="lightgrey"
)

# Adjust the spring length (edge length)
net.force_atlas_2based(gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.08, damping=0.4, overlap=0)

# Define colors for different node labels
colors = {
    "Content": "lightblue",
    "Concept": "#f24f4f",
    "Sentiment": "#F26C4F",
    "ProfessionalRole": "#741fb5",
    "Objective": "#767dcf",
    "ContentType": "#c476ff",
#     "People": "#f2f24f",
#     "ProfessionalRole": "#525217",
}
# Add nodes and edges to the PyVis graph
nodes = []
for n, r, m in graph_data:
    n_label = next(iter(n.labels)) if n.labels else None
    if n_label not in nodes:
        nodes.append(n_label)
    m_label = next(iter(m.labels)) if m.labels else None
    if m_label not in nodes:
        nodes.append(m_label)
    n_name = n["name"]
    if n_label == "Content":
        n_name = n["title"]
    if len(n_name) > 30:
        n_name = n_name[:27] + "..."
        
    m_name = m["name"]
    if m_label == "Content":
        m_name = m["title"]
    if len(m_name) > 30:
        m_name = m_name[:27] + "..."
    net.add_node(n.element_id, label=n_name, title=n_label, color=colors.get(n_label, "gray"))
    net.add_node(m.element_id, label=m_name, title=m_label, color=colors.get(m_label, "gray"))
    net.add_edge(n.element_id, m.element_id, title=r.type, label=(r.type).capitalize(), arrows={"to": {"enabled": True, "scaleFactor": 2.0}}, font={"size": 10})

# Show the graph
network = net.show(output_graph)
naas.asset.add(output_graph, {"inline": True})

driver.close()