# ImpPres with LLM

You have to implement in this notebook a better ImpPres classifier using an LLM.
This classifier must be implemented using DSPy.


In [2]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

with open("grok_key.ini") as f:
        for line in f:
            if "XAI_API_KEY" in line and not line.strip().startswith("#"):
                key_value = line.strip().split("=")
                if len(key_value) == 2:
                    os.environ["XAI_API_KEY"] = key_value[1].split()[0]

with open("gemini_key.ini") as f:
        for line in f:
            if "GEMINI_API_KEY" in line and not line.strip().startswith("#"):
                key_value = line.strip().split("=")
                if len(key_value) == 2:
                    os.environ["GEMINI_API_KEY"] = key_value[1].split()[0]

In [3]:

lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [75]:
from typing import Literal

## Implement the DSPy classifier program.

#Basic label signature
class anli_classification_signature(dspy.Signature):

    """Label the relationship between given premise and hypothesis."""
    
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()

    label: Literal['entailment', 'contradiction', 'neutral'] = dspy.OutputField()

#Paradigm signature
class paradigm_signature(dspy.Signature):

    """Label the relationship between 19 transformations of a ( premise , hypothesis ) pair."""

    premises: list[str] = dspy.InputField()
    hypotheses: list[str] = dspy.InputField()
    output_labels: list[Literal['entailment', 'contradiction', 'neutral']] = dspy.OutputField()

#Using CoT
label_prompt = dspy.ChainOfThought(anli_classification_signature)

#Creating a modul to predict a complete paradigm
class paradigm_module(dspy.Module):
    def __init__(self):
        super().__init__()
        self.signature = paradigm_signature
        #using a single prompt per pair
        self.label_prompt = label_prompt

    #packaging all predictions into a list    
    def forward(self, premises, hypotheses, **kwargs) -> dspy.Prediction:
        
        output_labels = []

        for p, h in zip(premises, hypotheses):
            result = self.label_prompt(premise=p, hypothesis=h)  # joint_prompt is a DSPy module
            output_labels.append(result.label)
        
        return dspy.Prediction({"output_labels": output_labels})

label_map = {
    "entailment": 0,
    "neutral": 1,
    "contradiction": 2
}

#reward based on paradigm accuracy
def paradigm_reward(data, pred: dspy.Prediction):
    #print(f"type(pred.labels): {type(pred.labels)}")
    print(f"pred.labels: {pred.output_labels}")
    golds = data["gold_labels"]
    preds = [label_map[item] for item in pred.output_labels]
    acc = sum(p == g for p, g in zip(preds, golds)) / len(golds)
    return acc

paradigm_refine = dspy.Refine(
    module=paradigm_module(), 
    N=3, reward_fn=paradigm_reward, 
    threshold=0.7)



## Load ImpPres Dataset

In [5]:
from datasets import load_dataset

sections = ['presupposition_all_n_presupposition', 
            'presupposition_both_presupposition', 
            'presupposition_change_of_state', 
            'presupposition_cleft_existence', 
            'presupposition_cleft_uniqueness', 
            'presupposition_only_presupposition', 
            'presupposition_possessed_definites_existence', 
            'presupposition_possessed_definites_uniqueness', 
            'presupposition_question_presupposition']

dataset = {}
for section in sections:
    print(f"Loading dataset for section: {section}")
    dataset[section] = load_dataset("facebook/imppres", section)

Loading dataset for section: presupposition_all_n_presupposition
Loading dataset for section: presupposition_both_presupposition
Loading dataset for section: presupposition_change_of_state
Loading dataset for section: presupposition_cleft_existence
Loading dataset for section: presupposition_cleft_uniqueness
Loading dataset for section: presupposition_only_presupposition
Loading dataset for section: presupposition_possessed_definites_existence
Loading dataset for section: presupposition_possessed_definites_uniqueness
Loading dataset for section: presupposition_question_presupposition


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [7]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [8]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

