### Ref: https://medium.com/neo4j/implementing-from-local-to-global-graphrag-with-neo4j-and-langchain-constructing-the-graph-73924cc5bab4
### Ref(Github): https://github.com/tomasonjo/blogs/blob/master/llm/ms_graphrag.ipynb

In [57]:
import os

from langchain_community.graphs import Neo4jGraph

os.environ["NEO4J_URI"] = "bolt://localhost:7687"
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = "2wsx3edc"

database = os.environ.get('NEO4J_DATABASE')
graph = Neo4jGraph(database=database)

In [58]:
from langchain_community.vectorstores import Neo4jVector

from langchain_openai import AzureOpenAIEmbeddings

embedding = AzureOpenAIEmbeddings(
    model="text-embedding-ada-002",
    azure_endpoint='https://sales-chatbot-llm.openai.azure.com/openai/deployments/embedding-ada-002/embeddings?api-version=2023-05-15',
    azure_deployment='text-embedding-ada-002',
    openai_api_version='2023-05-15'
)

In [None]:
from langchain_community.vectorstores import Neo4jVector
# ! pip3 install -U langchain-huggingface
import os
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/storage/models/embedding_models'
from langchain_huggingface import HuggingFaceEmbeddings
# Choose from https://huggingface.co/spaces/mteb/leaderboard

# embedding = HuggingFaceEmbeddings(model_name="lier007/xiaobu-embedding-v2")

model_path = os.path.join(os.environ['SENTENCE_TRANSFORMERS_HOME'], 'models--lier007--xiaobu-embedding-v2/snapshots/ee0b4ecdf5eb449e8240f2e3de2e10eeae877691')
embedding = HuggingFaceEmbeddings(model_name=model_path)

In [59]:
node_label = '__Entity__'
embedding_node_property = 'embedding'
fetch_query = (
    f"MATCH (n:`{node_label}`) "
    f"WHERE n.{embedding_node_property} IS null "
    "AND any(k in $props WHERE n[k] IS NOT null) "
    f"RETURN elementId(n) AS id, reduce(str='',"
    "k IN $props | str + '\\n' + k + ':' + coalesce(n[k], '')) AS text "
    "LIMIT 1000"
)
datas = graph.query(fetch_query, params={"props": ['id', 'description']})
datas
import sys
sys.path.append('..')
from tools.TokenCounter import num_tokens_from_string

tokens_num = 0
for data in datas:
    tokens_num += num_tokens_from_string(data['text'])
tokens_num

789

In [60]:
vector = Neo4jVector.from_existing_graph(
    embedding,
    index_name='embedding',
    node_label='__Entity__',
    text_node_properties=['id', 'description'],
    embedding_node_property='embedding'
)

In [22]:
# ! pip3 install graphdatascience

In [61]:
from graphdatascience import GraphDataScience 

gds = GraphDataScience( 
    os.environ[ "NEO4J_URI" ], 
    auth=(os.environ[ "NEO4J_USERNAME" ], os.environ[ "NEO4J_PASSWORD" ]) 
)

In [62]:
gds.graph.drop("entities")

