In [1]:
from copy import deepcopy
import json
import os
import pandas as pd
import typer
from pathlib import Path

import dspy
from dspy.evaluate import Evaluate
from datasets import load_dataset
from bellem.utils import set_seed
from bellem.musique.eval import (
    aggregate_scores,
    compute_scores,
    compute_scores_dataframe,
)
from dotenv import load_dotenv
from rich.console import Console

print = Console(stderr=True).print

load_dotenv()

set_seed(89)


def configure_lm(model, temperature):
    lm = dspy.LM(
        "openai/" + model,
        temperature=temperature,
        cache=False,
        api_base=os.getenv("OPENAI_BASE_URL"),
        api_key=os.getenv("OPENAI_API_KEY"),
    )
    dspy.configure(lm=lm)


def format_paragraph(paragraph):
    text = paragraph["paragraph_text"]
    title = paragraph["title"]
    return f"# {title}\n{text}"


def make_example(record):
    supporting_paragraphs = [p for p in record["paragraphs"] if p["is_supporting"]]
    context = "\n\n".join([format_paragraph(p) for p in supporting_paragraphs])
    return dspy.Example(
        id=record["id"],
        question=record["question"],
        question_decomposition=record["question_decomposition"],
        context=context,
        answer=record["answer"],
        answers=[record["answer"], *record["answer_aliases"]],
    ).with_inputs("question", "context")


class GenerateAnswer(dspy.Signature):
    """Answer the question based on the given context."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class QAModule(dspy.Module):
    def __init__(self, predict_cls=dspy.Predict):
        super().__init__()
        self.generate_answer = predict_cls(GenerateAnswer)

    def forward(self, context, question):
        return self.generate_answer(context=context, question=question)

In [2]:
print("Triplets: \nGlenhis Hern\u00e1ndez | birth place | Havana\nMarta Hern\u00e1ndez Romero | mayor of| Havana\n\nAnswer: Marta Hern\u00e1ndez Romero")

In [3]:
print("Triplets:\nRotst\u00f6ckli | part of | Urner Alps\nUrner Alps | part of | Western Alps\n\n")

In [4]:
## LLM as Judge
class JERXQualityJudge(dspy.Signature):
    """Judge whether the extracted entity-relation-entity triples are in correct format. A triple must in (subject, relation, object) format.

    # Good triples
    Glenhis Hernández;birth place;Havana
    Marta Hernández Romero;mayor of;Havana
    Urner Alps;part of;Western Alps

    # Bad triples
    Belgium;and;the Netherlands
    Belgium and the Netherlands;refer to an institution like a German Fachhochschule as;hogeschool

    Output 'Yes' if the triplets are in correct format, otherwise 'No'.
    """

    triples: str = dspy.InputField(desc="The extracted entity-relation-entity triples")
    correct: str = dspy.OutputField(desc="Are the triples in correct format? [Yes/No]", prefix="[Yes/No]:")


jerx_quality_judge = dspy.Predict(JERXQualityJudge)

# Updated evaluation function using the judge
def evaluate_jerx_llm(triples):
    result = jerx_quality_judge(triples=triples)
    return bool(result.correct.strip().lower() == "yes")

In [5]:
class JERX(dspy.Signature):
    """Extract triples relevant to the question from the given context."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    triples: list[tuple[str, str, str]] = dspy.OutputField(desc="List of triples (subject, predicate, object)")


class QA(dspy.Signature):
    """Answer the question based on the given triples."""

    triples: str = dspy.InputField(desc="List of triples (subject, predicate, object)")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


def validate_triple_format(triple):
    return len(triple) == 3

def validate_number_of_triples(triples, max_n_triples: int):
    if isinstance(triples, str):
        triples = triples.split("\n")
    return len(triples) < 8

class ConnectTheEntities(dspy.Module):
    def __init__(self, max_n_triples=8):
        super().__init__()
        self._jerx = dspy.Predict(JERX)
        self._qa = dspy.Predict(QA)
        self.max_n_triples = max_n_triples

    def forward(self, context, question):
        triple_list = self._jerx(context=context, question=question).triples
        dspy.Suggest(
            all(validate_triple_format(triple) for triple in triple_list),
            "Triples must be in the format of (subject, predicate, object)",
            target_module=self._jerx,
        )
        dspy.Suggest(
            validate_number_of_triples(triple_list, self.max_n_triples),
            f"There must be max {self.max_n_triples} triples",
            target_module=self._jerx,
        )

        if isinstance(triple_list, list):
            triples = "\n".join(";".join(triple) for triple in triple_list)
        elif isinstance(triple_list, str):
            triples = triple_list
        else:
            raise ValueError("Unexpected type for triples")

        dspy.Suggest(
            evaluate_jerx_llm(triples),
            "The extracted triples are not in correct format",
            target_module=self._jerx,

        )
        pred = self._qa(triples=triples, question=question)
        return dspy.Prediction(triples=triples, answer=pred.answer)

In [6]:
def get_predict_cls(technique):
    if technique == "standard":
        return dspy.Predict
    elif technique == "cot":
        return dspy.ChainOfThought
    elif technique == "cte":
        return ConnectTheEntities
    else:
        raise ValueError(f"Unknown technique: {technique}")


def evaluate_answer(example, pred, trace=None):
    scores = compute_scores(pred.answer, example.answers)
    return scores["f1"]


def dynamic_import(module, name):
    import importlib

    return getattr(importlib.import_module(module), name)


def make_optimizer(optimizer_config: dict):
    cls = dynamic_import("dspy.teleprompt", optimizer_config["class"])
    kwargs = deepcopy(optimizer_config["params"])
    if optimizer_config["with_metric"]:
        kwargs["metric"] = evaluate_answer
    return cls(**kwargs)


