# ANLI with LLM

You have to implement in this notebook a better ANLI classifier using an LLM.
This classifier must be implemented using DSPy.


In [24]:
# Imports and Setup
import os
import dspy
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
from dspy.primitives import Example
from dspy.evaluate import Evaluate
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import cohen_kappa_score
from typing import Literal
import pickle
from collections import Counter
import random
from evaluate import load
import json
import warnings
warnings.filterwarnings('ignore')

# Load Grok API key
with open('grok_key.ini', 'r') as f:
    line = f.read().strip()
    if line.startswith('export XAI_API_KEY='):
        api_key = line.split('=', 1)[1]
        os.environ['XAI_API_KEY'] = api_key
    else:
        raise ValueError("Could not parse API key from file")

lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
dspy.configure(lm=lm)

# Set seed for reproducibility
random.seed(42)
np.random.seed(42)
print("Imports complete!")

Imports complete!


In [25]:
#  Configuration setting parameters
CONFIG = {
    'OPTIMIZATION_SAMPLES': 100,
    'EVALUATION_SAMPLES': 100,
    'REFINE_N': 5,
    'MAX_BOOTSTRAPPED_DEMOS': 20,
    'NUM_CANDIDATE_PROGRAMS': 20  
}

# Load ANLI dataset and filter for non-empty reasons
dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] is not None and x['reason'] != "")

# Split dev_r3: shuffle for diversity, then split and convert to list of dicts
dev_r3 = dataset['dev_r3'].shuffle(seed=42)
midpoint = len(dev_r3) // 2

batched_opt = dev_r3.select(range(midpoint))[:CONFIG['OPTIMIZATION_SAMPLES']]
dev_optimization = [dict(zip(batched_opt, t)) for t in zip(*batched_opt.values())]

batched_eval = dev_r3.select(range(midpoint, len(dev_r3)))[:CONFIG['EVALUATION_SAMPLES']]
dev_evaluation = [dict(zip(batched_eval, t)) for t in zip(*batched_eval.values())]

print(f"Optimization split: {len(dev_optimization)} samples")
print(f"Evaluation split: {len(dev_evaluation)} samples")

# Load sentence-transformers
similarity_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Setup complete!")

Optimization split: 100 samples
Evaluation split: 100 samples
Setup complete!


In [26]:
# DSPy Signatures 
class JointCoT(dspy.Signature):
    """Generate a Chain-of-Thought explanation and classify the relationship."""
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    joint_explanation: str = dspy.OutputField(  
        desc="Step-by-step reasoning that analyzes the premise and hypothesis to justify the classification label."
    )
    label: Literal["entailment", "contradiction", "neutral"] = dspy.OutputField()

class GenerateExplanation(dspy.Signature):
    """Generate a relevant CoT explanation for the premise-hypothesis relation."""
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    generated_explanation: str = dspy.OutputField(  
        desc="Provide step-by-step reasoning that analyzes how the hypothesis relates to the premise. Consider what information is given, what can be inferred, and what contradictions exist."
    )

class ClassifyWithExplanation(dspy.Signature):
    """Classify based on premise, hypothesis, and explanation."""
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    explanation_text: str = dspy.InputField(  
        desc="The explanation to use for classification"
    )
    label: Literal["entailment", "contradiction", "neutral"] = dspy.OutputField(
        desc="Based on the explanation, classify as: 'entailment' if hypothesis must be true, 'contradiction' if hypothesis must be false, 'neutral' if hypothesis could be true or false."
    )
print("Signatures defined!")

Signatures defined!


In [27]:
# DSPy Modules 
class JointModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate = dspy.ChainOfThought(JointCoT)
    
    def forward(self, premise, hypothesis):
        return self.generate(premise=premise, hypothesis=hypothesis)

class PipelineModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.explain = dspy.ChainOfThought(GenerateExplanation)
        self.classify = dspy.ChainOfThought(ClassifyWithExplanation)
    
    def forward(self, premise, hypothesis):
        exp = self.explain(premise=premise, hypothesis=hypothesis)
        result = self.classify(premise=premise, hypothesis=hypothesis, explanation_text=exp.generated_explanation)  
        # Add the explanation to the result for evaluation
        result.pipeline_explanation = exp.generated_explanation  
        return result

print("Modules defined!")

Modules defined!


In [28]:
# Similarity Functions 
def compute_similarity(text1, text2):
    emb1 = similarity_model.encode(text1)
    emb2 = similarity_model.encode(text2)
    return util.cos_sim(emb1, emb2).item()

def explanation_quality(premise, hypothesis, pred_exp, human_reason):
    ph = f"{premise} {hypothesis}"
    return {
        'pred_vs_human': compute_similarity(pred_exp, human_reason),
        'pred_vs_ph': compute_similarity(pred_exp, ph),
        'human_vs_ph': compute_similarity(human_reason, ph)
    }

