In [19]:
import nest_asyncio
nest_asyncio.apply()

In [37]:
import dspy
from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch
from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2
from dspy.evaluate import Evaluate
import asyncio
import os
import numpy as np
from dotenv import load_dotenv
from datasets import load_dataset
import logging
import pickle

from nano_graphrag._utils import compute_mdhash_id
from nano_graphrag.entity_extraction.extract import generate_dataset
from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor
from nano_graphrag.entity_extraction.metric import relationship_similarity_metric, entity_recall_metric

In [21]:
WORKING_DIR = "./nano_graphrag_cache_finetune_entity_relationship_dspy"

load_dotenv()

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)

np.random.seed(1337)

In [22]:
system_prompt = """
    You are a world-class AI system, capable of complex reasoning and reflection. 
    Reason through the query, and then provide your final response. 
    If you detect that you made a mistake in your reasoning at any point, correct yourself.
    Think carefully.
"""
lm = dspy.OpenAI(
    model="deepseek-chat", 
    model_type="chat", 
    api_key=os.environ["DEEPSEEK_API_KEY"], 
    base_url=os.environ["DEEPSEEK_BASE_URL"], 
    system_prompt=system_prompt, 
    temperature=1.0,
    top_p=1.0,
    max_tokens=4096
)
llama_lm = dspy.OllamaLocal(
    model="llama3.1", 
    model_type="chat",
    system=system_prompt,
    max_tokens=4096
)
dspy.settings.configure(lm=lm)

In [23]:
os.makedirs(WORKING_DIR, exist_ok=True)
train_len = 20
val_len = 2
dev_len = 3
entity_relationship_trainset_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news_trainset.pkl")
entity_relationship_valset_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news_valset.pkl")
entity_relationship_devset_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news_devset.pkl")
entity_relationship_module_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news.json")
fin_news = load_dataset("ashraq/financial-news-articles")
cnn_news = load_dataset("AyoubChLin/CNN_News_Articles_2011-2022")
fin_shuffled_indices = np.random.permutation(len(fin_news['train']))
cnn_train_shuffled_indices = np.random.permutation(len(cnn_news['train']))
cnn_test_shuffled_indices = np.random.permutation(len(cnn_news['test']))
train_data = cnn_news['train'].select(cnn_train_shuffled_indices[:train_len])
val_data = cnn_news['test'].select(cnn_test_shuffled_indices[:val_len])
dev_data = fin_news['train'].select(fin_shuffled_indices[:dev_len])

  table = cls._concat_blocks(blocks, axis=0)


In [None]:
train_data['text'][:2]

In [None]:
val_data['text']

In [None]:
dev_data['text'][:2]

In [32]:
train_chunks = {compute_mdhash_id(text, prefix=f"chunk-"): {"content": text} for text in train_data["text"]}
trainset = asyncio.run(generate_dataset(chunks=train_chunks, filepath=entity_relationship_trainset_path))

DEBUG:nano-graphrag:Entities: 17 | Missed Entities: 15 | Total Entities: 32
DEBUG:nano-graphrag:Entities: 9 | Missed Entities: 7 | Total Entities: 16
DEBUG:nano-graphrag:Entities: 27 | Missed Entities: 21 | Total Entities: 48
DEBUG:nano-graphrag:Entities: 18 | Missed Entities: 10 | Total Entities: 28
DEBUG:nano-graphrag:Entities: 9 | Missed Entities: 9 | Total Entities: 18
DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 6 | Total Entities: 19
DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 7 | Total Entities: 21
DEBUG:nano-graphrag:Entities: 8 | Missed Entities: 10 | Total Entities: 18
DEBUG:nano-graphrag:Entities: 28 | Missed Entities: 6 | Total Entities: 34
DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 5 | Total Entities: 18
DEBUG:nano-graphrag:Entities: 15 | Missed Entities: 8 | Total Entities: 23
DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 5 | Total Entities: 19
DEBUG:nano-graphrag:Entities: 21 | Missed Entities: 5 | Total Entities: 26
DEBUG:nano-graphrag:Enti

In [None]:
for example in trainset:
    for relationship in example.relationships.context:
        if relationship.order == 2:
            print(relationship)

