In [None]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
import dspy
from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2
from dspy.evaluate import Evaluate
import asyncio
import os
import numpy as np
from dotenv import load_dotenv
from datasets import load_dataset
import logging

from nano_graphrag._utils import compute_mdhash_id
from nano_graphrag.entity_extraction.extract import generate_dataset
from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor
from nano_graphrag.entity_extraction.metric import relationship_recall_metric, relationship_similarity_metric, entity_recall_metric

In [None]:
WORKING_DIR = "./nano_graphrag_cache_finetune_entity_relationship_dspy"

load_dotenv()

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)

np.random.seed(1337)

In [None]:
system_prompt = """
    You are a world-class AI system, capable of complex reasoning and reflection. 
    Reason through the query, and then provide your final response. 
    If you detect that you made a mistake in your reasoning at any point, correct yourself.
    Think carefully.
"""
lm = dspy.OpenAI(
    model="deepseek-chat", 
    model_type="chat", 
    api_key=os.environ["DEEPSEEK_API_KEY"], 
    base_url=os.environ["DEEPSEEK_BASE_URL"], 
    system_prompt=system_prompt, 
    temperature=0.3,
    top_p=1.0,
    max_tokens=4096
)
llama_lm = dspy.OllamaLocal(
    model="llama3.1", 
    model_type="chat",
    system=system_prompt,
    max_tokens=4096
)
dspy.settings.configure(lm=lm)

In [None]:
os.makedirs(WORKING_DIR, exist_ok=True)
train_len = 20
val_len = 2
dev_len = 3
entity_relationship_trainset_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news_trainset.pkl")
entity_relationship_valset_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news_valset.pkl")
entity_relationship_devset_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news_devset.pkl")
entity_relationship_module_path = os.path.join(WORKING_DIR, "entity_relationship_extraction_news.json")
fin_news = load_dataset("ashraq/financial-news-articles")
cnn_news = load_dataset("AyoubChLin/CNN_News_Articles_2011-2022")
fin_shuffled_indices = np.random.permutation(len(fin_news['train']))
cnn_train_shuffled_indices = np.random.permutation(len(cnn_news['train']))
cnn_test_shuffled_indices = np.random.permutation(len(cnn_news['test']))
train_data = cnn_news['train'].select(cnn_train_shuffled_indices[:train_len])
val_data = cnn_news['test'].select(cnn_test_shuffled_indices[:val_len])
dev_data = fin_news['train'].select(fin_shuffled_indices[:dev_len])

In [None]:
train_data['text'][:2]

In [None]:
val_data['text']

In [None]:
dev_data['text'][:2]

In [None]:
train_chunks = {compute_mdhash_id(text, prefix=f"chunk-"): {"content": text} for text in train_data["text"]}
trainset = asyncio.run(generate_dataset(chunks=train_chunks, filepath=entity_relationship_trainset_path))

In [None]:
for example in trainset:
    for relationship in example.relationships.context:
        if relationship.order == 2:
            print(relationship)

In [None]:
for example in trainset:
    for relationship in example.relationships.context:
        if relationship.order == 3:
            print(relationship)

In [None]:
trainset[0].relationships.context[:2]

In [None]:
val_chunks = {compute_mdhash_id(text, prefix=f"chunk-"): {"content": text} for text in val_data["text"]}
valset = asyncio.run(generate_dataset(chunks=val_chunks, filepath=entity_relationship_valset_path))

In [None]:
valset[0].relationships.context[:2]

In [None]:
for example in valset:
    for relationship in example.relationships.context:
        if relationship.order == 2:
            print(relationship)

In [None]:
for example in valset:
    for relationship in example.relationships.context:
        if relationship.order == 3:
            print(relationship)

In [None]:
dev_chunks = {compute_mdhash_id(text, prefix=f"chunk-"): {"content": text} for text in dev_data["text"]}
devset = asyncio.run(generate_dataset(chunks=dev_chunks, filepath=entity_relationship_devset_path))

In [None]:
devset[0].relationships.context[:2]

In [None]:
for example in devset:
    for relationship in example.relationships.context:
        if relationship.order == 2:
            print(relationship)

In [None]:
for example in devset:
    for relationship in example.relationships.context:
        if relationship.order == 3:
            print(relationship)

In [None]:
model = EntityRelationshipExtractor()
model

In [None]:
metrics = [relationship_recall_metric, entity_recall_metric, relationship_similarity_metric]
for metric in metrics:
    evaluate = Evaluate(
        devset=devset, 
        metric=metric, 
        num_threads=os.cpu_count(), 
        display_progress=True,
        display_table=5,
    )
    evaluate(model)

In [None]:
optimizer = MIPROv2(
    prompt_model=lm,
    task_model=llama_lm,
    metric=relationship_recall_metric,
    init_temperature=0.7,
    num_candidates=4
)
optimized_model = optimizer.compile(model, trainset=trainset, valset=valset, num_batches=5, max_labeled_demos=5, max_bootstrapped_demos=3)
optimized_model

In [None]:
metrics = [relationship_recall_metric, entity_recall_metric, relationship_similarity_metric]
for metric in metrics:
    evaluate = Evaluate(
        devset=devset, 
        metric=metric, 
        num_threads=os.cpu_count(), 
        display_progress=True,
        display_table=5,
    )
    evaluate(optimized_model)