# Learn threshold 
def learn_threshold(data):
    similarities = []
    for ex in tqdm(data, desc="Learning threshold"):
        ph = f"{ex['premise']} {ex['hypothesis']}"
        sim = compute_similarity(ex['reason'], ph)
        similarities.append(sim)
    threshold = np.mean(similarities) - np.std(similarities)
    print(f"Learned threshold: {threshold:.3f}")
    return threshold

threshold = learn_threshold(dev_optimization)

Learning threshold: 100%|██████████| 100/100 [00:01<00:00, 56.30it/s]

Learned threshold: 0.167





In [29]:
# Custom Metric 
label_map = {"entailment": 0, "neutral": 1, "contradiction": 2}

def custom_metric(example, pred, trace=None):
    pred_label = getattr(pred, 'label', '').lower()
    acc = 1.0 if label_map.get(pred_label, -1) == example['label'] else 0.0
    
    # Get explanation from either joint or pipeline field
    exp = getattr(pred, 'joint_explanation', '') or getattr(pred, 'pipeline_explanation', '')
    if not exp:
        return 0.0
    sim = compute_similarity(exp, example['reason'])
    quality = 1.0 if sim > threshold else 0.0
    
    return (acc + quality) / 2.0

print("Metric defined!")

Metric defined!


In [30]:
# Prepare Examples
trainset = [Example(premise=ex['premise'], hypothesis=ex['hypothesis'], label=ex['label'], reason=ex['reason']).with_inputs('premise', 'hypothesis') for ex in dev_optimization]


# Run optimization
compiler = BootstrapFewShotWithRandomSearch(metric=custom_metric, max_bootstrapped_demos=CONFIG['MAX_BOOTSTRAPPED_DEMOS'], num_candidate_programs=CONFIG['NUM_CANDIDATE_PROGRAMS'])
optimized_joint = compiler.compile(JointModule(), trainset=trainset)
optimized_pipeline = compiler.compile(PipelineModule(), trainset=trainset)

print("Optimization complete!")

Going to sample between 1 and 20 traces per predictor.
Will attempt to bootstrap 20 candidate sets.
Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 56.32it/s]

2025/07/30 11:59:18 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



New best score: 84.5 for seed -3
Scores so far: [84.5]
Best score so far: 84.5
Average Metric: 82.50 / 100 (82.5%): 100%|██████████| 100/100 [00:01<00:00, 60.82it/s]

2025/07/30 11:59:20 INFO dspy.evaluate.evaluate: Average Metric: 82.5 / 100 (82.5%)



Scores so far: [84.5, 82.5]
Best score so far: 84.5


 20%|██        | 20/100 [00:00<00:01, 49.65it/s]


Bootstrapped 20 full traces after 20 examples for up to 1 rounds, amounting to 20 attempts.
Average Metric: 84.00 / 100 (84.0%): 100%|██████████| 100/100 [00:01<00:00, 56.80it/s]

2025/07/30 11:59:22 INFO dspy.evaluate.evaluate: Average Metric: 84.0 / 100 (84.0%)



Scores so far: [84.5, 82.5, 84.0]
Best score so far: 84.5


 13%|█▎        | 13/100 [00:00<00:01, 46.89it/s]


Bootstrapped 13 full traces after 13 examples for up to 1 rounds, amounting to 13 attempts.
Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 58.12it/s]

2025/07/30 11:59:24 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5]
Best score so far: 84.5


  5%|▌         | 5/100 [00:00<00:01, 58.16it/s]


Bootstrapped 5 full traces after 5 examples for up to 1 rounds, amounting to 5 attempts.
Average Metric: 82.00 / 100 (82.0%): 100%|██████████| 100/100 [00:01<00:00, 64.79it/s]

2025/07/30 11:59:26 INFO dspy.evaluate.evaluate: Average Metric: 82.0 / 100 (82.0%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0]
Best score so far: 84.5


  2%|▏         | 2/100 [00:00<00:02, 42.30it/s]


Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.
Average Metric: 82.50 / 100 (82.5%): 100%|██████████| 100/100 [00:01<00:00, 63.41it/s]

2025/07/30 11:59:28 INFO dspy.evaluate.evaluate: Average Metric: 82.5 / 100 (82.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5]
Best score so far: 84.5


  9%|▉         | 9/100 [00:00<00:02, 42.14it/s]


Bootstrapped 8 full traces after 9 examples for up to 1 rounds, amounting to 9 attempts.
Average Metric: 86.50 / 100 (86.5%): 100%|██████████| 100/100 [00:01<00:00, 67.11it/s]