In [None]:
for example in trainset:
    for relationship in example.relationships.context:
        if relationship.order == 3:
            print(relationship)

In [None]:
trainset[0].relationships.context[:2]

In [26]:
val_chunks = {compute_mdhash_id(text, prefix=f"chunk-"): {"content": text} for text in val_data["text"]}
valset = asyncio.run(generate_dataset(chunks=val_chunks, filepath=entity_relationship_valset_path))

DEBUG:nano-graphrag:Entities: 21 | Missed Entities: 14 | Total Entities: 35
DEBUG:nano-graphrag:Entities: 10 | Missed Entities: 5 | Total Entities: 15
DEBUG:nano-graphrag:Relationships: 22 | Missed Relationships: 14 | Total Relationships: 36
DEBUG:nano-graphrag:Relationships: 10 | Missed Relationships: 5 | Total Relationships: 15
DEBUG:nano-graphrag:Direct Relationships: 36 | Second-order: 0 | Third-order: 0 | Total Relationships: 36
DEBUG:nano-graphrag:Direct Relationships: 12 | Second-order: 3 | Third-order: 0 | Total Relationships: 15
INFO:nano-graphrag:Saved 2 examples with keys: ['input_text', 'entities', 'relationships']


In [30]:
valset[0].relationships.context[:2]

[Relationship(src_id='PORTUGAL', tgt_id='EURO 2016', description='Portugal qualified for the final of Euro 2016.', weight=0.9, order=1),
 Relationship(src_id='PORTUGAL', tgt_id='WALES', description='Portugal defeated Wales in the semifinal of Euro 2016.', weight=0.9, order=1)]

In [None]:
for example in valset:
    for relationship in example.relationships.context:
        if relationship.order == 2:
            print(relationship)

In [None]:
for example in valset:
    for relationship in example.relationships.context:
        if relationship.order == 3:
            print(relationship)

In [31]:
dev_chunks = {compute_mdhash_id(text, prefix=f"chunk-"): {"content": text} for text in dev_data["text"]}
devset = asyncio.run(generate_dataset(chunks=dev_chunks, filepath=entity_relationship_devset_path))

DEBUG:nano-graphrag:Entities: 27 | Missed Entities: 9 | Total Entities: 36
DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 7 | Total Entities: 21
DEBUG:nano-graphrag:Entities: 7 | Missed Entities: 4 | Total Entities: 11
DEBUG:nano-graphrag:Relationships: 19 | Missed Relationships: 8 | Total Relationships: 27
DEBUG:nano-graphrag:Relationships: 14 | Missed Relationships: 8 | Total Relationships: 22
DEBUG:nano-graphrag:Relationships: 8 | Missed Relationships: 8 | Total Relationships: 16
DEBUG:nano-graphrag:Direct Relationships: 27 | Second-order: 0 | Third-order: 0 | Total Relationships: 27
DEBUG:nano-graphrag:Direct Relationships: 18 | Second-order: 4 | Third-order: 0 | Total Relationships: 22
DEBUG:nano-graphrag:Direct Relationships: 12 | Second-order: 4 | Third-order: 0 | Total Relationships: 16
INFO:nano-graphrag:Saved 3 examples with keys: ['input_text', 'entities', 'relationships']


In [None]:
devset[0].relationships.context[:2]

In [None]:
for example in devset:
    for relationship in example.relationships.context:
        if relationship.order == 2:
            print(relationship)

In [None]:
for example in devset:
    for relationship in example.relationships.context:
        if relationship.order == 3:
            print(relationship)

In [33]:
model = EntityRelationshipExtractor()
model

