# ImpPres LLM Baseline

You have to implement in this notebook a baseline for ImpPres classification using an LLM.
This baseline must be implemented using DSPy.



In [None]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
from datetime import datetime

import dspy
import pandas as pd

lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])

# for ollama
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
# lm = dspy.LM(
#     "ollama/llama3.1:8b",
#     api_base="http://localhost:11434",
#     format="json"        # litellm translates this to Ollama's stream=false
# )
dspy.configure(lm=lm)

In [None]:
import logging
logging.getLogger("dspy.adapters.json_adapter").setLevel(logging.ERROR)

In [None]:
from typing import Literal

## Implement the DSPy program to classify pairs (premise, hypothesis) as entailment, contradiction, or neutral.
class NLIImPresClassifier(dspy.Signature):
    premise     :str = dspy.InputField(desc="A short passage or statement. All facts should be inferred from this text alone.")
    hypothesis  :str = dspy.InputField(desc="A second statement to evaluate. Check if this follows from, contradicts, or is unrelated to the premise.")
    label       : Literal["entailment", "neutral", "contradiction"] = dspy.OutputField(
        desc=(
            "Return one of: 'entailment', 'neutral', or 'contradiction'.\n"
            "- 'entailment': The hypothesis must be true if the premise is true.\n"
            "- 'contradiction': The hypothesis must be false if the premise is true.\n"
            "- 'neutral': The hypothesis could be either true or false based on the premise."
        )
    )

predictor = dspy.Predict(NLIImPresClassifier)
label_names = ["entailment", "neutral", "contradiction"]
def zero_shot_nli_classifier(x):
    return {
        'premise' : x['premise'],
        'hypothesis': x['hypothesis'],
        'pred_label' : predictor(premise=x['premise'], hypothesis=x['hypothesis']).label,
        'gold_label' : label_names[x['gold_label']]
    }

## Load ImpPres dataset

In [None]:
from datasets import load_dataset

sections = ['presupposition_all_n_presupposition',
            'presupposition_both_presupposition',
            'presupposition_change_of_state',
            'presupposition_cleft_existence',
            'presupposition_cleft_uniqueness',
            'presupposition_only_presupposition',
            'presupposition_possessed_definites_existence',
            'presupposition_possessed_definites_uniqueness',
            'presupposition_question_presupposition']

dataset = {}
for section in sections:
    print(f"Loading dataset for section: {section}")
    dataset[section] = load_dataset("facebook/imppres", section)

In [None]:
dataset

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [None]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

We will first run the dspy classifier through the dataset:

In [None]:
def accuracy_metric(example, pred, *args):
     return pred.label == example.label

In [None]:
import pandas as pd
# Convert to DSPy Example objects
dspy_examples = {}
for section_name, section in dataset.items():
    key = next(iter(section.keys()))
    ds = section[key]
    dspy_examples[section_name] = [
        dspy.Example(
            premise=ex['premise'],
            hypothesis=ex['hypothesis'],
            label=label_names[ex['gold_label']]
        ).with_inputs("premise", "hypothesis")
        for ex in ds
    ]

df = pd.DataFrame(dspy_examples)
display(df)

In [None]:
from dspy.evaluate import Evaluate
from evaluate import combine, load

# 1. Run DSPy evaluation for each section (here, limited to first 10 for demo)
results = {}  # Store per-section predictions
not_predicted = {}
for sec in dspy_examples:
    print(f"Evaluating section:\t{sec}")
    evaluator = Evaluate(
        devset=dspy_examples[sec],
        metric=accuracy_metric,
        return_outputs=True,
        num_threads=50,
        display_progress=True,
        display_table=False,
        provide_traceback=False
        # max_errors=30
    )
    eval_res = evaluator(predictor)
    _, result_tuples = eval_res
    print(f"number of results:\t{len(result_tuples)}")
    preds, refs = [], []
    not_predicted[sec] = {
        'section':sec,
        'num_not_predicted':0,
        'not_predicted':[]
    }
    for example, prediction, correct in result_tuples:
        if not hasattr(prediction, "label"):
            not_predicted[sec]['num_not_predicted']+=1
            not_predicted[sec]['not_predicted'].append((example, prediction, correct))
            continue
        preds.append(prediction.label)
        refs.append(example.label)
    results[sec] = {"preds": preds, "refs": refs}