def preprocess_result(result):
    example, pred, score = result
    predictions = {f"predicted_{k}": v for k, v in dict(pred).items()}
    return {**dict(example), **predictions, "score": float(score)}


def make_results_dataframe(results):
    dataf = pd.json_normalize([preprocess_result(result) for result in results])
    dataf["n_hops"] = dataf["question_decomposition"].apply(len)
    dataf['predicted_answer'] = dataf['predicted_answer'].fillna("No Answer")
    return compute_scores_dataframe(dataf)


def train_main(
    dataset_path: str = typer.Option(..., help="Path to the dataset"),
    dataset_name: str = typer.Option(..., help="Name of the dataset"),
    dataset_split: str = typer.Option(..., help="Dataset split to use (e.g., 'train', 'validation')"),
    model: str = typer.Option(..., help="Name of the model to use"),
    temperature: float = typer.Option(..., help="Temperature parameter for the model"),
    technique: str = typer.Option(..., help="Prompting technique to use"),
    load_from: str = typer.Option(default="UNSET", help="Path to a saved model to load"),
    optimizer_path: Path = typer.Option(..., help="Path to the optimizer config"),
):
    # Set up LLM
    configure_lm(model, temperature)

    # Load and preprocess datasets
    ds = load_dataset(dataset_path, dataset_name, split=dataset_split)
    examples = [make_example(record) for record in ds]
    print(f"Loaded {len(examples)} examples")

    # Create the program
    program = ConnectTheEntities()
    if load_from and load_from != "UNSET":
        print(f"Loading model from {load_from}")
        program.load(load_from)

    # Train the program
    with open(optimizer_path) as f:
        optimizer_config = json.load(f)

    if optimizer_config:
        optimizer = make_optimizer(optimizer_config)
        compile_params = optimizer_config.get("compile_params", {})
        trained_program = optimizer.compile(program, trainset=examples, **compile_params)
    else:
        trained_program = program

    # Save the trained program
    return trained_program


def evaluate_main(
    dataset_path: str = typer.Option(..., help="Path to the dataset"),
    dataset_name: str = typer.Option(..., help="Name of the dataset"),
    dataset_split: str = typer.Option(..., help="Dataset split to use (e.g., 'train', 'validation')"),
    model: str = typer.Option(..., help="Name of the model to use"),
    temperature: float = typer.Option(..., help="Temperature parameter for the model"),
    technique: str = typer.Option(..., help="Prompting technique to use"),
    program = None,
):
    # Set up LLM
    configure_lm(model, temperature)

    # Load and preprocess datasets
    ds = load_dataset(dataset_path, dataset_name, split=dataset_split)
    examples = [make_example(record) for record in ds]
    print(f"Loaded {len(examples)} examples")

    # Create the program
    if program is None:
        program = QAModule(predict_cls=get_predict_cls(technique))

    # Evaluate the program
    evaluate_program = Evaluate(
        metric=evaluate_answer,
        devset=examples,
        num_threads=16,
        display_progress=True,
        return_outputs=True,
    )
    _, results = evaluate_program(program)

    # Save the results
    result_df = make_results_dataframe(results)

    # Save the scores
    scores = aggregate_scores(result_df)
    for n_hops in result_df["n_hops"].unique():
        scores[f"{n_hops}hops"] = aggregate_scores(result_df[result_df["n_hops"] == n_hops])


    return result_df, scores

In [7]:
# trained_program = train_main(
#     dataset_path="bdsaglam/musique-mini",
#     dataset_name="answerable",
#     dataset_split="train",
#     model="llama-3-70b-tgi",
#     temperature=0.1,
#     technique="cte",
#     optimizer_path=Path("bfsrs-medium.json"),
#     load_from="UNSET",
# )

In [8]:
dataset_path="bdsaglam/musique-mini"
dataset_name="answerable"
dataset_split="train"
model="llama-3-70b-tgi"
temperature=0.1
technique="cte"
optimizer_path=Path("bfsrs-medium.json")
load_from="UNSET"

# Set up LLM
configure_lm(model, temperature)

# Load and preprocess datasets
ds = load_dataset(dataset_path, dataset_name, split=dataset_split)
examples = [make_example(record) for record in ds]
print(f"Loaded {len(examples)} examples")

# Create the program
program = ConnectTheEntities()


In [9]:
example = examples[0]
print(example.context)
print(example.question)
pred = program.forward(context=example.context, question=example.question)
pred

Prediction(
    triples='Missouri;average summer temperature;24 ° C\nMissouri;average summer temperature;75 ° F\nCenterpoint Medical Center;location;Independence, Missouri',
    answer='24 ° C or 75 ° F'
)

In [10]:
# Train the program
optimizer_config = json.load(optimizer_path.open())
optimizer = make_optimizer(optimizer_config)
compile_params = optimizer_config.get("compile_params", {})
trained_program = optimizer.compile(program, trainset=examples, **compile_params)

Going to sample between 1 and 4 traces per predictor.
Will attempt to bootstrap 16 candidate sets.


  0%|          | 0/300 [00:00<?, ?it/s][2m2024-10-27T18:07:55.345161Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 The extracted triples are not in correct format. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 0.0 / 1  (0.0):   0%|          | 1/300 [00:27<2:15:05, 27.11s/it][2m2024-10-27T18:07:55.550668Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 The extracted triples are not in correct format. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 23.03333333333333 / 32  (72.0):  11%|█         | 32/300 [01:29<09:09,  2.05s/it] [2m2024-10-27T18:09:00.307678Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 The extracted triples are not in correct format. 

ValueError: Expected dict_keys(['correct']) but got dict_keys([])

In [None]:
trained_program.save("qa-cte-two-step-program.json")