In [15]:
import json
import os
from copy import deepcopy
from pathlib import Path

import dspy
import pandas as pd
import typer
from bellem.musique.eval import (
    aggregate_scores,
    compute_scores,
    compute_scores_dataframe,
)
from bellem.utils import set_seed
from datasets import load_dataset
from dotenv import load_dotenv
from dspy.evaluate import Evaluate
from dspy.teleprompt.ensemble import Ensemble
from rich.console import Console

print = Console(stderr=True).print

load_dotenv()

set_seed(89)

In [None]:
import weave
weave.init(project_name="mhqa-dspy")

In [16]:
# import mlflow

# mlflow.set_tracking_uri("http://127.0.0.1:5000/")
# mlflow.set_experiment("mhqa-dspy")
# mlflow.dspy.autolog()

In [17]:
def configure_lm(model, temperature):
    lm = dspy.LM(
        "openai/" + model,
        temperature=temperature,
        cache=False,
        api_base=os.getenv("OPENAI_BASE_URL"),
        api_key=os.getenv("OPENAI_API_KEY"),
    )
    dspy.configure(lm=lm)


In [18]:
from rerankers import Reranker

ranker = Reranker(model_type="t5", model_name="unicamp-dl/mt5-base-mmarco-v2")


def retrieve(docs: list[dict], query: str, top_k: int = 3) -> list[dict]:
    """Reranker retriever implementation.

    Args:
        docs: List of documents to search in. Each document should be a dict with
                'idx' and 'text' fields.
        query: Query string to search for
        top_k: Number of documents to retrieve (default: 3)

    Returns:
        List of documents sorted by relevance score
    """
    # Extract text and ids from docs
    texts = [doc["text"] for doc in docs]
    ranking = ranker.rank(query=query, docs=texts, doc_ids=list(range(len(texts))))
    return [docs[result.doc_id] for result in ranking.results[:top_k]]

Loading T5Ranker model unicamp-dl/mt5-base-mmarco-v2 (this message can be suppressed by setting verbose=0)
No device set
Using device cpu
No dtype set
Using dtype torch.float32
Loading model unicamp-dl/mt5-base-mmarco-v2, this might take a while...
Using device cpu.
Using dtype torch.float32.
T5 true token set to ▁yes
T5 false token set to ▁no
Returning normalised scores...
Inputs template set to Query: {query} Document: {text} Relevant:


In [19]:
from mhqa.react import ReAct, RunContext


def format_paragraph(paragraph):
    text = paragraph["paragraph_text"]
    title = paragraph["title"]
    return f"# {title}\n{text}"


def make_example(record):
    docs = [{"text": format_paragraph(p), "idx": p["idx"]} for p in record["paragraphs"]]
    return dspy.Example(
        id=record["id"],
        question=record["question"],
        question_decomposition=record["question_decomposition"],
        docs=docs,
        answer=record["answer"],
        answers=[record["answer"], *record["answer_aliases"]],
    ).with_inputs("question", "context")


def search(query: str, run_context: RunContext) -> list[str]:
    """Find relevant documents for the query."""
    retrieved_docs = retrieve(run_context.input["docs"], query, 3)
    return [x["text"] for x in retrieved_docs]


def make_program():
    return ReAct("question -> answer", tools=[search])


def evaluate_answer(example, pred, trace=None):
    scores = compute_scores(pred.answer, example.answers)
    return scores["f1"]


def dynamic_import(module, name):
    import importlib

    return getattr(importlib.import_module(module), name)


def make_optimizer(optimizer_config: dict):
    cls = dynamic_import("dspy.teleprompt", optimizer_config["class"])
    kwargs = deepcopy(optimizer_config["params"])
    if optimizer_config["with_metric"]:
        kwargs["metric"] = evaluate_answer
    return cls(**kwargs)


def preprocess_result(result):
    example, pred, score = result
    predictions = {f"predicted_{k}": v for k, v in dict(pred).items()}
    return {**dict(example), **predictions, "score": float(score)}


def make_results_dataframe(results):
    dataf = pd.json_normalize([preprocess_result(result) for result in results])
    dataf["n_hops"] = dataf["question_decomposition"].apply(len)
    dataf["predicted_answer"] = dataf["predicted_answer"].fillna("No Answer")
    return compute_scores_dataframe(dataf)


def train_main(
    dataset_path: str = typer.Option(..., help="Path to the dataset"),
    dataset_name: str = typer.Option(..., help="Name of the dataset"),
    dataset_split: str = typer.Option(..., help="Dataset split to use (e.g., 'train', 'validation')"),
    model: str = typer.Option(..., help="Name of the model to use"),
    temperature: float = typer.Option(..., help="Temperature parameter for the model"),
    load_from: str = typer.Option(default="UNSET", help="Path to a saved model to load"),
    optimizer_path: Path = typer.Option(..., help="Path to the optimizer config"),
    ensemble: str = typer.Option("no", help="Whether to use an ensemble of models"),
    out: Path = typer.Option(..., help="Output file for trained program"),
):
    out.parent.mkdir(parents=True, exist_ok=True)

    # Set up LLM
    configure_lm(model, temperature)

    # Load and preprocess datasets
    ds = load_dataset(dataset_path, dataset_name, split=dataset_split)
    examples = [make_example(record) for record in ds]
    print(f"Loaded {len(examples)} examples")

    # Create the program
    program = make_program()
    if load_from and load_from != "UNSET":
        print(f"Loading model from {load_from}")
        program.load(load_from)

    # Train the program
    with open(optimizer_path) as f:
        optimizer_config = json.load(f)

    if optimizer_config:
        optimizer = make_optimizer(optimizer_config)
        compile_params = optimizer_config.get("compile_params", {})
        trained_program = optimizer.compile(program, trainset=examples, **compile_params)
    else:
        trained_program = program

    if ensemble == "yes":
        ensemble_optimizer = Ensemble(reduce_fn=dspy.majority)
        candidate_programs = [x[-1] for x in trained_program.candidate_programs]
        trained_program = ensemble_optimizer.compile(candidate_programs)

    # Save the trained program
    trained_program.save(out)

    return trained_program

