# ImpPres with LLM

You have to implement in this notebook a better ImpPres classifier using an LLM.
This classifier must be implemented using DSPy.


In [58]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
from dotenv import load_dotenv
import os
import dspy
load_dotenv("grok_key.ini") 
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'], max_tokens=8000, temperature=0.2)
# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [59]:
from typing import Literal
import json
import random
from collections import defaultdict, Counter
from dataclasses import dataclass

# Paradigm-level Signature: JSON in/out to handle batches of ~19 pairs
class ParadigmNLISignature(dspy.Signature):
    """
    Given a list of NLI pairs (premise, hypothesis), predict a label for each pair
    and provide a brief explanation per pair. Use one of: entailment, neutral, contradiction.
    The inputs and outputs are JSON-encoded lists to ensure deterministic parsing.
    """
    pairs_json: str = dspy.InputField()
    labels_json: str = dspy.OutputField()
    explanations_json: str = dspy.OutputField()

# Helper: label mapping
LABELS = ["entailment", "neutral", "contradiction"]
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
id2label = {v: k for k, v in label2id.items()}

# Utility: compute per-paradigm accuracy and consistency

def compute_paradigm_accuracy(gold_labels: list[int], pred_labels: list[int]) -> float:
    correct = sum(1 for g, p in zip(gold_labels, pred_labels) if g == p)
    return correct / len(gold_labels) if gold_labels else 0.0


def compute_paradigm_consistency(gold_labels: list[int], pred_labels: list[int]) -> float:
    """
    Majority-coherence within each gold label group:
    - For each gold label value present in the paradigm, find the majority predicted label inside that group.
    - Score for the group = fraction of items in that group that match the group's majority predicted label.
    - Return the average across gold groups.
    Range: [0, 1]. Equals 1.0 when predictions are consistent inside each gold cluster.
    """
    by_gold: defaultdict[int, list[int]] = defaultdict(list)
    for g, p in zip(gold_labels, pred_labels):
        by_gold[g].append(p)
    if not by_gold:
        return 0.0
    per_group_scores = []
    for group, preds in by_gold.items():
        if not preds:
            continue
        counts = Counter(preds)
        majority_pred, majority_count = counts.most_common(1)[0]
        per_group_scores.append(majority_count / len(preds))
    if not per_group_scores:
        return 0.0
    return sum(per_group_scores) / len(per_group_scores)


def combined_reward(gold_labels: list[int], pred_labels: list[int], alpha: float = 0.75) -> float:
    acc = compute_paradigm_accuracy(gold_labels, pred_labels)
    cons = compute_paradigm_consistency(gold_labels, pred_labels)
    return alpha * acc + (1 - alpha) * cons

# Parser utilities

def parse_labels_json(labels_json: str) -> list[int]:
    try:
        data = json.loads(labels_json)
        # Accept strings or ints
        parsed = []
        for x in data:
            if isinstance(x, int):
                parsed.append(x)
            else:
                parsed.append(label2id.get(str(x).strip().lower(), -1))
        return parsed
    except Exception:
        return []


def safe_json_dumps(obj) -> str:
    return json.dumps(obj, ensure_ascii=False)


## Load ImpPres Dataset

In [60]:
from datasets import load_dataset

sections = ['presupposition_all_n_presupposition', 
            'presupposition_both_presupposition', 
            'presupposition_change_of_state', 
            'presupposition_cleft_existence', 
            'presupposition_cleft_uniqueness', 
            'presupposition_only_presupposition', 
            'presupposition_possessed_definites_existence', 
            'presupposition_possessed_definites_uniqueness', 
            'presupposition_question_presupposition']

dataset = {}
for section in sections:
    print(f"Loading dataset for section: {section}")
    dataset[section] = load_dataset("facebook/imppres", section)