2025/07/30 11:59:29 INFO dspy.evaluate.evaluate: Average Metric: 86.5 / 100 (86.5%)



New best score: 86.5 for seed 3
Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5]
Best score so far: 86.5


  8%|▊         | 8/100 [00:00<00:01, 62.14it/s]

Bootstrapped 8 full traces after 8 examples for up to 1 rounds, amounting to 8 attempts.





Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 64.68it/s]

2025/07/30 11:59:31 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5]
Best score so far: 86.5


 21%|██        | 21/100 [00:00<00:01, 52.34it/s]


Bootstrapped 20 full traces after 21 examples for up to 1 rounds, amounting to 21 attempts.
Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 62.79it/s]

2025/07/30 11:59:33 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5]
Best score so far: 86.5


 19%|█▉        | 19/100 [00:00<00:01, 51.42it/s]


Bootstrapped 19 full traces after 19 examples for up to 1 rounds, amounting to 19 attempts.
Average Metric: 83.00 / 100 (83.0%): 100%|██████████| 100/100 [00:01<00:00, 63.07it/s]

2025/07/30 11:59:35 INFO dspy.evaluate.evaluate: Average Metric: 83.0 / 100 (83.0%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0]
Best score so far: 86.5


 11%|█         | 11/100 [00:00<00:01, 46.56it/s]


Bootstrapped 11 full traces after 11 examples for up to 1 rounds, amounting to 11 attempts.
Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 59.65it/s]

2025/07/30 11:59:37 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5]
Best score so far: 86.5


  8%|▊         | 8/100 [00:00<00:02, 39.51it/s]


Bootstrapped 8 full traces after 8 examples for up to 1 rounds, amounting to 8 attempts.
Average Metric: 83.50 / 100 (83.5%): 100%|██████████| 100/100 [00:01<00:00, 64.22it/s]

2025/07/30 11:59:39 INFO dspy.evaluate.evaluate: Average Metric: 83.5 / 100 (83.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5]
Best score so far: 86.5


 15%|█▌        | 15/100 [00:00<00:01, 46.48it/s]


Bootstrapped 15 full traces after 15 examples for up to 1 rounds, amounting to 15 attempts.
Average Metric: 83.50 / 100 (83.5%): 100%|██████████| 100/100 [00:01<00:00, 62.59it/s]

2025/07/30 11:59:41 INFO dspy.evaluate.evaluate: Average Metric: 83.5 / 100 (83.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5]
Best score so far: 86.5


 19%|█▉        | 19/100 [00:00<00:01, 46.25it/s]


Bootstrapped 19 full traces after 19 examples for up to 1 rounds, amounting to 19 attempts.
Average Metric: 83.00 / 100 (83.0%): 100%|██████████| 100/100 [00:01<00:00, 65.33it/s]

2025/07/30 11:59:43 INFO dspy.evaluate.evaluate: Average Metric: 83.0 / 100 (83.0%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0]
Best score so far: 86.5


 15%|█▌        | 15/100 [00:00<00:01, 50.75it/s]


Bootstrapped 15 full traces after 15 examples for up to 1 rounds, amounting to 15 attempts.
Average Metric: 86.50 / 100 (86.5%): 100%|██████████| 100/100 [00:01<00:00, 62.63it/s]

2025/07/30 11:59:45 INFO dspy.evaluate.evaluate: Average Metric: 86.5 / 100 (86.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5]
Best score so far: 86.5


 16%|█▌        | 16/100 [00:00<00:01, 50.85it/s]


Bootstrapped 16 full traces after 16 examples for up to 1 rounds, amounting to 16 attempts.
Average Metric: 86.00 / 100 (86.0%): 100%|██████████| 100/100 [00:01<00:00, 68.24it/s]

2025/07/30 11:59:46 INFO dspy.evaluate.evaluate: Average Metric: 86.0 / 100 (86.0%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0]
Best score so far: 86.5


  9%|▉         | 9/100 [00:00<00:01, 47.38it/s]

Bootstrapped 9 full traces after 9 examples for up to 1 rounds, amounting to 9 attempts.





Average Metric: 81.50 / 100 (81.5%): 100%|██████████| 100/100 [00:01<00:00, 66.41it/s]

2025/07/30 11:59:48 INFO dspy.evaluate.evaluate: Average Metric: 81.5 / 100 (81.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0, 81.5]
Best score so far: 86.5


  4%|▍         | 4/100 [00:00<00:01, 56.02it/s]


Bootstrapped 4 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.
Average Metric: 83.00 / 100 (83.0%): 100%|██████████| 100/100 [00:01<00:00, 63.94it/s]

