# Neo4j Json Import Test

## Basic Data

In [None]:
title = "Training Large Language Models to Reason in a Continuous Latent Space"

In [None]:
import sys
import os

# 获取当前脚本所在目录的父目录 (即 my_project)
parent_dir = os.path.dirname(os.getcwd())

# 将父目录添加到 sys.path
sys.path.append(parent_dir)

In [None]:
from apis.arxiv_tool import ArxivKit
from apis.semanticscholar_tool import SemanticScholarKit

Arxiv Metadata

In [None]:
arxiv = ArxivKit()
arxiv_metadata = arxiv.retrieve_metadata_by_paper(query_term=title, max_cnt=3)

SemanticScholar Metadata

In [None]:
ss = SemanticScholarKit()
ss_metadata = ss.search_paper_by_keywords(query=title, limit=3)

In [None]:
# paper_ss_id = ss_metadata[0][0].get('paperId')
paper_ss_id = ss_metadata[0].get('paperId')
print(paper_ss_id)

Reference and Citedby data

In [None]:
reference_metadata = ss.get_semanticscholar_references(paper_id=paper_ss_id, limit=100)
len(reference_metadata)

In [None]:
citedby_metadata = ss.get_semanticscholar_citedby(paper_id=paper_ss_id, limit=100)
len(citedby_metadata)

## Meatadata Process

for semantic scholar paper metadata

In [None]:
def move_key_to_first(input_dict, key_to_move):
    """将字典的某个键移动到第一位。"""
    if key_to_move not in input_dict:
        return input_dict  # 如果键不存在，则直接返回原字典

    value = input_dict[key_to_move]
    new_dict = {key_to_move: value} # 创建新字典，首先插入要移动的键
    for k, v in input_dict.items():
        if k != key_to_move:
            new_dict[k] = v
    return new_dict

In [None]:
def filter_and_reorder_dict_comprehension(input_dict, keys_to_keep):
    """使用字典推导式过滤并按指定键顺序调整字典。"""
    return {key: input_dict[key] for key in keys_to_keep if key in input_dict}

In [None]:
import copy 

def delete_keys_del(input_dict, keys_to_delete):
    """使用 del 关键字删除字典中特定的键。"""
    opt_dct = copy.deepcopy(input_dict)
    for key in keys_to_delete:
        if key in opt_dct:  # 检查键是否存在，避免 KeyError
            del opt_dct[key]
    return opt_dct # 为了方便链式调用，返回修改后的字典

In [None]:
def remove_kth_element(original_list, k):
    """删除list中第k个元素 (不改变原list的值，仅返回新list)"""
    if k <= 0 or k > len(original_list):
        return list(original_list)  # 返回原list的副本，不改变原list
    else:
        new_list = list(original_list) # 创建原list的副本
        new_list.pop(k - 1) # 删除索引为 k-1 的元素 (因为list索引是 0-based)
        return new_list

In [None]:
paper_metadata_json = []