extractor.predictor = Predict(CombinedExtraction(input_text, entity_types -> entities, relationships
    instructions='Signature for extracting both entities and relationships from input text.'
    input_text = Field(annotation=str required=True json_schema_extra={'desc': 'The text to extract entities and relationships from.', '__dspy_field_type': 'input', 'prefix': 'Input Text:'})
    entity_types = Field(annotation=EntityTypes required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Entity Types:', 'desc': '${entity_types}'})
    entities = Field(annotation=Entities required=True json_schema_extra={'desc': '\n        Format:\n        {\n            "context": [\n                {\n                    "entity_name": "ENTITY NAME",\n                    "entity_type": "ENTITY TYPE",\n                    "description": "Detailed description",\n                    "importance_score": 0.8\n                },\n                ...\n            ]\n        }\n        Each enti

In [39]:
metrics = [entity_recall_metric, relationship_similarity_metric]
for metric in metrics:
    evaluate = Evaluate(
        devset=devset, 
        metric=metric, 
        num_threads=os.cpu_count(), 
        display_progress=True,
        display_table=5,
    )
    evaluate(model)

DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 10 | Total Entities: 23
DEBUG:nano-graphrag:Entities: 22 | Missed Entities: 14 | Total Entities: 36
  0%|          | 0/3 [00:00<?, ?it/s]DEBUG:nano-graphrag:Entities: 9 | Missed Entities: 2 | Total Entities: 11
DEBUG:nano-graphrag:Relationships: 11 | Missed Relationships: 9 | Total Relationships: 20
DEBUG:nano-graphrag:Relationships: 16 | Missed Relationships: 14 | Total Relationships: 30
DEBUG:nano-graphrag:Relationships: 9 | Missed Relationships: 3 | Total Relationships: 12
DEBUG:nano-graphrag:Direct Relationships: 18 | Second-order: 2 | Third-order: 0 | Total Relationships: 20
DEBUG:nano-graphrag:Direct Relationships: 30 | Second-order: 0 | Third-order: 0 | Total Relationships: 30
Average Metric: 1.6150793650793651 / 2  (80.8):  33%|███▎      | 1/3 [00:00<00:00, 619.63it/s]DEBUG:nano-graphrag:Direct Relationships: 10 | Second-order: 2 | Third-order: 0 | Total Relationships: 12
Average Metric: 2.342352092352092 / 3  (78.1): 100%|███

Unnamed: 0,input_text,example_entities,example_relationships,pred_entities,pred_relationships,entity_recall_metric
0,"As students from Marjory Stoneman Douglas High School confront lawmakers with demands to restrict sales of assault rifles, there were warnings by the president of...","context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...","context=[Relationship(src_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', tgt_id='NIKOLAS CRUZ', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='NIKOLAS CRUZ', tgt_id='FLORIDA',...","context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...","context=[Relationship(src_id='NIKOLAS CRUZ', tgt_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='LAURENZO PRADO', tgt_id='MARJORY...",✔️ [0.8055555555555556]
1,"From ferrying people to and from their place of work to transporting nuclear waste and coal, railways are not only an integral part of 21st...","context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='Transportation system used for ferrying people and transporting nuclear waste and coal.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='Country where a business is looking to innovate...","context=[Relationship(src_id='RAILNOVA', tgt_id='BRUSSELS', description='Railnova is based in Brussels.', weight=0.9, order=1), Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova serves Deutsche Bahn as a client.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova serves...","context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='A mode of transportation that involves trains running on tracks, used for various purposes including passenger and cargo transport.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='A...","context=[Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova provides innovative technology solutions to Deutsche Bahn, a German railway company.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova offers its technology services to...",✔️ [0.8095238095238095]
2,Jan 22 (Reuters) - Shanghai Stock Exchange Filing * SHOWS BLOCK TRADE OF YONGHUI SUPERSTORES Co LTd's 166.3 MILLION SHARES INVOLVING 1.63 BILLION YUAN ($254.68...,"context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...","context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='SHANGHAI STOCK EXCHANGE', description=""YONGHUI SUPERSTORES' shares were traded on the SHANGHAI STOCK EXCHANGE."", weight=0.9, order=1), Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was...","context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...","context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was involved in a block trade of 166.3 million shares.', weight=0.9, order=1), Relationship(src_id='166.3 MILLION SHARES', tgt_id='1.63 BILLION YUAN',...",✔️ [0.7272727272727273]


DEBUG:nano-graphrag:Entities: 22 | Missed Entities: 14 | Total Entities: 36
DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 10 | Total Entities: 23
  0%|          | 0/3 [00:00<?, ?it/s]DEBUG:nano-graphrag:Entities: 9 | Missed Entities: 2 | Total Entities: 11
DEBUG:nano-graphrag:Relationships: 16 | Missed Relationships: 14 | Total Relationships: 30
DEBUG:nano-graphrag:Relationships: 11 | Missed Relationships: 9 | Total Relationships: 20
DEBUG:nano-graphrag:Relationships: 9 | Missed Relationships: 3 | Total Relationships: 12
DEBUG:nano-graphrag:Direct Relationships: 30 | Second-order: 0 | Third-order: 0 | Total Relationships: 30
DEBUG:nano-graphrag:Direct Relationships: 18 | Second-order: 2 | Third-order: 0 | Total Relationships: 20
DEBUG:nano-graphrag:Direct Relationships: 10 | Second-order: 2 | Third-order: 0 | Total Relationships: 12

[A

Batches: 100%|██████████| 1/1 [00:00<00:00, 39.39it/s]

Batches: 100%|██████████| 1/1 [00:00<00:00, 35.07it/s]


Batches: 100%|██████████| 1/1 

Unnamed: 0,input_text,example_entities,example_relationships,pred_entities,pred_relationships,relationship_similarity_metric
0,"As students from Marjory Stoneman Douglas High School confront lawmakers with demands to restrict sales of assault rifles, there were warnings by the president of...","context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...","context=[Relationship(src_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', tgt_id='NIKOLAS CRUZ', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='NIKOLAS CRUZ', tgt_id='FLORIDA',...","context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...","context=[Relationship(src_id='NIKOLAS CRUZ', tgt_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='LAURENZO PRADO', tgt_id='MARJORY...",✔️ [0.946203351020813]
1,"From ferrying people to and from their place of work to transporting nuclear waste and coal, railways are not only an integral part of 21st...","context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='Transportation system used for ferrying people and transporting nuclear waste and coal.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='Country where a business is looking to innovate...","context=[Relationship(src_id='RAILNOVA', tgt_id='BRUSSELS', description='Railnova is based in Brussels.', weight=0.9, order=1), Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova serves Deutsche Bahn as a client.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova serves...","context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='A mode of transportation that involves trains running on tracks, used for various purposes including passenger and cargo transport.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='A...","context=[Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova provides innovative technology solutions to Deutsche Bahn, a German railway company.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova offers its technology services to...",✔️ [0.9310485124588013]
2,Jan 22 (Reuters) - Shanghai Stock Exchange Filing * SHOWS BLOCK TRADE OF YONGHUI SUPERSTORES Co LTd's 166.3 MILLION SHARES INVOLVING 1.63 BILLION YUAN ($254.68...,"context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...","context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='SHANGHAI STOCK EXCHANGE', description=""YONGHUI SUPERSTORES' shares were traded on the SHANGHAI STOCK EXCHANGE."", weight=0.9, order=1), Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was...","context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...","context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was involved in a block trade of 166.3 million shares.', weight=0.9, order=1), Relationship(src_id='166.3 MILLION SHARES', tgt_id='1.63 BILLION YUAN',...",✔️ [0.9334976673126221]


In [None]:
optimizer = BootstrapFewShotWithRandomSearch(
    metric=relationship_similarity_metric, 
    num_threads=os.cpu_count(),
    num_candidate_programs=4,
    max_labeled_demos=5,
    max_bootstrapped_demos=3,
)
rs_model = optimizer.compile(model, trainset=trainset, valset=valset)
rs_model

In [None]:
metrics = [entity_recall_metric, relationship_similarity_metric]
for metric in metrics:
    evaluate = Evaluate(
        devset=devset, 
        metric=metric, 
        num_threads=os.cpu_count(), 
        display_progress=True,
        display_table=5,
    )
    evaluate(rs_model)

In [None]:
optimizer = MIPROv2(
    prompt_model=lm,
    task_model=llama_lm,
    metric=relationship_similarity_metric,
    init_temperature=1.0,
    num_candidates=4
)
miprov2_model = optimizer.compile(model, trainset=trainset, valset=valset, num_batches=5, max_labeled_demos=5, max_bootstrapped_demos=3)
miprov2_model

In [None]:
metrics = [entity_recall_metric, relationship_similarity_metric]
for metric in metrics:
    evaluate = Evaluate(
        devset=devset, 
        metric=metric, 
        num_threads=os.cpu_count(), 
        display_progress=True,
        display_table=5,
    )
    evaluate(miprov2_model)