Let's display some statistics about the results

In [None]:
from collections import Counter

for sec, data in results.items():
    preds = data['preds']
    refs = data['refs']
    print(f"Section: {sec}")
    print(f"  Total predictions: {len(preds)}")
    print(f"  Total references:  {len(refs)}")
    print(f"  Class distribution in predictions: {Counter(preds)}")
    print(f"  Class distribution in references:  {Counter(refs)}")
    agree = sum([p == r for p, r in zip(preds, refs)])
    print(f"  Number of matches (agreement): {agree}")
    print(f"  Accuracy (quick): {agree / len(refs):.3f}")
    print()

# Overall stats
all_preds = sum([v['preds'] for v in results.values()], [])
all_refs  = sum([v['refs']  for v in results.values()], [])
print("=== OVERALL ===")
print(f"Total predictions: {len(all_preds)}")
print(f"Total references:  {len(all_refs)}")
print(f"Class distribution in predictions: {Counter(all_preds)}")
print(f"Class distribution in references:  {Counter(all_refs)}")
agree = sum([p == r for p, r in zip(all_preds, all_refs)])
print(f"Number of matches (agreement): {agree}")
print(f"Accuracy (quick): {agree / len(all_refs):.3f}")


We will now show information about non-predicted examples:

In [None]:
df_np = pd.DataFrame(list(not_predicted.values())).set_index("section")
exploded = df_np["not_predicted"].explode()
df_details = (
    exploded
    .reset_index()
    .rename(columns={"index": "section", "not_predicted": "detail"})
    .join(pd.json_normalize(exploded).add_prefix("detail."))
)
display(df_details)
for sec, info in not_predicted.items():
    print(f"=== Section: {sec} — {info['num_not_predicted']} failures ===")
    for ex, raw_out, score in info['not_predicted']:
        print(ex)
        premise, hypothesis, ref,= ex
        print(f"🎯 Ref label: {ex[ref]}")
        print(f"💬 Premise: {ex[premise]}")
        print(f"💬 Hypothesis: {ex[hypothesis]}")
        print(f"🛑 Raw output: {raw_out!r}")
        print(f"⚠️ Score: {score}")
        print("-" * 40)

In [None]:
# 2. Prepare for metric calculation
metric_prf = combine(["precision", "recall", "f1"])
acc = load("accuracy")
rows = []
all_preds, all_refs = [], []
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}

for sec, data in results.items():
    print(f"Computing metrics for section: {sec}")
    preds = [label2id[label] for label in data["preds"]]
    refs  = [label2id[label] for label in data["refs"]]
    prf = metric_prf.compute(predictions=preds, references=refs, average="weighted")
    accuracy = acc.compute(predictions=preds, references=refs)["accuracy"]

    rows.append({"section": sec, "accuracy": accuracy, **prf})
    all_preds += preds
    all_refs += refs

# 3. Compute overall metrics
overall_prf = metric_prf.compute(predictions=all_preds, references=all_refs, average="weighted")
overall_acc = acc.compute(predictions=all_preds, references=all_refs)["accuracy"]
rows.append({"section": "all", "accuracy": overall_acc, **overall_prf})

# Create DataFrame and display
df_metrics = pd.DataFrame(rows)
display(df_metrics.set_index("section"))