for item in ss_metadata:
    paper_id = item.get('paperId')
    if paper_id is not None:
        # process paper node
        arxiv_id = item.get('externalIds',{}).get('ArXiv')
        if arxiv_id is not None:
            arxiv_id = arxiv_id.replace('10.48550/arXiv.', '') 
        item['arxivId'] = arxiv_id

        doi = item.get('externalIds',{}).get('DOI')
        if doi is None:
            if arxiv_id is not None:
                doi = f"10.48550/arXiv.{arxiv_id}"  # assign 10.48550/arXiv. for arxiv id https://info.arxiv.org/help/doi.html
            else:
                doi = paper_id
        item['DOI'] = doi

        authors = item.get('authors', [])[:10] if item.get('authors', []) is not None else []
        journal = item.get('journal', {}) if item.get('journal', {}) is not None else {}
        venue = item.get('publicationVenue', {}) if item.get('publicationVenue', {}) is not None else {}
      
        paper_node = {
            "type": "node",
            "id": doi,
            "labels": ["Paper"],
            "properties": item
            }
        paper_metadata_json.append(paper_node)

        for idx, author in enumerate(authors[:10]):
            # process author node
            author_id = author.get('authorId')
            if author_id is not None:
                author_node = {
                    "type": "node",
                    "id": author.get('authorId'),
                    "labels": ["Author"],
                    "properties": author}
                paper_metadata_json.append(author_node)
            
                # process author -> WRITES -> paper
                author_order = idx + 1
                coauthors = remove_kth_element(authors, idx)
                author_paper_relationship = {
                    "type": "relationship",
                    "relationshipType": "WRITES",
                    "startNodeId": author_id,
                    "endNodeId": doi,
                    "properties": {'authorOrder': author_order, 'coauthors': coauthors}
                    }
                paper_metadata_json.append(author_paper_relationship)

        journal_name = journal.get('name')
        if journal_name is not None:
            # process journal node
            journal_node = {
                  "type": "node",
                  "id": journal_name,
                  "labels": ["Journal"],
                  "properties": {"name": journal_name}}
            paper_metadata_json.append(journal_node)
            
            if 'arxiv' not in journal_name.lower():  # journal可能会有大量热点，预先进行排除
                # process paper -> PRINTS_ON -> journal
                paper_journal_relationship = {
                  "type": "relationship",
                  "relationshipType": "PRINTS_ON",
                  "startNodeId": doi,
                  "endNodeId": journal_name,
                  "properties": journal}
                paper_metadata_json.append(paper_journal_relationship)

        venue_id = venue.get('id')
        if venue_id is not None:
            # process venue node
            venue_node = {
                  "type": "node",
                  "id": venue_id,
                  "labels": ["Venue"],
                  "properties": venue
                  }
            paper_metadata_json.append(venue_node)
            
            # process paper -> RELEASES_IN -> venue
            if 'arxiv' not in venue.get('name').lower():  # venue可能会有大量热点，预先进行排除
                paper_venue_relationship = {
                  "type": "relationship",
                  "relationshipType": "RELEASES_IN",
                  "startNodeId": doi,
                  "endNodeId": venue.get('id'),
                  "properties": {}}
                paper_metadata_json.append(paper_venue_relationship)

    

In [None]:
paper_metadata_json[-3]

## For Paper Entity

In [None]:
import json
from neo4j import GraphDatabase  # pip install neo4j https://github.com/neo4j/neo4j-python-driver
# import jsonschema  # pip install jsonschema https://github.com/python-jsonschema/jsonschema
# from jsonschema import Draft7Validator

neo4j_uri = "bolt://localhost:7687"  # 替换为你的 Neo4j Bolt URI
neo4j_user = "neo4j"           # 替换为你的 Neo4j 用户名
neo4j_password = "25216590"      # 替换为你的 Neo4j 密码
database = "paper-graph-v0-1"

In [None]:
def import_data_to_neo4j_with_merge(processed_data, uri, username, password, database):
    """
    将预处理后的数据导入 Neo4j 数据库，使用 MERGE 进行节点创建或更新（参数化查询版本）
    """
    driver = GraphDatabase.driver(uri, auth=(username, password))
    
    with driver.session(database=database) as session:
        for item in processed_data:
            if item['type'] == 'node':
                # 合并节点逻辑
                labels = ":".join(item['labels'])
                parameters = {"id": item['id']}
                
                # 构建属性设置语句
                set_clauses = []
                if item.get('properties'):
                    for key in item['properties']:
                        parameters[key] = item['properties'][key]
                    set_clauses = [f"n.{k} = ${k}" for k in item['properties']]
                
                # 构建完整查询
                merge_query = f"MERGE (n:{labels} {{id: $id}})"
                if set_clauses:
                    set_query = "SET " + ", ".join(set_clauses)
                    cypher_query = f"""
                        {merge_query}
                        ON CREATE {set_query}
                        ON MATCH {set_query}
                        RETURN n
                    """
                else:
                    cypher_query = f"{merge_query} RETURN n"
                
                session.run(cypher_query, parameters)

            elif item['type'] == 'relationship':
                # 合并关系逻辑
                rel_type = item['relationshipType']
                parameters = {
                    "startId": item['startNodeId'],
                    "endId": item['endNodeId']  # 注意这里应该用 endNodeId，原代码有错误
                }
                
                # 处理关系属性
                set_clauses = []
                if item.get('properties'):
                    for key in item['properties']:
                        parameters[key] = item['properties'][key]
                    set_clauses = [f"r.{k} = ${k}" for k in item['properties']]
                
                # 构建完整查询
                cypher_query = f"""
                    MATCH (a {{id: $startId}}), (b {{id: $endId}})
                    MERGE (a)-[r:{rel_type}]->(b)
                """
                if set_clauses:
                    set_query = "SET " + ", ".join(set_clauses)
                    cypher_query += f"""
                        ON CREATE {set_query}
                        ON MATCH {set_query}
                    """
                cypher_query += " RETURN r"
                
                session.run(cypher_query, parameters)
    
    driver.close()

