In [None]:
import os

In [None]:
from curate_gpt.pipeline.pipelines import *


In [None]:
os.environ["USE_AZURE"] = "true"

In [None]:
"""Chat with data in a collection.

Example:

    curategpt extract-dspy -c hpoa "What is the HPO ID for breast cancer?"
"""
query = "What is the HPO ID for breast cancer?"
result = rag_dspy(query, "stagedb", "hpoa")

In [None]:
# from dspy.teleprompt import BootstrapFewShot

# # Validation logic: check that the predicted answer is correct.
# # Also check that the retrieved context does actually contain that answer.
# def validate_context_and_answer(example, pred, trace=None):
#     answer_EM = dspy.evaluate.answer_exact_match(example, pred)
#     answer_PM = dspy.evaluate.answer_passage_match(example, pred)
#     return answer_EM and answer_PM

# # Set up a basic teleprompter, which will compile our RAG program.
# teleprompter = BootstrapFewShot(metric=validate_context_and_answer)

# # Compile!
# compiled_rag = teleprompter.compile(RAG(), trainset=trainset)

In [None]:
import pandas as pd
gt_dataset = pd.read_csv('data/v1_ground_truth_dataset.csv')

In [None]:
import dspy
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch, BootstrapFinetune

In [None]:
from sklearn.model_selection import train_test_split
import json 


def parse_hpo_for_rq(row: pd.Series, col: str) -> list[str]:
    return json.loads(row[col])

def get_examples(df, limit):
    examples = []
    df = df.head(limit)
    for _, row in df.iterrows():
        rq = row["RQ"]
        question = f'What are the HPO ids for phenotypes in this text? Text: {row["INDICATION"]}'
        hpo_ids = parse_hpo_for_rq(row, "QCED_HPO_IDS")
        hpo_ids = " ,".join(hpo_ids)  # TODO: Add teleprompter support for list, otherwise we later see AttributeError: 'list' object has no attribute 'split' from the TemplateV2 format handler
        ex = dspy.Example(question=question, answer=hpo_ids).with_inputs('question') 
        examples.append(ex)
    return examples


seed = 10230495
train, dev = train_test_split(gt_dataset, test_size=0.2, random_state=seed)

train_examples = get_examples(train, 10)
dev_examples = get_examples(dev, 20)


In [None]:


language_model = GPT(temperature=0.7)

dspy.settings.configure(lm=language_model)

# Define the predictor.
generate_answer = dspy.Predict(BasicQA)

# Call the predictor on a particular input.
print(f"Question: {train_examples[0].question}")
result = generate_answer(question=train_examples[0].question)

language_model.inspect_history(n=1)

In [None]:
train_examples[0].answer

In [None]:
metric_EM = dspy.evaluate.answer_exact_match


In [None]:

class CoT(dspy.Module):  # let's define a new module
    def __init__(self):
        super().__init__()

        # here we declare the chain of thought sub-module, so we can later compile it (e.g., teach it a prompt)
        self.generate_answer = dspy.ChainOfThought('question -> answer')
    
    def forward(self, question):
        return self.generate_answer(question=question)  # here we use the module

In [None]:

teleprompter = BootstrapFewShot(metric=metric_EM, max_bootstrapped_demos=2)
cot_compiled = teleprompter.compile(CoT(), trainset=train_examples)

In [None]:
cot_compiled(train_examples[0].question)

In [None]:
NUM_THREADS = 32
evaluate_hpo = Evaluate(devset=dev_examples, metric=metric_EM, num_threads=NUM_THREADS, display_progress=True, display_table=15)

In [None]:
evaluate_hpo(cot_compiled)
# dev_examples

In [None]:
retrieve_model = ChromadbForAzureRM.from_dir(
    persist_directory="stagedb",
    collection_name="hpoa",
)
dspy.settings.configure(rm=retrieve_model, lm=language_model)


In [None]:
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()

        # declare three modules: the retriever, a query generator, and an answer generator
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_query = dspy.ChainOfThought("question -> search_query")
        self.generate_answer = dspy.ChainOfThought("context, question -> answer")
    
    def forward(self, question):
        # generate a search query from the question, and use it to retrieve passages
        search_query = self.generate_query(question=question).search_query
        passages = self.retrieve(search_query).passages

        # generate an answer from the passages and the question
        return self.generate_answer(context=passages, question=question)

In [None]:
evaluate_hpo(RAG(), display_table=0)

In [None]:
query = train_examples[2].question
retrieve = dspy.Retrieve(k=3)
top_passages = retrieve(query).passages
print(f"Query: {query}")
for passage in top_passages:
    print("=" * 30)
    print(passage)

In [None]:
teleprompter2 = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=NUM_THREADS)
rag_compiled = teleprompter2.compile(RAG(), trainset=train_examples, valset=dev_examples)

In [None]:
query = train_examples[2].question

answer = rag_compiled(query)

print(f"Query: {query}")
print(f"Answer: {answer}")

In [None]:
language_model.inspect_history(n=1)

In [None]:
from dsp.utils.utils import deduplicate

class MultiHop(dspy.Module):
    def __init__(self, num_passages=10):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_query = dspy.ChainOfThought("question -> search_query")

        self.generate_query_from_context = dspy.ChainOfThought("context, question -> search_query")

        self.generate_answer = dspy.ChainOfThought("context, question -> answer")
    
    def forward(self, question):
        passages = []
        
        search_query = self.generate_query(question=question).search_query
        passages += self.retrieve(search_query).passages

        search_query2 = self.generate_query_from_context(context=deduplicate(passages), question=question).search_query

        # TODO: Replace `None` with a call to self.retrieve to retrieve passages. Append them to the list `passages`.
        passages += self.retrieve(search_query2).passages

        return self.generate_answer(context=deduplicate(passages), question=question)

In [None]:
teleprompter3 = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_candidate_programs=2, num_threads=NUM_THREADS)

multihop_compiled = teleprompter3.compile(MultiHop(), trainset=train_examples, valset=dev_examples)

In [None]:
evaluate_hpo(multihop_compiled, devset=dev_examples)

In [None]:
query = train_examples[3].question

multihop_compiled(question=query)

language_model.inspect_history(n=1, skip=2)