graphName                                                         entities
database                                                             neo4j
databaseLocation                                                     local
memoryUsage                                                               
sizeInBytes                                                             -1
nodeCount                                                               25
relationshipCount                                                       85
configuration            {'relationshipProjection': {'__ALL__': {'aggre...
density                                                           0.141667
creationTime                           2024-09-09T04:45:12.577411763+00:00
modificationTime                       2024-09-09T04:45:13.928689638+00:00
schema                   {'graphProperties': {}, 'nodes': {'__Entity__'...
schemaWithOrientation    {'graphProperties': {}, 'nodes': {'__Entity__'...
Name: 0, dtype: object

In [63]:

G, result = gds.graph.project(
    "entities",                   # Graph name
    "__Entity__",                 # Node projection
    "*",                          # Relationship projection
    nodeProperties=["embedding"]  # Configuration parameters
)

In [64]:
similarity_threshold = 0.95

gds.knn.mutate(
  G,
  nodeProperties=['embedding'],
  mutateRelationshipType= 'SIMILAR',
  mutateProperty= 'score',
  similarityCutoff=similarity_threshold
)

ranIterations                                                             3
nodePairsConsidered                                                    6519
didConverge                                                            True
preProcessingMillis                                                       0
computeMillis                                                            24
mutateMillis                                                             11
postProcessingMillis                                                      0
nodesCompared                                                            26
relationshipsWritten                                                     54
similarityDistribution    {'min': 0.9500160217285156, 'p5': 0.9503250122...
configuration             {'mutateProperty': 'score', 'jobId': 'de751ef3...
Name: 0, dtype: object

In [65]:
gds.wcc.write(
    G,
    writeProperty="wcc",
    relationshipTypes=["SIMILAR"]
)

writeMillis                                                             46
nodePropertiesWritten                                                   26
componentCount                                                           8
componentDistribution    {'min': 1, 'p5': 1, 'max': 7, 'p999': 7, 'p99'...
postProcessingMillis                                                    11
preProcessingMillis                                                      0
computeMillis                                                            1
configuration            {'writeProperty': 'wcc', 'jobId': 'd95c8b7f-bd...
Name: 0, dtype: object

In [66]:
word_edit_distance = 3
potential_duplicate_candidates = graph.query(
    """MATCH (e:`__Entity__`)
    WHERE size(e.id) > 3 // longer than 3 characters
    WITH e.wcc AS community, collect(e) AS nodes, count(*) AS count
    WHERE count > 1
    UNWIND nodes AS node
    // Add text distance
    WITH distinct
      [n IN nodes WHERE apoc.text.distance(toLower(node.id), toLower(n.id)) < $distance 
                  OR node.id CONTAINS n.id | n.id] AS intermediate_results
    WHERE size(intermediate_results) > 1
    WITH collect(intermediate_results) AS results
    // combine groups together if they share elements
    UNWIND range(0, size(results)-1, 1) as index
    WITH results, index, results[index] as result
    WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
            CASE WHEN index <> index2 AND
                size(apoc.coll.intersection(acc, results[index2])) > 0
                THEN apoc.coll.union(acc, results[index2])
                ELSE acc
            END
    )) as combinedResult
    WITH distinct(combinedResult) as combinedResult
    // extra filtering
    WITH collect(combinedResult) as allCombinedResults
    UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
    WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
    WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1)
        WHERE x <> combinedResultIndex
        AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
    )
    RETURN combinedResult
    """, params={'distance': word_edit_distance})
potential_duplicate_candidates

[]

In [50]:
import os

from langchain_experimental.graph_transformers import LLMGraphTransformer
# from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(
    azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
    azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
    temperature=0
)

In [26]:
from langchain_experimental.llms.ollama_functions import OllamaFunctions
model_name = 'qwen2:72b-instruct-q8_0'
llm = OllamaFunctions(model=model_name, temperature=0)
llm

  warn_deprecated(


OllamaFunctions(model='qwen2:72b-instruct-q8_0')

In [51]:
from langchain_core.prompts import ChatPromptTemplate

system_prompt = """You are a data processing assistant. Your task is to identify duplicate entities in a list and decide which of them should be merged.
The entities might be slightly different in format or content, but essentially refer to the same thing. Use your analytical skills to determine duplicates.

Here are the rules for identifying duplicates:
1. Entities with minor typographical differences should be considered duplicates.
2. Entities with different formats but the same content should be considered duplicates.
3. Entities that refer to the same real-world object or concept, even if described differently, should be considered duplicates.
4. If it refers to different numbers, dates, or products, do not merge results
"""
user_template = """
Here is the list of entities to process:
{entities}

Please identify duplicates, merge them, and provide the merged list.
"""

from typing import List, Optional
from pydantic import BaseModel, Field

class DuplicateEntities(BaseModel):
    entities: List[str] = Field(
        description="Entities that represent the same object or real-world entity and should be merged"
    )


class Disambiguate(BaseModel):
    merge_entities: Optional[List[DuplicateEntities]] = Field(
        description="Lists of entities that represent the same object or real-world entity and should be merged"
    )


extraction_llm = llm.with_structured_output(
    Disambiguate
)

extraction_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            system_prompt,
        ),
        (
            "human",
            user_template,
        ),
    ]
)

In [52]:
extraction_chain = extraction_prompt | extraction_llm

In [53]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

def resolve_and_merge_entities_with_llm(potential_duplicate_candidates, max_retry=0) -> List[List[str]]:
    '''
    parmas:
        potential_duplicate_candidates(List[dict['combinedResult': List[str]]): 有可能需要合併的清單 
                                                                                e.g.[{'combinedResult': ['土地銀行', '第一銀行']}]
        max_retry: 最多嘗試次數, 假設為2, 則最多遞迴執行 2+1=3次
    return:
        merged_entities (List[dict['combinedResult': List[str]]) : LLM 確認過需要合併的清單
                                                                    e.g.[{'combinedResult': ['土地銀行', '第一銀行']}]
    '''
    def entity_resolution(entities: List[str]) -> Optional[List[List[str]]]:
        return [
            el.entities
            for el in extraction_chain.invoke({"entities": entities}).merge_entities
        ]
        
    merged_entities_result = []
    merged_future_map = {}
    futures = []
    merged_failds = []
    with ThreadPoolExecutor(max_workers=5) as executor:
        # Submitting all tasks and creating a list of future objects
        for el in potential_duplicate_candidates:
            future = executor.submit(entity_resolution, el['combinedResult'])
            merged_future_map[future] = el
            futures.append(future)
        for future in tqdm(
            as_completed(futures), total=len(futures), desc="Processing documents"
        ):
            try:
                to_merge = future.result()
                if to_merge:
                    merged_entities_result.extend(to_merge)
            except Exception as e:
                el = merged_future_map[future]
                print(f'process element faild!:{el['combinedResult']}, error:\n{e}')
                merged_failds.append(el)
    if len(merged_failds) > 0 and max_retry > 0:
        merged_entities_result.extend(resolve_and_merge_entities_with_llm(merged_failds, max_retry=max_retry-1))
    return merged_entities_result
merged_entities = resolve_and_merge_entities_with_llm(potential_duplicate_candidates, max_retry=1)

Processing documents: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s]


In [55]:
merged_entities

[]

In [56]:

count = 0
for merge_entity in merged_entities:
    merge_entities = [merge_entity]
    results = graph.query("""
  UNWIND $data AS candidates
  CALL {
    WITH candidates
    MATCH (e:__Entity__) WHERE e.id IN candidates
    RETURN collect(e) AS nodes
  }
  CALL apoc.refactor.mergeNodes(nodes, {
      properties: {
        description: 'combine',
        `.*`: 'discard'
      }
    })
  YIELD node
  RETURN count(*)
  """, params={"data": merge_entities})
    count += results[0]['count(*)']
count

0