Loading dataset for section: presupposition_all_n_presupposition
Loading dataset for section: presupposition_both_presupposition
Loading dataset for section: presupposition_change_of_state
Loading dataset for section: presupposition_cleft_existence
Loading dataset for section: presupposition_cleft_uniqueness
Loading dataset for section: presupposition_only_presupposition
Loading dataset for section: presupposition_possessed_definites_existence
Loading dataset for section: presupposition_possessed_definites_uniqueness
Loading dataset for section: presupposition_question_presupposition


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [61]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [62]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

In [63]:
# Per-item signature used for fallback predictions
class NLISignature(dspy.Signature):
    """
    Classify the relationship between the premise and hypothesis to one of:
    entailment, neutral, contradiction. Provide a short explanation.
    """
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

In [64]:
# Build paradigm-grouped DSPy examples with shuffling
from datasets import load_dataset, Dataset

section_to_split = {
    'presupposition_all_n_presupposition': 'all_n_presupposition',
    'presupposition_both_presupposition': 'both_presupposition',
    'presupposition_change_of_state': 'change_of_state',
    'presupposition_cleft_existence': 'cleft_existence',
    'presupposition_cleft_uniqueness': 'cleft_uniqueness',
    'presupposition_only_presupposition': 'only_presupposition',
    'presupposition_possessed_definites_existence': 'possessed_definites_existence',
    'presupposition_possessed_definites_uniqueness': 'possessed_definites_uniqueness',
    'presupposition_question_presupposition': 'question_presupposition',
}

# Load data
raw = {}
for section, split in section_to_split.items():
    raw[section] = load_dataset("facebook/imppres", section)[split]

# Group items by paradigmID per section, and create one example per paradigm
paradigm_examples = []
section_index = defaultdict(list)

for section, ds in raw.items():
    # Build index for this section
    pid_to_rows = defaultdict(list)
    for row in ds:
        pid_to_rows[row["paradigmID"]].append(row)
    
    # Create a DSPy example per paradigmID
    for pid, rows in pid_to_rows.items():
        # Shuffle to avoid positional leakage
        rows_shuffled = rows[:]
        random.shuffle(rows_shuffled)
        
        pairs = [{"premise": r["premise"], "hypothesis": r["hypothesis"]} for r in rows_shuffled]
        gold = [int(r["gold_label"]) if isinstance(r["gold_label"], (int,)) else int(r["gold_label"]) for r in rows_shuffled]
        
        ex = dspy.Example(
            pairs_json=safe_json_dumps(pairs),
            labels_json=safe_json_dumps(gold),
            section=section,
            paradigm_id=str(pid),
        ).with_inputs("pairs_json")
        
        paradigm_examples.append(ex)
        section_index[section].append(len(paradigm_examples) - 1)

len(paradigm_examples)


900

In [65]:
# CoT Predictor over Paradigms

class ParadigmPredictor(dspy.Module):
    def __init__(self, max_items: int = 25):
        super().__init__()
        # Use ChainOfThought to elicit explanations
        self.predict = dspy.ChainOfThought(ParadigmNLISignature)
        self.max_items = max_items
    
    def forward(self, pairs: list[dict]) -> tuple[list[int], list[str]]:
        # Safety cap
        pairs = pairs[: self.max_items]
        prompt_pairs = [
            {
                "premise": p.get("premise", ""),
                "hypothesis": p.get("hypothesis", ""),
            }
            for p in pairs
        ]
        pairs_json = safe_json_dumps(prompt_pairs)
        pred = self.predict(pairs_json=pairs_json)
        labels = parse_labels_json(pred.labels_json)
        # Fallback: if parsing failed or wrong length, recover with per-item prompts
        if not labels or len(labels) != len(prompt_pairs):
            labels = []
            explanations = []
            single = dspy.ChainOfThought(NLISignature)
            for p in prompt_pairs:
                out = single(premise=p["premise"], hypothesis=p["hypothesis"])
                lab = label2id.get(str(out.label).strip().lower(), 1)
                labels.append(lab)
                explanations.append(str(getattr(out, "reasoning", "")))
            return labels, explanations
        # Parse explanations_json if present and aligned
        try:
            explanations = json.loads(getattr(pred, "explanations_json", "[]"))
            if not isinstance(explanations, list) or len(explanations) != len(labels):
                explanations = [""] * len(labels)
        except Exception:
            explanations = [""] * len(labels)
        return labels, explanations

