In [None]:
import os

In [None]:
%load_ext autoreload
%autoreload 2
from curate_gpt.pipeline.pipelines import *


In [None]:
os.environ["USE_AZURE"] = "true"
os.environ["OPENAI_API_KEY"] = "ff1691b0cd664c4f81ca22128826dfa4"
os.environ["AZURE_OPENAI_API_KEY"] = "ff1691b0cd664c4f81ca22128826dfa4"

In [None]:
# """Chat with data in a collection.

# Example:

#     curategpt extract-dspy -c hpoa "What is the HPO ID for breast cancer?"
# """
# query = "What is the HPO ID for breast cancer?"
# result = rag_dspy(query, "stagedb", "hpoa")

In [None]:
import pandas as pd
gt_dataset = pd.read_csv('data/v1_ground_truth_dataset.csv')

In [None]:
import dspy
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch, BootstrapFinetune

In [None]:
from sklearn.model_selection import train_test_split
import json 
import dspy

def parse_hpo_for_rq(row: pd.Series, col: str) -> list[str]:
    return json.loads(row[col])

def get_examples(df, limit, inputs):
    examples = []
    df = df.head(limit)
    for _, row in df.iterrows():
        rq = row["RQ"]
        indication_text = row["INDICATION"]
        # question = f'What are the HPO ids for phenotypes in this text? Text: {row["INDICATION"]}'
        hpo_ids = parse_hpo_for_rq(row, "QCED_HPO_IDS")
        hpo_terms = parse_hpo_for_rq(row, "QCED_HPOS")
        #hpo_ids = " ,".join(hpo_ids)  # TODO: Add teleprompter support for list, otherwise we later see AttributeError: 'list' object has no attribute 'split' from the TemplateV2 format handler
        example = dict(context=indication_text, hpo_ids=hpo_ids, hpo_terms=hpo_terms)
        # example['labels'] = dspy.Example(hpo_ids=hpo_ids)
        examples.append(dspy.Example(**example).with_inputs(*inputs))
    return examples


seed = 10230495
train, dev = train_test_split(gt_dataset, test_size=0.3, random_state=seed)

train_examples = get_examples(train, 100, ['context', 'labels'])
dev_examples = get_examples(dev, 50, ['context', 'labels'])

# dataset aliases:
train = train_examples
trainset = train_examples
dev = dev_examples
devset = dev_examples

In [None]:


language_model = GPT(temperature=0.7, use_azure=True)

dspy.settings.configure(lm=language_model)

# Define the predictor.
generate_answer = dspy.Predict(BasicQA)

# Call the predictor on a particular input.
ex0 = train_examples[0].context
print(f"Question: {ex0}")
result = generate_answer(question=f"Predict HPO IDs for the following text: {ex0}")

language_model.inspect_history(n=1)

In [None]:
train_examples[0].labels().hpo_ids

In [None]:
from typing import Union

metric_EM = dspy.evaluate.answer_exact_match

def normalize(hpo_id: str) -> str:
    return hpo_id.strip()


# NOTE: sorted may be a bad call
def normalize_list(hpo_ids: list[str]) -> list[str]:
    return list(filter(None, [normalize(r) for r in hpo_ids]))


def metric_recall(gold: list[str], pred: Union[list[str], str], K:int=10) -> float:
    """ Given a gold and predicted list of reactions, normalize and compute recall."""
    if isinstance(pred, str):
        pred = pred.split(",")

    gold = normalize_list(gold)
    pred = normalize_list(pred)[:K]
    
    gold, pred = set(gold), set(pred)
    
    intersection = gold.intersection(pred)

    recall = len(intersection) / len(gold)
    return recall

def metric_recallK(gold: list[str], pred: Union[list[str], str], K:int=10) -> float:
    return metric_recall(gold, pred, K=K)

# wrap the recall@K metric so it can take dspy Examples
def dspy_metric_recall10(gold: dspy.Example, pred: dspy.Example, trace=None) -> float:
    return metric_recallK(gold.labels().hpo_ids, pred.hpo_ids, K=10)


In [None]:
class PredictHPOs(dspy.Signature):
    __doc__ = f"""Given a snippet from a patient's medical history, identify the Human Phenotype Ontology (HPO) identifier for each phenotype in the text. If none are mentioned in the snippet, say '\n'."""

    context = dspy.InputField()
    hpo_ids = dspy.OutputField(desc="list of comma-separated HPO IDs", format=lambda x: ', '.join(x) if isinstance(x, list) else x)


class CoT(dspy.Module):
    def __init__(self):
        super().__init__()

        # here we declare the chain of thought sub-module, so we can later compile it (e.g., teach it a prompt)
        self.generate_answer = dspy.ChainOfThought(PredictHPOs)
    
    def forward(self, context, labels=None):
        return self.generate_answer(context=context)

In [None]:
example_limit = 10
threads = 10
sample_devset = dev_examples[:example_limit]
evaluate_hpo = Evaluate(devset=sample_devset, metric=dspy_metric_recall10, num_threads=threads, display_progress=True, display_table=15)
result = evaluate_hpo(CoT())
result

In [None]:
# so recall@10 is around 20% uncompiled running eval on 10 threads
# weird, result was 36% when stepping through each example with pdb single-threaded.
# does pausing give the model more time, and improve answers?
print(36.67)

In [None]:
teleprompter = BootstrapFewShot(metric=dspy_metric_recall10, max_bootstrapped_demos=2)
cot_compiled = teleprompter.compile(CoT(), trainset=train_examples[:10])

In [None]:
cot_compiled(train_examples[0].context)