In our experiment we got the following results:
| section                                       | accuracy | precision | recall  | f1       |
|----------------------------------------------|----------|-----------|---------|----------|
| presupposition_all_n_presupposition          | 0.942632 | 0.949257  | 0.942632| 0.942783 |
| presupposition_both_presupposition           | 0.973158 | 0.974034  | 0.973158| 0.973184 |
| presupposition_change_of_state               | 0.557895 | 0.655905  | 0.557895| 0.493381 |
| presupposition_cleft_existence               | 0.686316 | 0.812531  | 0.686316| 0.669707 |
| presupposition_cleft_uniqueness              | 0.474211 | 0.503028  | 0.474211| 0.350207 |
| presupposition_only_presupposition           | 0.668947 | 0.778061  | 0.668947| 0.654415 |
| presupposition_possessed_definites_existence | 0.923158 | 0.929153  | 0.923158| 0.923322 |
| presupposition_possessed_definites_uniqueness| 0.475263 | 0.626211  | 0.475263| 0.352235 |
| presupposition_question_presupposition       | 0.841053 | 0.863356  | 0.841053| 0.838288 |
| all                                          | 0.726959 | 0.815532  | 0.726959| 0.717863 |

With a total F1 score of 0.726959 with grok-3-mini. Let's try to optimize the model


## Optimizing the model
we will first create a dev\test split:

In [None]:
import random

dev_data = {}
test_data = {}

for sec, examples in dspy_examples.items():
    random.shuffle(examples)
    n = len(examples)
    split_point = int(0.7 * n)  # e.g., 70% dev, 30% test

    dev_data[sec] = examples[:split_point]
    test_data[sec] = examples[split_point:]
display(pd.DataFrame(dev_data))
display(pd.DataFrame(test_data))

Let's try Few-shot example optimization.
We will try to optimize prompts separately for each section using few-shot example search.

In [None]:
from dspy.teleprompt import BootstrapFewShot
from datetime import datetime
optimized_pipelines = {}

for sec in dev_data:
    # Flatten dev examples for prompt tuning
    dev_set = dev_data[sec]

    # Initialize optimizer
    bs = BootstrapFewShot(
        metric=accuracy_metric,
        max_bootstrapped_demos=50,
        max_labeled_demos=10
    )

    # Compile and tune using dev split
    compiled = bs.compile(
        student=predictor,
        trainset=dev_set
    )
    optimized_pipelines[sec] = compiled
    print(f"✅ Completed Bootstrapped few-shot for section `{sec}`")

# existing section pipelines
pipelines = list(optimized_pipelines.values())
joint_predictor = dspy.BetterTogether(*pipelines)
joint_predictor.save(f"joint_predictor_state_{datetime.timestamp()}.pkl", save_program=False)

In [None]:
from dspy.evaluate import Evaluate

test_results = {}

for sec, examples in test_data.items():
    print(f"Evaluating on test section: {sec}")
    evaluator = Evaluate(
        devset=examples,
        metric=accuracy_metric,
        num_threads=20,
        display_progress=True,
        display_table=5,
        provide_traceback=True,
        max_errors=5,
    )
    result = evaluator(joint_predictor)
    test_results[sec] = result

In [None]:
metric_prf = combine(["precision", "recall", "f1"])
acc = load("accuracy")
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}

rows = []
all_preds, all_refs = [], []

for sec, res in test_results.items():
    print(f"Metrics for section: {sec}")
    preds = [label2id[p.label] for (_, p, _) in res.results]
    refs = [label2id[ex.label] for (ex, _, _) in res.results]

    prf = metric_prf.compute(predictions=preds, references=refs, average="weighted")
    accuracy = acc.compute(predictions=preds, references=refs)["accuracy"]

    rows.append({"section": sec, "accuracy": accuracy, **prf})
    all_preds += preds
    all_refs += refs

overall_prf = metric_prf.compute(predictions=all_preds, references=all_refs, average="weighted")
overall_acc = acc.compute(predictions=all_preds, references=all_refs)["accuracy"]
rows.append({"section": "all", "accuracy": overall_acc, **overall_prf})

df_metrics = pd.DataFrame(rows)
display(df_metrics.set_index("section"))