2025/07/30 11:59:50 INFO dspy.evaluate.evaluate: Average Metric: 83.0 / 100 (83.0%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0, 81.5, 83.0]
Best score so far: 86.5


  7%|▋         | 7/100 [00:00<00:01, 52.20it/s]

Bootstrapped 7 full traces after 7 examples for up to 1 rounds, amounting to 7 attempts.





Average Metric: 82.50 / 100 (82.5%): 100%|██████████| 100/100 [00:01<00:00, 59.81it/s]

2025/07/30 11:59:52 INFO dspy.evaluate.evaluate: Average Metric: 82.5 / 100 (82.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0, 81.5, 83.0, 82.5]
Best score so far: 86.5


 12%|█▏        | 12/100 [00:00<00:01, 49.77it/s]


Bootstrapped 12 full traces after 12 examples for up to 1 rounds, amounting to 12 attempts.
Average Metric: 85.50 / 100 (85.5%): 100%|██████████| 100/100 [00:01<00:00, 62.96it/s]

2025/07/30 11:59:53 INFO dspy.evaluate.evaluate: Average Metric: 85.5 / 100 (85.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0, 81.5, 83.0, 82.5, 85.5]
Best score so far: 86.5


 17%|█▋        | 17/100 [00:00<00:01, 49.51it/s]


Bootstrapped 17 full traces after 17 examples for up to 1 rounds, amounting to 17 attempts.
Average Metric: 86.00 / 100 (86.0%): 100%|██████████| 100/100 [00:01<00:00, 64.57it/s]

2025/07/30 11:59:55 INFO dspy.evaluate.evaluate: Average Metric: 86.0 / 100 (86.0%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0, 81.5, 83.0, 82.5, 85.5, 86.0]
Best score so far: 86.5


  6%|▌         | 6/100 [00:00<00:01, 55.82it/s]

Bootstrapped 6 full traces after 6 examples for up to 1 rounds, amounting to 6 attempts.





Average Metric: 82.50 / 100 (82.5%): 100%|██████████| 100/100 [00:01<00:00, 63.92it/s]

2025/07/30 11:59:57 INFO dspy.evaluate.evaluate: Average Metric: 82.5 / 100 (82.5%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0, 81.5, 83.0, 82.5, 85.5, 86.0, 82.5]
Best score so far: 86.5


  2%|▏         | 2/100 [00:00<00:01, 53.72it/s]

Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.





Average Metric: 82.00 / 100 (82.0%): 100%|██████████| 100/100 [00:01<00:00, 66.86it/s]

2025/07/30 11:59:59 INFO dspy.evaluate.evaluate: Average Metric: 82.0 / 100 (82.0%)



Scores so far: [84.5, 82.5, 84.0, 84.5, 82.0, 82.5, 86.5, 84.5, 84.5, 83.0, 84.5, 83.5, 83.5, 83.0, 86.5, 86.0, 81.5, 83.0, 82.5, 85.5, 86.0, 82.5, 82.0]
Best score so far: 86.5
23 candidate programs found.
Average Metric: 83.00 / 100 (83.0%): 100%|██████████| 100/100 [00:01<00:00, 68.86it/s]

2025/07/30 12:00:00 INFO dspy.evaluate.evaluate: Average Metric: 83.0 / 100 (83.0%)



New best score: 83.0 for seed -3
Scores so far: [83.0]
Best score so far: 83.0
Average Metric: 85.50 / 100 (85.5%): 100%|██████████| 100/100 [00:01<00:00, 68.24it/s]

2025/07/30 12:00:02 INFO dspy.evaluate.evaluate: Average Metric: 85.5 / 100 (85.5%)



New best score: 85.5 for seed -2
Scores so far: [83.0, 85.5]
Best score so far: 85.5


 20%|██        | 20/100 [00:00<00:01, 56.88it/s]


Bootstrapped 20 full traces after 20 examples for up to 1 rounds, amounting to 20 attempts.
Average Metric: 84.00 / 100 (84.0%): 100%|██████████| 100/100 [00:01<00:00, 56.66it/s]

2025/07/30 12:00:04 INFO dspy.evaluate.evaluate: Average Metric: 84.0 / 100 (84.0%)



Scores so far: [83.0, 85.5, 84.0]
Best score so far: 85.5


 13%|█▎        | 13/100 [00:00<00:01, 55.67it/s]


Bootstrapped 13 full traces after 13 examples for up to 1 rounds, amounting to 13 attempts.
Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 60.81it/s]

2025/07/30 12:00:06 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5]
Best score so far: 85.5


  5%|▌         | 5/100 [00:00<00:01, 57.87it/s]


Bootstrapped 5 full traces after 5 examples for up to 1 rounds, amounting to 5 attempts.
Average Metric: 83.50 / 100 (83.5%): 100%|██████████| 100/100 [00:01<00:00, 60.31it/s]