In [None]:
NUM_THREADS = 32
evaluate_hpo = Evaluate(devset=dev_examples[:10], metric=dspy_metric_recall10, num_threads=NUM_THREADS, display_progress=True, display_table=15)

In [None]:
result = evaluate_hpo(cot_compiled)
result

In [None]:
# huge increase (20% -> 40-50% recall) just by compiling the model over 10 examples. let's inspect the prompt:
language_model.inspect_history(n=1)

In [None]:
# dir(cot_compiled)
cot_compiled.save("data/cot_compiled_50pct_recall")

In [None]:
import importlib
from curate_gpt.pipeline import retrieval
importlib.reload(retrieval)

retrieve_model = retrieval.ChromadbForAzureRM.from_dir(
    persist_directory="stagedb",
    collection_name="hpoa",
    use_azure=True
)
dspy.settings.configure(rm=retrieve_model, lm=language_model)


In [None]:
class SearchQueryForHPOs(dspy.Signature):
    __doc__ = f"""Given a snippet from a patient's medical history, create a search query for the Human Phenotype Ontology (HPO) identifier for each phenotype in the text."""

    context = dspy.InputField()
    search_query = dspy.OutputField(desc="search query to retrieve HPO document texts")


class PredictWithSearchHPOs(dspy.Signature):
    __doc__ = f"""Given a snippet from a patient's medical history and the search results, identify the Human Phenotype Ontology (HPO) identifier for each phenotype in the text. If none are mentioned in the snippet, say '\n'."""

    context = dspy.InputField()
    documents = dspy.InputField(desc="HPO document texts", format=lambda x: '\n\n'.join(x) if isinstance(x, list) else x)
    hpo_ids = dspy.OutputField(desc="list of comma-separated HPO IDs", format=lambda x: ', '.join(x) if isinstance(x, list) else x)


class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()

        # declare three modules: the retriever, a query generator, and an answer generator
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_query = dspy.ChainOfThought(SearchQueryForHPOs)
        self.generate_answer = dspy.ChainOfThought(PredictWithSearchHPOs)
    
    def forward(self, context, labels=None):
        # generate a search query from the context, and use it to retrieve passages
        search_query = self.generate_query(context=context).search_query
        documents = self.retrieve(search_query).passages

        # generate an answer from the passages and the question
        return self.generate_answer(context=context, documents=documents)

In [None]:
evaluate_hpo(RAG(), display_table=0)

In [None]:
# Pretty good! >28% uncompiled!

In [None]:
query = train_examples[2].context
retrieve = dspy.Retrieve(k=3)
top_passages = retrieve(query).passages
print(f"Query: {query}")
for passage in top_passages:
    print("=" * 30)
    print(passage)

In [None]:
threads = 10
teleprompter2 = BootstrapFewShotWithRandomSearch(metric=dspy_metric_recall10, max_bootstrapped_demos=2, num_candidate_programs=2, num_threads=threads)
rag_compiled = teleprompter2.compile(RAG(), trainset=train_examples[:20], valset=dev_examples[:10])

In [None]:
threads = 10
ex = dev_examples[40:50]
print(len(dev_examples))
evaluate_hpo = Evaluate(devset=ex, metric=dspy_metric_recall10, num_threads=threads, display_progress=True, display_table=15)
result = evaluate_hpo(rag_compiled)
result

In [None]:
language_model.inspect_history(n=2, skip=2)

In [None]:
result

In [None]:
avm_example = dev_examples[45]
pred = rag_compiled(avm_example.context)

In [None]:
pred

In [None]:
pred.hpo_ids

In [None]:
avm_result = dspy_metric_recall10(avm_example, pred)
avm_result

In [None]:
query = train_examples[2].context

answer = rag_compiled(query)

print(f"Query: {query}")
print(f"Answer: {answer}")

In [None]:
language_model.inspect_history(n=1, skip=1)

In [None]:
train_examples[2].labels().hpo_ids

In [None]:
# One problem with the current approach is that we're only retrieving results for one of the phenotypes it seems, e.g. Goiter in the above example.
# One option is to use MultiHop
# Another is to use something like PredictThenGround
# https://colab.research.google.com/drive/1CpsOiLiLYKeGrhmq579_FmtGsD5uZ3Qe#scrollTo=0TjOZmXEUDie

In [None]:
# from dsp.utils.utils import deduplicate

# class MultiHop(dspy.Module):
#     def __init__(self, num_passages=10):
#         super().__init__()

#         self.retrieve = dspy.Retrieve(k=num_passages)
#         self.generate_query = dspy.ChainOfThought("question -> search_query")

#         self.generate_query_from_context = dspy.ChainOfThought("context, question -> search_query")

#         self.generate_answer = dspy.ChainOfThought("context, question -> answer")
    
#     def forward(self, question):
#         passages = []
        
#         search_query = self.generate_query(question=question).search_query
#         passages += self.retrieve(search_query).passages

#         search_query2 = self.generate_query_from_context(context=deduplicate(passages), question=question).search_query

#         passages += self.retrieve(search_query2).passages

#         return self.generate_answer(context=deduplicate(passages), question=question)

In [None]:
# threads = 32
# teleprompter3 = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_candidate_programs=2, num_threads=threads)
# multihop_compiled = teleprompter3.compile(MultiHop(), trainset=train_examples, valset=dev_examples)

In [None]:
# evaluate_hpo(multihop_compiled, devset=dev_examples)

In [None]:
# query = train_examples[3].question

# multihop_compiled(question=query)

# language_model.inspect_history(n=1, skip=2)

In [None]:
# language_model.inspect_history(n=1, skip=3)

In [None]:
# train_examples[3].answer

In [None]:
# query

In [None]:
# TODOs
# Work on multihop to only query text from original passage.... the question was changed.
# 