In [None]:
import_data_to_neo4j_with_merge(processed_data=paper_metadata_json, uri=neo4j_uri, username=neo4j_user, password=neo4j_password, database=database)

In [None]:
def import_data_to_neo4j(processed_data, uri, username, password, database):
    """
    将预处理后的数据导入 Neo4j 数据库。
    """
    driver = GraphDatabase.driver(uri, auth=(username, password))
    with driver.session(database=database) as session:
        for item in processed_data:
            print(item)
            if item['type'] == 'node':
                # 创建节点的 Cypher 查询
                labels_str = ":".join(item['labels'])
                properties_str = ", ".join([f"{key}: ${key}" for key in item['properties']])
                if properties_str:
                    cypher_query = f"CREATE (n:{labels_str} {{{properties_str}}}) SET n.id = $id RETURN n"
                else:
                    cypher_query = f"CREATE (n:{labels_str}) SET n.id = $id RETURN n" # 没有属性的情况

                session.run(cypher_query, item['properties'], id=item['id']) # 传入属性和 id 参数

            elif item['type'] == 'relationship':
                # 创建关系的 Cypher 查询
                cypher_query = """
                MATCH (startNode {id: $startNodeId}), (endNode {id: $endNodeId})
                CREATE (startNode)-[r:`${relationshipType}`]->(endNode)
                SET r += $properties
                RETURN r
                """
                session.run(cypher_query, item) # 直接传入 item 字典，包含关系的所有参数
    driver.close()

In [None]:
import_data_to_neo4j(processed_data=paper_metadata_json, uri=neo4j_uri, username=neo4j_user, password=neo4j_password, database=database)

In [None]:
def load_json(raw_data):
    """validate and load json data"""
    if isinstance(raw_data, str):
        try:
            json_data = json.loads(raw_data)
            if isinstance(json_data, dict):
                json_data = [json_data]
            return json_data
        except json.JSONDecodeError:
            print("Error: Invalid JSON string provided.")
            return []
    elif isinstance(raw_data, list):
        return raw_data
    elif isinstance(raw_data, dict):
        return [raw_data] # 统一处理为列表
    else:
        print("Error: Invalid JSON data type. Please provide a JSON string, list or dict.")
        return []


In [None]:
def infer_node_mapping_with_schema(json_data, top_n=None):
    """从JSON数据中推断节点相关的 mapping 字典部分，并使用 JSON Schema 进行数据类型检查，
    **仅基于前 N 个 JSON 元素进行推断。**

    Args:
        json_data: JSON数据 (list of dict).
        num_elements (int, optional):  用于推断 mapping 的 JSON 元素数量上限。默认为 None，表示使用所有元素.

    Returns:
        dict: 推断出的节点 mapping 字典.
    """
    node_mapping = {}
    node_types_discovered = {} # 记录已发现的节点类型, 避免重复处理
    elements_to_process = json_data[:top_n] if top_n is not None else json_data # **限制处理的元素数量**

    for record in elements_to_process: # **遍历限制数量的 JSON 元素**
        if not isinstance(record, dict):
            print("Warning: Skipping non-dictionary record:", record)
            continue

        # 简单地使用 JSON 对象的 key 作为节点类型名
        node_type_name = record.keys().__iter__().__next__()
        if not node_type_name:
            print("Warning: Skipping record without keys:", record)
            continue

        if node_type_name not in node_types_discovered and isinstance(record[node_type_name], str):
            node_types_discovered[node_type_name] = True
            node_def = {
                "node_label": node_type_name.capitalize(),
                "properties": {},
                "relationships": []
            }

            # 使用 JSON Schema 进行属性类型推断和检查
            schema = Draft7Validator.check_schema({}) # 初始化一个空的schema validator
            for key, value in record.items():
                if key != node_type_name:
                    property_type = "string" # 默认类型
                    if isinstance(value, int):
                        property_type = "integer"
                    elif isinstance(value, float):
                        property_type = "number"
                    elif isinstance(value, bool):
                        property_type = "boolean"
                    elif isinstance(value, list):
                        property_type = "array"
                    elif isinstance(value, dict):
                        property_type = "object"
                    else:
                        property_type = "string" # 默认字符串类型

                    node_def["properties"][key] = {"neo4j_property": key, "type": property_type} # 存储属性类型信息

            node_mapping[node_def["node_label"]] = node_def

    return node_mapping

In [None]:
node_mapping = infer_node_mapping_with_schema(ss_metadata, top_n=None)