2025/07/30 12:00:07 INFO dspy.evaluate.evaluate: Average Metric: 83.5 / 100 (83.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5]
Best score so far: 85.5


  2%|▏         | 2/100 [00:00<00:01, 60.90it/s]

Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.





Average Metric: 82.50 / 100 (82.5%): 100%|██████████| 100/100 [00:01<00:00, 63.49it/s]

2025/07/30 12:00:09 INFO dspy.evaluate.evaluate: Average Metric: 82.5 / 100 (82.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5]
Best score so far: 85.5


  8%|▊         | 8/100 [00:00<00:02, 43.44it/s]

Bootstrapped 8 full traces after 8 examples for up to 1 rounds, amounting to 8 attempts.





Average Metric: 85.00 / 100 (85.0%): 100%|██████████| 100/100 [00:01<00:00, 61.91it/s]

2025/07/30 12:00:11 INFO dspy.evaluate.evaluate: Average Metric: 85.0 / 100 (85.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0]
Best score so far: 85.5


  8%|▊         | 8/100 [00:00<00:01, 58.45it/s]

Bootstrapped 8 full traces after 8 examples for up to 1 rounds, amounting to 8 attempts.





Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 65.85it/s]

2025/07/30 12:00:13 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5]
Best score so far: 85.5


 21%|██        | 21/100 [00:00<00:01, 50.00it/s]


Bootstrapped 20 full traces after 21 examples for up to 1 rounds, amounting to 21 attempts.
Average Metric: 84.00 / 100 (84.0%): 100%|██████████| 100/100 [00:01<00:00, 58.51it/s]

2025/07/30 12:00:15 INFO dspy.evaluate.evaluate: Average Metric: 84.0 / 100 (84.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0]
Best score so far: 85.5


 19%|█▉        | 19/100 [00:00<00:01, 55.12it/s]


Bootstrapped 19 full traces after 19 examples for up to 1 rounds, amounting to 19 attempts.
Average Metric: 86.00 / 100 (86.0%): 100%|██████████| 100/100 [00:01<00:00, 59.70it/s]

2025/07/30 12:00:17 INFO dspy.evaluate.evaluate: Average Metric: 86.0 / 100 (86.0%)



New best score: 86.0 for seed 6
Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0]
Best score so far: 86.0


 11%|█         | 11/100 [00:00<00:01, 54.99it/s]


Bootstrapped 11 full traces after 11 examples for up to 1 rounds, amounting to 11 attempts.
Average Metric: 81.50 / 100 (81.5%): 100%|██████████| 100/100 [00:02<00:00, 49.67it/s]

2025/07/30 12:00:19 INFO dspy.evaluate.evaluate: Average Metric: 81.5 / 100 (81.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5]
Best score so far: 86.0


  8%|▊         | 8/100 [00:00<00:01, 56.40it/s]

Bootstrapped 8 full traces after 8 examples for up to 1 rounds, amounting to 8 attempts.





Average Metric: 84.00 / 100 (84.0%): 100%|██████████| 100/100 [00:01<00:00, 62.09it/s]

2025/07/30 12:00:21 INFO dspy.evaluate.evaluate: Average Metric: 84.0 / 100 (84.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0]
Best score so far: 86.0


 15%|█▌        | 15/100 [00:00<00:01, 54.81it/s]


Bootstrapped 15 full traces after 15 examples for up to 1 rounds, amounting to 15 attempts.
Average Metric: 85.00 / 100 (85.0%): 100%|██████████| 100/100 [00:01<00:00, 60.41it/s]

2025/07/30 12:00:23 INFO dspy.evaluate.evaluate: Average Metric: 85.0 / 100 (85.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0]
Best score so far: 86.0


 20%|██        | 20/100 [00:00<00:01, 55.09it/s]


Bootstrapped 19 full traces after 20 examples for up to 1 rounds, amounting to 20 attempts.
Average Metric: 85.50 / 100 (85.5%): 100%|██████████| 100/100 [00:01<00:00, 57.40it/s]

2025/07/30 12:00:25 INFO dspy.evaluate.evaluate: Average Metric: 85.5 / 100 (85.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5]
Best score so far: 86.0


 15%|█▌        | 15/100 [00:00<00:01, 53.56it/s]


Bootstrapped 15 full traces after 15 examples for up to 1 rounds, amounting to 15 attempts.
Average Metric: 85.00 / 100 (85.0%): 100%|██████████| 100/100 [00:01<00:00, 59.99it/s]

2025/07/30 12:00:27 INFO dspy.evaluate.evaluate: Average Metric: 85.0 / 100 (85.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0]
Best score so far: 86.0


 16%|█▌        | 16/100 [00:00<00:01, 55.11it/s]


