In [1]:
from pathlib import Path
from pydantic import BaseModel
from typing import Any
from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
import duckdb
load_dotenv()

True

In [2]:
class Neo4jGraph:

    def __init__(self, neo4j_uri:str, neo4j_username:str, neo4j_password:str, db:str)->None:
        self.uri  = neo4j_uri
        self.auth = (neo4j_username, neo4j_password)
        self.db = db
        self.driver = GraphDatabase.driver(self.uri, auth=self.auth)

    def query(self, query:str, params:dict):
        with self.driver.session(database=self.db) as session:
            result = session.run(query, params)
            return [r for r in result]

class Node(BaseModel):
    id:int
    label: str | list[str]
    properties: dict[str, Any] = {}

class Relation(BaseModel):
    id:int | None
    label:str
    properties:dict[str, Any] = {}
    source: Node
    target:Node


In [3]:
def merge_rel(graph:Neo4jGraph, relations:list[Relation], source_label:str, target_label:str):
    res = graph.query(
        "UNWIND $data as row "
        f"MATCH (source_node:{source_label} {{ id: row.source.id }}) "
        f"MATCH (target_node:{target_label} {{ id: row.target.id }}) "
        "CALL apoc.merge.relationship(source_node, "
        "row.label, "
        "row.id, "
        "row.properties, "
        "target_node, "
        "row.properties ) "
        "YIELD rel "
        "RETURN rel "
        ,
        {
            "data":[
                {
                    'id': {'id': rel.id} if rel.id is not None else {},
                    'label': rel.label,
                    'source': rel.source.model_dump(),
                    'target':rel.target.model_dump(),
                    'properties': rel.properties,
                } for rel in relations
            ],
        }
    )
    return res

In [4]:
base_path = Path().cwd().parent
source_path = base_path / Path('gold/anilist/fact-studio-produce.parquet')

neo4j_uri = os.environ['neo4j_uri']
neo4j_username = os.environ['neo4j_username']
neo4j_password = os.environ['neo4j_password']
neo4j_dbname = os.environ['neo4j_dbname']

In [5]:
graph = Neo4jGraph(
    neo4j_uri,
    neo4j_username,
    neo4j_password,
    neo4j_dbname,
)

In [6]:
def make_rel_list_from_table(rows:list[tuple], source_label:str, rel_labels:list[str], columns:list[str], target_label:str,)->list[Relation]:
    rels = []
    for row in rows:
        id_ = row[0]
        source_id = row[1]
        source = Node(id=source_id, label=source_label, properties={})
        target_id = row[2]
        target = Node(id=target_id, label=target_label, properties={})
        prop = dict(zip(columns[3:], row[3:])) if len(row) > 3 else {}
        rel = Relation(id=id_, label=rel_labels, properties=prop, source=source, target=target)
        rels.append(rel)

    return rels

In [7]:
tb = duckdb.read_parquet(str(source_path))
tb.shape

(2355, 4)

In [9]:
tb_rel = duckdb.sql("""
SELECT
    studio_edge_id AS rel_id
    , studio_id
    , anime_id
    , isMain
FROM tb
""")
tb_rel.shape

(2355, 4)

In [10]:
BATCH_SIZE = 10000
num_rows, num_column = tb.shape
columns = tb_rel.columns
labels = 'PRODUCE'
source_label = "Studio"
target_label = "Anime"

In [11]:
batch = tb_rel.fetchmany(size=BATCH_SIZE)
n=1
while batch:
    rels = make_rel_list_from_table(rows=batch, source_label=source_label, rel_labels=labels, columns=columns, target_label=target_label,)
    res = merge_rel(graph, rels, source_label=source_label, target_label=target_label)
    print(f"Inserted relationship batch {n} ({len(res)} rows) into neo4j-{neo4j_dbname}")
    batch = tb_rel.fetchmany(size=BATCH_SIZE)
    n += 1

Inserted relationship batch 1 (2355 rows) into neo4j-kgserve