In [None]:
node_mapping = {
    'Paper': {
        'node_label': 'Paperid',
        'properties': 
            {
                'externalIds': {'neo4j_property': 'externalIds',
                'type': 'object'},
                'corpusId': {'neo4j_property': 'corpusId', 'type': 'integer'},
                'publicationVenue': {'neo4j_property': 'publicationVenue',
                    'type': 'object'},
                'url': {'neo4j_property': 'url', 'type': 'string'},
                'title': {'neo4j_property': 'title', 'type': 'string'},
                'abstract': {'neo4j_property': 'abstract', 'type': 'string'},
                'venue': {'neo4j_property': 'venue', 'type': 'string'},
                'year': {'neo4j_property': 'year', 'type': 'integer'},
                'referenceCount': {'neo4j_property': 'referenceCount', 'type': 'integer'},
                'citationCount': {'neo4j_property': 'citationCount', 'type': 'integer'},
                'influentialCitationCount': {'neo4j_property': 'influentialCitationCount',
                    'type': 'integer'},
                'isOpenAccess': {'neo4j_property': 'isOpenAccess', 'type': 'integer'},
                'openAccessPdf': {'neo4j_property': 'openAccessPdf', 'type': 'string'},
                'fieldsOfStudy': {'neo4j_property': 'fieldsOfStudy', 'type': 'array'},
                's2FieldsOfStudy': {'neo4j_property': 's2FieldsOfStudy', 'type': 'array'},
                'publicationTypes': {'neo4j_property': 'publicationTypes', 'type': 'array'},
                'publicationDate': {'neo4j_property': 'publicationDate', 'type': 'string'},
                'journal': {'neo4j_property': 'journal', 'type': 'object'},
                'citationStyles': {'neo4j_property': 'citationStyles', 'type': 'object'},
                'authors': {'neo4j_property': 'authors', 'type': 'array'}
    },
    'relationships': []}}

In [None]:
class Json2Neo4j:
    def __init__(self, input_data, neo4j_uri, neo4j_user, database):
        """
        Args: 
            json_data (str or list or dict): JSON数据 (可以是JSON字符串, JSON对象列表, 或JSON对象).
        """
        self.json_data = self.load_json(input_data)
        self.neo4j_uri = neo4j_uri
        self.neo4j_user = neo4j_user
        self.database = database
        self.driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, database))
        if self.driver is None:
            print("Failed to create Neo4j driver.")
            return


    def load_json(self, raw_data):
        """validate and load json data"""
        if isinstance(raw_data, str):
            try:
                json_data = json.loads(raw_data)
                if isinstance(json_data, dict):
                    json_data = [json_data]
                return json_data
            except json.JSONDecodeError:
                print("Error: Invalid JSON string provided.")
                return []
        elif isinstance(raw_data, list):
            return raw_data
        elif isinstance(raw_data, dict):
            return [raw_data] # 统一处理为列表
        else:
            print("Error: Invalid JSON data type. Please provide a JSON string, list or dict.")
            return []

    def infer_node_mapping_with_schema(self, top_n=None):
        """从JSON数据中推断节点相关的 mapping 字典部分，并使用 JSON Schema 进行数据类型检查，
        **仅基于前 N 个 JSON 元素进行推断。**

        Args:
            json_data (str or list or dict): JSON数据 (可以是JSON字符串, JSON对象列表, 或JSON对象).
            num_elements (int, optional):  用于推断 mapping 的 JSON 元素数量上限。默认为 None，表示使用所有元素.

        Returns:
            dict: 推断出的节点 mapping 字典.
        """
        node_mapping = {}
        node_types_discovered = {} # 记录已发现的节点类型, 避免重复处理
        elements_to_process = self.json_data[:top_n] if top_n is not None else self.json_data # **限制处理的元素数量**

        for record in elements_to_process: # **遍历限制数量的 JSON 元素**
            if not isinstance(record, dict):
                print("Warning: Skipping non-dictionary record:", record)
                continue

            # 简单地使用 JSON 对象的 key 作为节点类型名
            node_type_name = record.keys().__iter__().__next__()
            if not node_type_name:
                print("Warning: Skipping record without keys:", record)
                continue

            if node_type_name not in node_types_discovered and isinstance(record[node_type_name], str):
                node_types_discovered[node_type_name] = True
                node_def = {
                    "node_label": node_type_name.capitalize(),
                    "properties": {},
                    "relationships": []
                }

                # 使用 JSON Schema 进行属性类型推断和检查
                schema = Draft7Validator.check_schema({}) # 初始化一个空的schema validator
                for key, value in record.items():
                    if key != node_type_name:
                        property_type = "string" # 默认类型
                        if isinstance(value, int):
                            property_type = "integer"
                        elif isinstance(value, float):
                            property_type = "number"
                        elif isinstance(value, bool):
                            property_type = "boolean"
                        elif isinstance(value, list):
                            property_type = "array"
                        elif isinstance(value, dict):
                            property_type = "object"
                        else:
                            property_type = "string" # 默认字符串类型

                        node_def["properties"][key] = {"neo4j_property": key, "type": property_type} # 存储属性类型信息

                node_mapping[node_def["node_label"]] = node_def

        return node_mapping