Bootstrapped 16 full traces after 16 examples for up to 1 rounds, amounting to 16 attempts.
Average Metric: 82.50 / 100 (82.5%): 100%|██████████| 100/100 [00:01<00:00, 60.09it/s]

2025/07/30 12:00:29 INFO dspy.evaluate.evaluate: Average Metric: 82.5 / 100 (82.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5]
Best score so far: 86.0


  9%|▉         | 9/100 [00:00<00:01, 58.46it/s]

Bootstrapped 9 full traces after 9 examples for up to 1 rounds, amounting to 9 attempts.





Average Metric: 87.00 / 100 (87.0%): 100%|██████████| 100/100 [00:01<00:00, 61.98it/s]

2025/07/30 12:00:31 INFO dspy.evaluate.evaluate: Average Metric: 87.0 / 100 (87.0%)



New best score: 87.0 for seed 13
Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5, 87.0]
Best score so far: 87.0


  4%|▍         | 4/100 [00:00<00:01, 63.55it/s]


Bootstrapped 4 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.
Average Metric: 86.00 / 100 (86.0%): 100%|██████████| 100/100 [00:01<00:00, 58.26it/s]

2025/07/30 12:00:33 INFO dspy.evaluate.evaluate: Average Metric: 86.0 / 100 (86.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5, 87.0, 86.0]
Best score so far: 87.0


  7%|▋         | 7/100 [00:00<00:01, 52.01it/s]

Bootstrapped 7 full traces after 7 examples for up to 1 rounds, amounting to 7 attempts.





Average Metric: 87.00 / 100 (87.0%): 100%|██████████| 100/100 [00:01<00:00, 56.64it/s]

2025/07/30 12:00:35 INFO dspy.evaluate.evaluate: Average Metric: 87.0 / 100 (87.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5, 87.0, 86.0, 87.0]
Best score so far: 87.0


 12%|█▏        | 12/100 [00:00<00:01, 54.32it/s]


Bootstrapped 12 full traces after 12 examples for up to 1 rounds, amounting to 12 attempts.
Average Metric: 83.50 / 100 (83.5%): 100%|██████████| 100/100 [00:01<00:00, 61.11it/s]

2025/07/30 12:00:36 INFO dspy.evaluate.evaluate: Average Metric: 83.5 / 100 (83.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5, 87.0, 86.0, 87.0, 83.5]
Best score so far: 87.0


 17%|█▋        | 17/100 [00:00<00:01, 52.11it/s]


Bootstrapped 17 full traces after 17 examples for up to 1 rounds, amounting to 17 attempts.
Average Metric: 85.50 / 100 (85.5%): 100%|██████████| 100/100 [00:01<00:00, 56.20it/s]

2025/07/30 12:00:39 INFO dspy.evaluate.evaluate: Average Metric: 85.5 / 100 (85.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5, 87.0, 86.0, 87.0, 83.5, 85.5]
Best score so far: 87.0


  6%|▌         | 6/100 [00:00<00:01, 55.00it/s]

Bootstrapped 6 full traces after 6 examples for up to 1 rounds, amounting to 6 attempts.





Average Metric: 84.50 / 100 (84.5%): 100%|██████████| 100/100 [00:01<00:00, 64.90it/s]

2025/07/30 12:00:40 INFO dspy.evaluate.evaluate: Average Metric: 84.5 / 100 (84.5%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5, 87.0, 86.0, 87.0, 83.5, 85.5, 84.5]
Best score so far: 87.0


  2%|▏         | 2/100 [00:00<00:01, 58.84it/s]

Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.





Average Metric: 86.00 / 100 (86.0%): 100%|██████████| 100/100 [00:01<00:00, 64.35it/s]

2025/07/30 12:00:42 INFO dspy.evaluate.evaluate: Average Metric: 86.0 / 100 (86.0%)



Scores so far: [83.0, 85.5, 84.0, 84.5, 83.5, 82.5, 85.0, 84.5, 84.0, 86.0, 81.5, 84.0, 85.0, 85.5, 85.0, 82.5, 87.0, 86.0, 87.0, 83.5, 85.5, 84.5, 86.0]
Best score so far: 87.0
23 candidate programs found.
Optimization complete!