# Instantiate a module wrapper for batch processing in evaluation
paradigm_predictor = ParadigmPredictor()


In [66]:
# Define optimization metric: per-paradigm accuracy + consistency

def metric_paradigm(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
    gold = parse_labels_json(example.labels_json)
    preds = parse_labels_json(getattr(pred, "labels_json", "[]"))
    if not gold or not preds or len(gold) != len(preds):
        return 0.0
    return combined_reward(gold, preds, alpha=0.75)

# Student module for DSPy compile: produces labels_json from pairs_json
class StudentParadigmModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.inner = dspy.ChainOfThought(ParadigmNLISignature)
    
    def forward(self, pairs_json: str) -> dspy.Prediction:
        pred = self.inner(pairs_json=pairs_json)
        # Ensure labels_json is present; if missing, attempt prompt-robust fallback
        labels = parse_labels_json(getattr(pred, "labels_json", "[]"))
        if not labels:
            try:
                # Last-resort JSON parse from any free-text field
                labels = parse_labels_json(getattr(pred, "labels_json", "[]"))
            except Exception:
                labels = []
        return dspy.Prediction(labels_json=safe_json_dumps(labels), explanations_json=getattr(pred, "explanations_json", "[]"))



In [67]:
# Compile with DSPy using a small train/eval split of paradigms
from dspy.teleprompt import COPRO

# Keep budget small; sample paradigms across sections
random.seed(123)
all_indices = list(range(len(paradigm_examples)))
random.shuffle(all_indices)
train_k = min(60, int(0.05 * len(all_indices)))  # ~5% of paradigms
trainset = [paradigm_examples[i] for i in all_indices[:train_k]]
evalset = [paradigm_examples[i] for i in all_indices[train_k:train_k + 2 * train_k]]

student = StudentParadigmModule()
tele = COPRO(metric=metric_paradigm, max_trials=4, depth=2, breadth=3)
optimized_student = tele.compile(student=student, trainset=trainset, eval_kwargs={})

optimized_student


2025/08/11 00:17:35 INFO dspy.teleprompt.copro_optimizer: Iteration Depth: 1/2.
2025/08/11 00:17:35 INFO dspy.teleprompt.copro_optimizer: At Depth 1/2, Evaluating Prompt Candidate #1/3 for Predictor 1 of 1.
2025/08/11 00:25:05 INFO dspy.evaluate.evaluate: Average Metric: 20.4281067251462 / 45 (45.4%)
2025/08/11 00:25:05 INFO dspy.teleprompt.copro_optimizer: At Depth 1/2, Evaluating Prompt Candidate #2/3 for Predictor 1 of 1.
2025/08/11 00:36:03 INFO dspy.evaluate.evaluate: Average Metric: 31.61483918128655 / 45 (70.3%)
2025/08/11 00:36:03 INFO dspy.teleprompt.copro_optimizer: At Depth 1/2, Evaluating Prompt Candidate #3/3 for Predictor 1 of 1.
2025/08/11 00:43:18 INFO dspy.evaluate.evaluate: Average Metric: 29.873062865497065 / 45 (66.4%)
2025/08/11 00:45:42 INFO dspy.teleprompt.copro_optimizer: Iteration Depth: 2/2.
2025/08/11 00:45:42 INFO dspy.teleprompt.copro_optimizer: At Depth 2/2, Evaluating Prompt Candidate #1/3 for Predictor 1 of 1.
2025/08/11 01:13:48 INFO dspy.evaluate.evalu

inner.predict = Predict(StringSignature(pairs_json -> reasoning, labels_json, explanations_json
    instructions='You are an expert in Natural Language Inference (NLI). Your task is to process a JSON-encoded list of pairs, where each pair contains a "premise" and a "hypothesis". For each pair, determine the relationship between the premise and the hypothesis by selecting one of the following labels: "entailment" (the premise logically implies the hypothesis), "neutral" (the premise is unrelated or neither supports nor contradicts the hypothesis), or "contradiction" (the premise directly conflicts with the hypothesis). Provide a concise, one-sentence explanation for each label to justify your decision. Ensure your output is structured as a JSON-encoded list of objects, where each object includes the fields "label" and "explanation", corresponding to the input pairs. Focus on accuracy, logical reasoning, and brevity to make your response clear and easy to parse.'
    pairs_json = Field(a

In [None]:
# Evaluation: per-section and per-transformation metrics
import numpy as np


def predict_paradigm_with_model(model: dspy.Module, pairs: list[dict]) -> list[int]:
    pred = model(pairs_json=safe_json_dumps(pairs))
    return parse_labels_json(getattr(pred, "labels_json", "[]"))


# Build reverse index to the raw rows to recover transformation types (pairID mod 19 often maps to type)
# We will report by section and by transformation index t in [0..18]
section_pid_to_rows = defaultdict(dict)
for section, ds in raw.items():
    for row in ds:
        section_pid_to_rows[section].setdefault(row["paradigmID"], []).append(row)

# Evaluate
results_by_section = {}
results_by_transform = {t: {"correct": 0, "total": 0} for t in range(19)}

for section, indices in section_index.items():
    accs = []
    conss = []
    totals = 0
    corrects = 0
    for idx in indices[:]:
        ex = paradigm_examples[idx]
        pairs = json.loads(ex.pairs_json)
        gold = parse_labels_json(ex.labels_json)
        preds = predict_paradigm_with_model(optimized_student, pairs)
        if not preds or len(preds) != len(gold):
            continue
        acc = compute_paradigm_accuracy(gold, preds)
        cons = compute_paradigm_consistency(gold, preds)
        accs.append(acc)
        conss.append(cons)
        totals += len(gold)
        corrects += sum(int(g == p) for g, p in zip(gold, preds))
        
        # Per-transformation report: best effort via original order using pairID % 19
        # If pairID is present and behaves as index, aggregate accuracy per transform
        for row in section_pid_to_rows[section][int(ex.paradigm_id)]:
            try:
                t = int(row["pairID"]) % 19
            except Exception:
                t = None
            # match corresponding prediction by (premise,hypothesis)
            try:
                i = next(i for i, pr in enumerate(pairs) if pr["premise"] == row["premise"] and pr["hypothesis"] == row["hypothesis"]) 
                if t is not None:
                    results_by_transform[t]["total"] += 1
                    results_by_transform[t]["correct"] += int(preds[i] == int(row["gold_label"]))
            except StopIteration:
                pass
    if accs:
        results_by_section[section] = {
            "accuracy": float(np.mean(accs)),
            "consistency": float(np.mean(conss)),
        }

results_by_section, {t: (v["correct"] / v["total"] if v["total"] else None) for t, v in results_by_transform.items()}



## Why this design

- We group by `paradigmID` and shuffle items within each paradigm to prevent position leakage and to compute a meaningful reward over the whole paradigm.
- The optimization metric is a weighted combination of accuracy and intra-paradigm consistency, aligning with the assignment. Consistency is measured via majority-coherence within gold-label clusters.
- We use a JSON-IO `Signature` so the model returns aligned lists of labels and explanations; this reduces parsing errors for batched predictions.
- We adopt a CoT predictor to elicit short explanations that often stabilize classification, with a fallback to per-item CoT if the batched JSON output is malformed.
- During evaluation, we report per-section averages for accuracy and consistency, and approximate er-transformation accuracy using `pairID % 19` mapping to transformation indices.