In [79]:
list(dataset["presupposition_change_of_state"]["change_of_state"])[:19]

[{'premise': 'The guest had found John.',
  'hypothesis': 'John used to be in an unknown location.',
  'trigger': 'unembedded',
  'trigger1': 'Not_In_Example',
  'trigger2': 'Not_In_Example',
  'presupposition': 'positive',
  'gold_label': 0,
  'UID': 'change_of_state',
  'pairID': '0e',
  'paradigmID': 0},
 {'premise': 'The guest had found John.',
  'hypothesis': "John didn't used to be in an unknown location.",
  'trigger': 'unembedded',
  'trigger1': 'Not_In_Example',
  'trigger2': 'Not_In_Example',
  'presupposition': 'negated',
  'gold_label': 2,
  'UID': 'change_of_state',
  'pairID': '1c',
  'paradigmID': 0},
 {'premise': 'The guest had found John.',
  'hypothesis': 'Peter used to be in an unknown location.',
  'trigger': 'unembedded',
  'trigger1': 'Not_In_Example',
  'trigger2': 'Not_In_Example',
  'presupposition': 'neutral',
  'gold_label': 1,
  'UID': 'change_of_state',
  'pairID': '2n',
  'paradigmID': 0},
 {'premise': "The guest hadn't found John.",
  'hypothesis': 'John 

In [81]:
premises = [entry['premise'] for entry in list(dataset["presupposition_change_of_state"]["change_of_state"])[:19]]
hypotheses = [entry['hypothesis'] for entry in list(dataset["presupposition_change_of_state"]["change_of_state"])[:19]]
gold_labels = [entry['gold_label'] for entry in list(dataset["presupposition_change_of_state"]["change_of_state"])[:19]]
prediction = paradigm_refine(**{"premises": premises, "hypotheses": hypotheses, "gold_labels": gold_labels })


pred.labels: ['neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral']
pred.labels: ['entailment', 'contradiction', 'neutral', 'entailment', 'contradiction', 'neutral', 'entailment', 'contradiction', 'neutral', 'entailment', 'contradiction', 'neutral', 'entailment', 'contradiction', 'neutral', 'contradiction', 'entailment', 'neutral', 'neutral']


In [82]:

predictions = [label_map[item] for item in prediction.output_labels]
references = [item for item in gold_labels]


print(precision.compute(predictions= predictions, references= references, average="weighted"))

{'precision': 0.9561403508771931}


In [107]:
import json
with open("imppres_dataset.json", "r") as f:
    merged_dataset = json.load(f)

In [92]:
import random
from collections import defaultdict

# Assuming merged_dataset is a list or iterable of all entries from all sections,
# each entry has a 'section' key and a 'paradigmID' key.

all_results = []

section_groups = defaultdict(list)
for entry in merged_dataset:
    section_groups[entry["section"]].append(entry)

for section, entries_in_section in section_groups.items():
    print(f"Processing section: {section}")

    paradigm_groups = defaultdict(list)
    for entry in entries_in_section:
        paradigm_groups[entry["paradigmID"]].append(entry)

    # Select the first 7 paradigms sorted by paradigmID
    selected_paradigm_ids = sorted(paradigm_groups.keys())[:7]

    for paradigm_id in selected_paradigm_ids:
        print(f"Processing paradigm: {paradigm_id}")
        paradigm_entries = paradigm_groups[paradigm_id]
        random.shuffle(paradigm_entries)

        premises = [e["premise"] for e in paradigm_entries]
        hypotheses = [e["hypothesis"] for e in paradigm_entries]
        gold_labels = [e["gold_label"] for e in paradigm_entries]

        input_batch = {
            "premises": premises,
            "hypotheses": hypotheses,
            "gold_labels": gold_labels
        }

        prediction = paradigm_refine(**input_batch)

        predicted_labels = prediction.output_labels

        for e, pred_label in zip(paradigm_entries, predicted_labels):
            result_entry = dict(e)
            result_entry["pred_label"] = pred_label
            all_results.append(result_entry)
        
        print(f"Paradigm results:")

        predictions = [label_map[item] for item in predicted_labels]
        references = [item["gold_label"] for item in paradigm_entries]

        print(precision.compute(predictions= predictions, references= references, average="weighted"))
        

print(f"Total results collected: {len(all_results)}")

Processing section: presupposition_all_n_presupposition
Processing paradigm: 0
pred.labels: ['neutral', 'neutral', 'contradiction', 'entailment', 'neutral', 'neutral', 'entailment', 'contradiction', 'neutral', 'contradiction', 'entailment', 'contradiction', 'contradiction', 'entailment', 'neutral', 'contradiction', 'neutral', 'entailment', 'neutral']
Paradigm results:
{'precision': 1.0}
Processing paradigm: 1
pred.labels: ['contradiction', 'neutral', 'neutral', 'contradiction', 'contradiction', 'contradiction', 'neutral', 'entailment', 'neutral', 'neutral', 'entailment', 'contradiction', 'neutral', 'entailment', 'entailment', 'neutral', 'contradiction', 'neutral', 'neutral']
Paradigm results:
{'precision': 0.9532163742690059}
Processing paradigm: 2
pred.labels: ['entailment', 'neutral', 'neutral', 'contradiction', 'neutral', 'contradiction', 'entailment', 'entailment', 'contradiction', 'neutral', 'neutral', 'neutral', 'entailment', 'contradiction', 'entailment', 'neutral', 'contradicti

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


pred.labels: ['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral']
pred.labels: ['neutral', 'neutral', 'contradiction', 'entailment', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral']
pred.labels: ['neutral', 'neutral', 'contradiction', 'entailment', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral']
Paradigm results:
{'precision': 0.8035087719298245}
Processing paradigm: 3
pred.labels: ['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutra

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


pred.labels: ['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral']
pred.labels: ['neutral', 'entailment', 'neutral', 'neutral', 'neutral', 'neutral', 'entailment', 'contradiction', 'contradiction', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'entailment', 'contradiction', 'contradiction', 'neutral', 'contradiction']




pred.labels: ['neutral', 'entailment', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'contradiction', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'entailment', 'contradiction', 'neutral', 'contradiction', 'neutral']
Paradigm results:
{'precision': 0.5029239766081871}
Processing paradigm: 5


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


pred.labels: ['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral']
pred.labels: ['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'entailment', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'contradiction', 'neutral']
pred.labels: ['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'entailment', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral']
Paradigm results:
{'precision': 0.7894736842105263}
Processing paradigm: 6
pred.labels: ['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral



pred.labels: ['neutral', 'neutral', 'neutral', 'entailment', 'contradiction', 'neutral', 'neutral', 'contradiction', 'neutral', 'entailment', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'entailment', 'entailment', 'neutral']
Paradigm results:
{'precision': 0.7587719298245614}
Processing section: presupposition_only_presupposition
Processing paradigm: 0
pred.labels: ['neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'entailment', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'entailment', 'contradiction', 'neutral']
Paradigm results:
{'precision': 0.8380566801619433}
Processing paradigm: 1
pred.labels: ['contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'entailment', 'entailment', 'neutral', 'contradiction', 'contradiction', 'neutral', 'neutral', 'neutral', 'contradiction', 'contradiction', 'neutral', 'neutral']
Paradigm results:
{'prec



pred.labels: ['contradiction', 'contradiction', 'contradiction', 'contradiction', 'contradiction', 'neutral', 'neutral', 'contradiction', 'contradiction', 'neutral', 'contradiction', 'neutral', 'entailment', 'neutral', 'entailment', 'contradiction', 'neutral', 'contradiction', 'neutral']
Paradigm results:
{'precision': 0.7033492822966507}
Processing paradigm: 3
pred.labels: ['contradiction', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'entailment', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'entailment', 'neutral', 'neutral', 'neutral']
pred.labels: ['neutral', 'neutral', 'contradiction', 'contradiction', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'entailment', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'entailment', 'neutral', 'neutral', 'neutral']
Paradigm results:
{'precision': 0.8380566801619433}
Processing paradigm: 4
pred.labels: ['neutral', 'neutral', 'neutral', 'neutral

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


pred.labels: ['neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral']
pred.labels: ['entailment', 'neutral', 'neutral', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'entailment', 'entailment', 'entailment', 'neutral', 'entailment', 'neutral', 'entailment', 'contradiction', 'entailment', 'entailment', 'neutral']
pred.labels: ['neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral']
Paradigm results:
{'precision': 0.7141812865497076}
Processing section: presupposition_question_presupposition
Processing paradigm: 0
pred.labels: ['contradiction', 'neutral', 'contradiction', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'contradiction', 'contradiction', '

In [106]:
import pandas as pd

# Convert results to a DataFrame for easy grouping
df = pd.DataFrame(all_results)

# Map labels to integers
label_map = {
    "entailment": 0,
    "neutral": 1,
    "contradiction": 2
}
df["gold_int"] = df["gold_label"]
df["pred_int"] = df["pred_label"].map(label_map)

# ---- 1. Accuracy per section ----
section_perf = (
    df.groupby("section")
      .apply(lambda g: (g["gold_int"] == g["pred_int"]).mean(), include_groups=False)
      .reset_index(name="accuracy")
)

transform_perf = (
    df.groupby(["trigger", "trigger1", "trigger2", "presupposition"])
      .agg(
          accuracy=("gold_int", lambda x: (x == df.loc[x.index, "pred_int"]).mean()),
          n_examples=("gold_int", "size")
      )
      .reset_index()
      .sort_values("accuracy", ascending=False)
)

# Filter rows where trigger1 == 'negated'
negated_df = df[df["trigger1"] == "negated"]

# Now group by the other columns, e.g., trigger, trigger2, presupposition
negated_perf = (
    negated_df.groupby(["trigger", "trigger2", "presupposition"])
        .agg(
            accuracy=("gold_int", lambda x: (x == negated_df.loc[x.index, "pred_int"]).mean()),
            n_examples=("gold_int", "size")
        )
        .reset_index()
        .sort_values("accuracy", ascending=False)
)

print(negated_perf)
print("overall accuracy:")
print(precision.compute(predictions= df["pred_int"], references= df["gold_int"], average="weighted", zero_division = 0))
print("=== Accuracy per Section ===")
print(section_perf)
print(transform_perf)



Empty DataFrame
Columns: [trigger, trigger2, presupposition, accuracy, n_examples]
Index: []
overall accuracy:
{'precision': 0.8232283709539563}
=== Accuracy per Section ===
                                         section  accuracy
0            presupposition_all_n_presupposition  0.932331
1             presupposition_both_presupposition  0.977444
2                 presupposition_change_of_state  0.714286
3                 presupposition_cleft_existence  0.774436
4                presupposition_cleft_uniqueness  0.548872
5             presupposition_only_presupposition  0.714286
6   presupposition_possessed_definites_existence  0.939850
7  presupposition_possessed_definites_uniqueness  0.661654
8         presupposition_question_presupposition  0.887218
           trigger        trigger1        trigger2  presupposition  accuracy  \
11         negated  Not_In_Example  Not_In_Example         neutral  0.968254   
2      conditional  Not_In_Example  Not_In_Example         neutral  0.952381

It seems like I made a mistake in imppres.ipynb and entered trigger2 data into trigger1 key, which is why transformation 0 has 252 = 4 * 7 * 9 entries.

I've fixed it now but to apply it here would mean circumventing the randomization somehow.

I don't trust the method in which I've refined this prompt. 

The process simply tries the whole paradigm again with a different temperture if group average is low enough.

I've had trouble finding a way to apply DSPy refinement to a single prompt based on the score of a group of 19.