In [31]:
# Custom Best-of-N Refinement (DSPy refine implementation)
def best_of_n(module, premise, hypothesis, n=CONFIG['REFINE_N'], metric=custom_metric):
    """Generate n predictions and return the best one based on the custom metric."""
    predictions = []
    
    for _ in range(n):
        try:
            pred = module(premise=premise, hypothesis=hypothesis)
            predictions.append(pred)
        except Exception as e:
            print(f"Prediction error: {e}")
            continue
    
    if not predictions:
        # Fallback: try one more time
        return module(premise=premise, hypothesis=hypothesis)
    
    # Score each prediction using the custom metric
    best_pred = predictions[0]
    best_score = -1
    
    for pred in predictions:
        try:
            # Create a dummy example for metric evaluation
            example = {'premise': premise, 'hypothesis': hypothesis, 'label': 0, 'reason': ''}
            score = metric(example, pred)
            if score > best_score:
                best_score = score
                best_pred = pred
        except:
            continue
    
    return best_pred

print("Best-of-N refinement ready!")


Best-of-N refinement ready!


In [32]:
# Prediction with Refinement
def predict_with_refine(module, premise, hypothesis):
    return best_of_n(module, premise, hypothesis)

print("Prediction ready!")

Prediction ready!


In [33]:
# Evaluation Function 
def evaluate_model(module, data, deberta_preds=None):
    preds, golds = [], []
    sims = {'pred_vs_human': [], 'pred_vs_ph': [], 'human_vs_ph': []}
    
    for ex in tqdm(data, desc='Evaluating'):
        try:
            pred = predict_with_refine(module, ex['premise'], ex['hypothesis'])
            pred_label = getattr(pred, 'label', '').lower()
            preds.append(label_map.get(pred_label, -1))
            golds.append(ex['label'])
            
            # Get explanation from either joint or pipeline field
            exp = getattr(pred, 'joint_explanation', '') or getattr(pred, 'pipeline_explanation', '')
            if exp:
                metrics = explanation_quality(ex['premise'], ex['hypothesis'], exp, ex['reason'])
                for k in sims:
                    sims[k].append(metrics[k])
            else:
                for k in sims:
                    sims[k].append(0.0)
        except Exception as e:
            print(f"Error: {e}")
            preds.append(-1)
            golds.append(ex['label'])
            for k in sims:
                sims[k].append(0.0)

    # Calculate classification metrics
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support
    
    # Filter out invalid predictions
    valid_indices = [i for i, p in enumerate(preds) if p != -1]
    valid_preds = [preds[i] for i in valid_indices]
    valid_golds = [golds[i] for i in valid_indices]
    
    print(f"Total predictions: {len(preds)}, Valid predictions: {len(valid_preds)}")
    
    if len(valid_preds) == 0:
        print("ERROR: No valid predictions found!")
        return None
    
    accuracy = accuracy_score(valid_golds, valid_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(valid_golds, valid_preds, average='weighted', zero_division=0)
    
    # Calculate average similarities
    avg_sims = {k: np.mean(v) if v else 0.0 for k, v in sims.items()}
    
    # Return results dictionary
    return {
        'predictions': preds,
        'golds': golds,
        'metrics': {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        },
        'avg_sims': avg_sims
    }

print("Evaluation ready!")

Evaluation ready!


In [34]:
# Comparison Functions
def compare_results(joint_results, pipeline_results):
    comparison_df = pd.DataFrame({
        'Metric': ['Accuracy', 'F1', 'Precision', 'Recall', 'Pred-Human Sim', 'Pred-PH Sim', 'Human-PH Sim'],
        'Joint': [
            joint_results['metrics']['accuracy'], joint_results['metrics']['f1'],
            joint_results['metrics']['precision'], joint_results['metrics']['recall'],
            joint_results['avg_sims']['pred_vs_human'], joint_results['avg_sims']['pred_vs_ph'],
            joint_results['avg_sims']['human_vs_ph']
        ],
        'Pipeline': [
            pipeline_results['metrics']['accuracy'], pipeline_results['metrics']['f1'],
            pipeline_results['metrics']['precision'], pipeline_results['metrics']['recall'],
            pipeline_results['avg_sims']['pred_vs_human'], pipeline_results['avg_sims']['pred_vs_ph'],
            pipeline_results['avg_sims']['human_vs_ph']
        ]
    })
    return comparison_df

def compute_agreement_metrics(llm_preds, deberta_preds, golds):
    both_correct = llm_correct_deberta_wrong = deberta_correct_llm_wrong = both_wrong = 0
    
    deberta_mapped = [label_map.get(p.lower() if isinstance(p, str) else p, -1) for p in deberta_preds]
    
    for llm_pred, deberta_pred, gold in zip(llm_preds, deberta_mapped, golds):
        if llm_pred == -1 or deberta_pred == -1:
            continue
        llm_correct = (llm_pred == gold)
        deberta_correct = (deberta_pred == gold)
        if llm_correct and deberta_correct:
            both_correct += 1
        elif llm_correct:
            llm_correct_deberta_wrong += 1
        elif deberta_correct:
            deberta_correct_llm_wrong += 1
        else:
            both_wrong += 1
    
    total = both_correct + llm_correct_deberta_wrong + deberta_correct_llm_wrong + both_wrong
    if total == 0:
        return {'both_correct': 0, 'llm_correct_deberta_wrong': 0, 'deberta_correct_llm_wrong': 0, 'both_wrong': 0}
    pct = lambda x: (x / total * 100) if total > 0 else 0
    return {
        'both_correct': both_correct, 'both_correct_pct': pct(both_correct),
        'llm_correct_deberta_wrong': llm_correct_deberta_wrong, 'llm_correct_deberta_wrong_pct': pct(llm_correct_deberta_wrong),
        'deberta_correct_llm_wrong': deberta_correct_llm_wrong, 'deberta_correct_llm_wrong_pct': pct(deberta_correct_llm_wrong),
        'both_wrong': both_wrong, 'both_wrong_pct': pct(both_wrong)
    }

def print_agreement_analysis(joint_results, pipeline_results, deberta_preds):
    joint_ag = compute_agreement_metrics(joint_results['predictions'], deberta_preds[:len(joint_results['predictions'])], joint_results['golds'])
    pipeline_ag = compute_agreement_metrics(pipeline_results['predictions'], deberta_preds[:len(pipeline_results['predictions'])], pipeline_results['golds'])
    
    print("\nJOINT vs DeBERTa:")
    for k in ['both_correct', 'llm_correct_deberta_wrong', 'deberta_correct_llm_wrong', 'both_wrong']:
        print(f"{k}: {joint_ag[k]} ({joint_ag[f'{k}_pct']:.1f}%)")
    
    print("\nPIPELINE vs DeBERTa:")
    for k in ['both_correct', 'llm_correct_deberta_wrong', 'deberta_correct_llm_wrong', 'both_wrong']:
        print(f"{k}: {pipeline_ag[k]} ({pipeline_ag[f'{k}_pct']:.1f}%)")

print("Comparison ready!")

Comparison ready!


In [35]:
# Main Evaluation
import pickle
with open('deberta_preds.pkl', 'rb') as f:
    deberta_preds = pickle.load(f)[:CONFIG['EVALUATION_SAMPLES']]

print("Starting evaluation...")
joint_results = evaluate_model(optimized_joint, dev_evaluation, deberta_preds)
pipeline_results = evaluate_model(optimized_pipeline, dev_evaluation, deberta_preds)

comparison_df = compare_results(joint_results, pipeline_results)
print("\nJOINT vs PIPELINE COMPARISON")
print(comparison_df.round(4))

print("\nAGREEMENT ANALYSIS")
print_agreement_analysis(joint_results, pipeline_results, deberta_preds)

# Summary
print("\nTASK 1.4 SUMMARY")
print("✓ All requirements met!")
joint_acc = joint_results['metrics']['accuracy']
pipeline_acc = pipeline_results['metrics']['accuracy']
print(f"Joint Accuracy: {joint_acc:.3f}, Pipeline: {pipeline_acc:.3f}")

Starting evaluation...


Evaluating: 100%|██████████| 100/100 [00:15<00:00,  6.35it/s]


Total predictions: 100, Valid predictions: 100


Evaluating: 100%|██████████| 100/100 [00:13<00:00,  7.15it/s]

Total predictions: 100, Valid predictions: 100

JOINT vs PIPELINE COMPARISON
           Metric   Joint  Pipeline
0        Accuracy  0.7500    0.6900
1              F1  0.7529    0.6871
2       Precision  0.7867    0.6880
3          Recall  0.7500    0.6900
4  Pred-Human Sim  0.5105    0.5563
5     Pred-PH Sim  0.5869    0.6348
6    Human-PH Sim  0.4027    0.4027

AGREEMENT ANALYSIS

JOINT vs DeBERTa:
both_correct: 28 (28.0%)
llm_correct_deberta_wrong: 47 (47.0%)
deberta_correct_llm_wrong: 3 (3.0%)
both_wrong: 22 (22.0%)

PIPELINE vs DeBERTa:
both_correct: 25 (25.0%)
llm_correct_deberta_wrong: 44 (44.0%)
deberta_correct_llm_wrong: 6 (6.0%)
both_wrong: 25 (25.0%)

TASK 1.4 SUMMARY
✓ All requirements met!
Joint Accuracy: 0.750, Pipeline: 0.690





In [36]:
comparison_df.to_csv('task14_results.csv')

# END

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [37]:
# from evaluate import load

# accuracy = load("accuracy")
# precision = load("precision")
# recall = load("recall")
# f1 = load("f1")


In [38]:
# import evaluate
# clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]