def evaluate_main(
    dataset_path: str = typer.Option(..., help="Path to the dataset"),
    dataset_name: str = typer.Option(..., help="Name of the dataset"),
    dataset_split: str = typer.Option(..., help="Dataset split to use (e.g., 'train', 'validation')"),
    model: str = typer.Option(..., help="Name of the model to use"),
    temperature: float = typer.Option(..., help="Temperature parameter for the model"),
    load_from: str = typer.Option(default="UNSET", help="Path to a saved model to load"),
    out: Path = typer.Option(..., help="Output directory for generated results"),
):
    out.mkdir(parents=True, exist_ok=True)

    # Set up LLM
    configure_lm(model, temperature)

    # Load and preprocess datasets
    ds = load_dataset(dataset_path, dataset_name, split=dataset_split)
    examples = [make_example(record) for record in ds]
    print(f"Loaded {len(examples)} examples")

    # Create the program
    program = make_program()
    if load_from and load_from != "UNSET":
        print(f"Loading model from {load_from}")
        program.load(load_from)

    # Evaluate the program
    evaluate_program = Evaluate(
        metric=evaluate_answer,
        devset=examples,
        num_threads=1,
        display_progress=True,
        return_outputs=True,
    )
    _, results = evaluate_program(program)

    # Save the results
    result_df = make_results_dataframe(results)
    result_df.to_json(out / "results.jsonl", orient="records", lines=True)

    # Save the scores
    scores = aggregate_scores(result_df)
    for n_hops in result_df["n_hops"].unique():
        scores[f"{n_hops}hops"] = aggregate_scores(result_df[result_df["n_hops"] == n_hops])

    with open(out / "scores.json", "w") as f:
        json.dump(scores, f, indent=2)


In [20]:
model='meta-llama/Llama-3.3-70B-Instruct-Turbo'
# model='llama3.1:8b-instruct-q8_0'
# model='llama-3.1-8b-instant'
# model='gemini-2.0-flash-lite-preview-02-05'

In [21]:
out = Path('out')

trained_program = train_main(
    dataset_path='bdsaglam/musique-mini',
    dataset_name='answerable',
    dataset_split='train',
    model=model,
    temperature=0.1,
    load_from='UNSET',
    optimizer_path='../data/raw/optimizer-configs/bfsrs-medium.json',
    out=out,
)

Going to sample between 1 and 8 traces per predictor.
Will attempt to bootstrap 16 candidate sets.
Average Metric: 0.00 / 0 (0%):   1%|          | 2/300 [06:55<17:12:49, 207.95s/it]t/s]
Average Metric: 50.05 / 300 (16.7%): 100%|██████████| 300/300 [03:23<00:00,  1.47it/s]

2025/02/10 21:44:20 INFO dspy.evaluate.evaluate: Average Metric: 50.0466085981505 / 300 (16.7%)



New best score: 16.68 for seed -3
Scores so far: [16.68]
Best score so far: 16.68
Average Metric: 34.99 / 184 (19.0%):  61%|██████▏   | 184/300 [01:57<01:21,  1.42it/s]

2025/02/10 21:46:18 ERROR dspy.utils.parallelizer: Error processing item Example({'id': '3hop1__643047_859552_846191', 'question': "In which country does the child of Silverado's director hold citizenship?", 'question_decomposition': [{'answer': 'Lawrence Kasdan', 'id': 643047, 'paragraph_support_idx': 11, 'question': 'Silverado >> director'}, {'answer': 'Jake Kasdan', 'id': 859552, 'paragraph_support_idx': 18, 'question': '#1 >> child'}, {'answer': 'America', 'id': 846191, 'paragraph_support_idx': 1, 'question': '#2 >> country of citizenship'}], 'docs': [{'text': '# Child labour\nIn addition to setting the international law, the United Nations initiated International Program on the Elimination of Child Labour (IPEC) in 1992. This initiative aims to progressively eliminate child labour through strengthening national capacities to address some of the causes of child labour. Amongst the key initiative is the so-called time-bounded programme countries, where child labour is most prevalent

Average Metric: 53.86 / 299 (18.0%): 100%|██████████| 300/300 [03:22<00:00,  1.48it/s]

2025/02/10 21:47:43 INFO dspy.evaluate.evaluate: Average Metric: 53.86186068485928 / 300 (18.0%)



New best score: 17.95 for seed -2
Scores so far: [16.68, 17.95]
Best score so far: 17.95


  2%|▏         | 5/300 [05:26<5:20:37, 65.21s/it] 


KeyboardInterrupt: 