def infer_relationship_mapping_with_schema(json_data, node_mapping, num_elements=None):
    """
    从JSON数据中推断关联关系的 mapping 字典部分，并添加到已有的节点 mapping 中，
    **仅基于前 N 个 JSON 元素进行推断。**

    Args:
        json_data (str or list or dict): JSON数据 (可以是JSON字符串, JSON对象列表, 或JSON对象).
        node_mapping (dict):  已经推断出的节点 mapping 字典 (由 infer_node_mapping_with_schema 函数生成).
        num_elements (int, optional):  用于推断 mapping 的 JSON 元素数量上限。默认为 None，表示使用所有元素.

    Returns:
        dict:  完整的 mapping 字典，包含节点和关系 mapping.
    """
    if isinstance(json_data, str):
        try:
            data = json.loads(json_data)
        except json.JSONDecodeError:
            print("Error: Invalid JSON string provided.")
            return node_mapping
    elif isinstance(json_data, list):
        data = json_data
    elif isinstance(json_data, dict):
        data = [json_data] # 统一处理为列表
    else:
        print("Error: Invalid JSON data type. Please provide a JSON string, list or dict.")
        return node_mapping

    if not data:
        print("Error: Empty JSON data.")
        return node_mapping

    elements_to_process = data[:num_elements] if num_elements is not None else data # **限制处理的元素数量**

    for record in elements_to_process: # **遍历限制数量的 JSON 元素**
        if not isinstance(record, dict):
            continue

        # 简单地使用 JSON 对象的 key 作为节点类型名
        node_type_name = record.keys().__iter__().__next__()
        if not node_type_name or not isinstance(record[node_type_name], str):
            continue

        node_label = node_type_name.capitalize()

        if node_label in node_mapping and "relationships" in node_mapping[node_label]:
            node_def = node_mapping[node_label]

            # 推断关系 (这里只是一个非常基础的示例，假设 "relationship" 字段表示关系)
            if "relationship" in record and isinstance(record["relationship"], list):
                for relation in record["relationship"]:
                    if isinstance(relation, dict) and "type" in relation and "target_" in list(relation.keys())[1]:
                        relationship_type = relation["type"].upper()
                        target_key = list(relation.keys())[1]
                        target_node_label = target_key.replace("target_", "").capitalize()
                        target_node_property_key = target_key
                        source_property_key = node_type_name
                        target_json_key = target_key

                        relationship_def = {
                            "relationship_type": relationship_type,
                            "target_node_label": target_node_label,
                            "target_node_property_key": target_node_property_key,
                            "source_property_key": source_property_key,
                            "target_json_key": target_json_key,
                            "json_relationship_type": relation["type"]
                        }
                        if relationship_def not in node_def["relationships"]:
                            node_def["relationships"].append(relationship_def)

    return node_mapping


def infer_mapping_from_json_with_schema(json_data, num_elements=None):
    """
    从JSON数据中推断完整的Neo4j mapping字典，包括节点和关系，并使用 JSON Schema 进行数据类型检查，
    **仅基于前 N 个 JSON 元素进行推断。**

    Args:
        json_data (str or list or dict): JSON数据 (可以是JSON字符串, JSON对象列表, 或JSON对象).
        num_elements (int, optional):  用于推断 mapping 的 JSON 元素数量上限。默认为 None，表示使用所有元素.

    Returns:
        dict: 推断出的完整的 mapping 字典.
    """
    node_mapping = infer_node_mapping_with_schema(json_data, num_elements) # 推断节点 mapping，传递 num_elements
    full_mapping = infer_relationship_mapping_with_schema(json_data, node_mapping, num_elements) # 推断关系 mapping 并合并，传递 num_elements
    return full_mapping