# ImpPres LLM Baseline

You have to implement in this notebook a baseline for ImpPres classification using an LLM.
This baseline must be implemented using DSPy.



In [111]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
from dotenv import load_dotenv
import os
import dspy
load_dotenv("grok_key.ini") 
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])



# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [94]:
from typing import Literal
from tqdm import tqdm
import dspy

# Signature for the NLI task
class NLISignature(dspy.Signature):
    """
    Classify the relationship between the premise and hypothesis 
    to a label: entailment, neutral or contradiction.
    """
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

# A class for Parallel processing with progress display
class NLIClassifier(dspy.Module):
    def __init__(self, predictor_module: dspy.Module, batch_size: int = 15, num_threads: int = 6):
        super().__init__()
        self.predictor = predictor_module  # Predict, ChainOfThought, etc.
        self.batch_size = batch_size
        self.num_threads = num_threads

    def forward(self, examples: dspy.Example) -> list[dspy.Prediction]:
        # Display progress with tqdm while processing
        results = []
        for i in tqdm(range(0, len(examples), self.batch_size), desc="Processing"):
            sub_batch = examples[i:i + self.batch_size]
            processed = self.predictor.batch( # perform batch processing
                sub_batch,
                num_threads=self.num_threads
            )
            results.extend(processed)

        return results


## Load ImpPres dataset

In [95]:
from datasets import load_dataset

sections = ['presupposition_all_n_presupposition', 
            'presupposition_both_presupposition', 
            'presupposition_change_of_state', 
            'presupposition_cleft_existence', 
            'presupposition_cleft_uniqueness', 
            'presupposition_only_presupposition', 
            'presupposition_possessed_definites_existence', 
            'presupposition_possessed_definites_uniqueness', 
            'presupposition_question_presupposition']

dataset = {}
for section in sections:
    print(f"Loading dataset for section: {section}")
    dataset[section] = load_dataset("facebook/imppres", section)

Loading dataset for section: presupposition_all_n_presupposition
Loading dataset for section: presupposition_both_presupposition
Loading dataset for section: presupposition_change_of_state
Loading dataset for section: presupposition_cleft_existence
Loading dataset for section: presupposition_cleft_uniqueness
Loading dataset for section: presupposition_only_presupposition
Loading dataset for section: presupposition_possessed_definites_existence
Loading dataset for section: presupposition_possessed_definites_uniqueness
Loading dataset for section: presupposition_question_presupposition


In [96]:
dataset

{'presupposition_all_n_presupposition': DatasetDict({
     all_n_presupposition: Dataset({
         features: ['premise', 'hypothesis', 'trigger', 'trigger1', 'trigger2', 'presupposition', 'gold_label', 'UID', 'pairID', 'paradigmID'],
         num_rows: 1900
     })
 }),
 'presupposition_both_presupposition': DatasetDict({
     both_presupposition: Dataset({
         features: ['premise', 'hypothesis', 'trigger', 'trigger1', 'trigger2', 'presupposition', 'gold_label', 'UID', 'pairID', 'paradigmID'],
         num_rows: 1900
     })
 }),
 'presupposition_change_of_state': DatasetDict({
     change_of_state: Dataset({
         features: ['premise', 'hypothesis', 'trigger', 'trigger1', 'trigger2', 'presupposition', 'gold_label', 'UID', 'pairID', 'paradigmID'],
         num_rows: 1900
     })
 }),
 'presupposition_cleft_existence': DatasetDict({
     cleft_existence: Dataset({
         features: ['premise', 'hypothesis', 'trigger', 'trigger1', 'trigger2', 'presupposition', 'gold_label', 'UI

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [97]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

# IMPORTANT NOTE:

Following discussing with Michael - it was agreed that we use smaller sample than the whole impPres data set (17,000 samples) , as long as it is diverse.
So - we took 212 samples from each of the 9 data sets , summing up to 1908 examples.

In [98]:
# Organizing the dataset into examples

section_to_split = {
    'presupposition_all_n_presupposition': 'all_n_presupposition',
    'presupposition_both_presupposition': 'both_presupposition',
    'presupposition_change_of_state': 'change_of_state',
    'presupposition_cleft_existence': 'cleft_existence',
    'presupposition_cleft_uniqueness': 'cleft_uniqueness',
    'presupposition_only_presupposition': 'only_presupposition',
    'presupposition_possessed_definites_existence': 'possessed_definites_existence',
    'presupposition_possessed_definites_uniqueness': 'possessed_definites_uniqueness',
    'presupposition_question_presupposition': 'question_presupposition',
}

preprocessed_examples = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"],
        label=row["gold_label"]
    ).with_inputs("premise", "hypothesis")
    for k, v in section_to_split.items()
    for row in dataset[k][v]
]

# use the first 212 examples for each section (in total ~1900 examples)
examples = []
for i in range(len(section_to_split)):
    start_index = i * 212
    end_index = start_index + 212
    examples.extend(preprocessed_examples[start_index:end_index])

print(f"Total examples: {len(examples)}")


# Orgnaizing the data sets set
# pick examples randomly for training
import random
train_set_size = 30 # tradeoff between quality and speed after testing
trainset = random.sample(examples[:len(examples)//2], train_set_size) # pick examples from the first half to avoid bias


# pick examples randomly for evaluation
eval_set_size = 60  # size of the evaluation set
evaluate_set = random.sample(examples[len(examples)//2:], eval_set_size)

# unrequired verbosity
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*encoder_attention_mask.*")



Total examples: 1908


In [None]:
from dspy.teleprompt import BootstrapFewShot, COPRO, KNNFewShot
from sentence_transformers import SentenceTransformer
import evaluate

# Load metrics
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
precision = evaluate.load("precision")
recall = evaluate.load("recall")

label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}

def exact_match(example, pred, trace=None):
    # Ensure both labels are strings and lowercase
    ex_label = str(example.label).strip().lower()
    pred_label = str(pred.label).strip().lower()

    # In case example.label is already an int, use reverse mapping
    if ex_label.isdigit():
        id2label = {v: k for k, v in label2id.items()}
        ex_label = id2label[int(ex_label)]

    return label2id.get(pred_label) == label2id.get(ex_label)

def compute_metrics(preds, golds):
    return {
        "accuracy": accuracy.compute(predictions=preds, references=golds)["accuracy"],
        "precision": precision.compute(predictions=preds, references=golds, average="macro")["precision"],
        "recall": recall.compute(predictions=preds, references=golds, average="macro")["recall"],
        "f1": f1.compute(predictions=preds, references=golds, average="macro")["f1"],
    }

# Prepare models
model_simple = dspy.Predict(NLISignature)
model_cot = dspy.ChainOfThought(NLISignature)

# --- Optimization phase ---
# 1. BootstrapFewShot
bootstrap = BootstrapFewShot(metric=exact_match)
optimized_bootstrap = bootstrap.compile(student=model_simple, trainset=trainset)

# 2. COPRO on our cot model
copro = COPRO(metric=exact_match, max_trials=5, depth=2, breadth=3)
optimized_cot = copro.compile(student=model_cot, trainset=trainset,eval_kwargs={})

# # 3. KNNFewShot
# class EmbedderWrapper:
#     def __init__(self, model_name): self.model = SentenceTransformer(model_name)
#     def __call__(self, texts): return self.model.encode(texts, convert_to_numpy=True)

# vectorizer = EmbedderWrapper('all-MiniLM-L6-v2')
# knn = KNNFewShot(k=5, trainset=trainset, vectorizer=vectorizer)
# optimized_knn = knn.compile(student=model_simple)



 13%|█▎        | 4/30 [00:30<03:20,  7.70s/it]
2025/08/07 20:12:50 INFO dspy.teleprompt.copro_optimizer: Iteration Depth: 1/2.
2025/08/07 20:12:50 INFO dspy.teleprompt.copro_optimizer: At Depth 1/2, Evaluating Prompt Candidate #1/3 for Predictor 1 of 1.


Bootstrapped 4 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.


2025/08/07 20:13:13 INFO dspy.evaluate.evaluate: Average Metric: 27 / 30 (90.0%)
2025/08/07 20:13:13 INFO dspy.teleprompt.copro_optimizer: At Depth 1/2, Evaluating Prompt Candidate #2/3 for Predictor 1 of 1.
2025/08/07 20:13:39 INFO dspy.evaluate.evaluate: Average Metric: 29 / 30 (96.7%)
2025/08/07 20:13:39 INFO dspy.teleprompt.copro_optimizer: At Depth 1/2, Evaluating Prompt Candidate #3/3 for Predictor 1 of 1.
2025/08/07 20:14:05 INFO dspy.evaluate.evaluate: Average Metric: 28 / 30 (93.3%)
2025/08/07 20:14:12 INFO dspy.teleprompt.copro_optimizer: Iteration Depth: 2/2.
2025/08/07 20:14:12 INFO dspy.teleprompt.copro_optimizer: At Depth 2/2, Evaluating Prompt Candidate #1/3 for Predictor 1 of 1.
2025/08/07 20:14:39 INFO dspy.evaluate.evaluate: Average Metric: 27 / 30 (90.0%)
2025/08/07 20:14:39 INFO dspy.teleprompt.copro_optimizer: At Depth 2/2, Evaluating Prompt Candidate #2/3 for Predictor 1 of 1.
2025/08/07 20:15:06 INFO dspy.evaluate.evaluate: Average Metric: 28 / 30 (93.3%)
2025/08

In [104]:
import contextlib
import io

# --- Evaluation phase (Silent Mode) ---

print("📊 Evaluating Models... this might take a long time... adjust eval_set_size , train_set_size to speed up the process")

results = {}
silent_output = io.StringIO()
optimized_dict = {
    "BootstrapFewShot": optimized_bootstrap,
    "COPRO_CoT": optimized_cot,
    # "KNNFewShot": optimized_knn,
}

with contextlib.redirect_stdout(silent_output), contextlib.redirect_stderr(silent_output):
    for name, predictor in optimized_dict.items():
        program = NLIClassifier(predictor_module=predictor)
        predictions = program(evaluate_set) # "forward" method is called implicitly

        y_true = [ex.label for ex in evaluate_set[:len(predictions)]]
        y_pred = [label2id[pred.label.strip().lower()] for pred in predictions]

        results[name] = compute_metrics(y_pred, y_true)

# --- Display results ---
from pprint import pprint
print("📊 Model Evaluation Results:\n")
pprint(results)

📊 Evaluating Models... this might take a long time... adjust eval_set_size , train_set_size to speed up the process
📊 Model Evaluation Results:

{'BootstrapFewShot': {'accuracy': 0.9833333333333333,
                      'f1': 0.9808187134502924,
                      'precision': 0.9885057471264368,
                      'recall': 0.9743589743589745},
 'COPRO_CoT': {'accuracy': 0.9333333333333333,
               'f1': 0.9358730158730159,
               'precision': 0.9583333333333334,
               'recall': 0.921727395411606}}


In [105]:
# As you can see from the above output, BootstrapFewShot has the highest scores.
# reference output:

# {'BootstrapFewShot': {'accuracy': 0.9833333333333333,
#                       'f1': 0.9808187134502924,
#                       'precision': 0.9885057471264368,
#                       'recall': 0.9743589743589745},
#  'COPRO_CoT': {'accuracy': 0.9333333333333333,
#                'f1': 0.9358730158730159,
#                'precision': 0.9583333333333334,
#                'recall': 0.921727395411606}}


# The best performing model is optimized_bootstrap -> BootstrapFewShot ,  which automates the creation of effective few-shot demonstrations 

In [106]:
# run the best model on all of the examples , and report the results
best_model = optimized_bootstrap
program = NLIClassifier(predictor_module=best_model)
predictions = program.forward(examples)
y_true = [ex.label for ex in examples]
y_pred = [label2id[pred.label.strip().lower()] for pred in predictions]
# Compute final metrics
final_results = compute_metrics(y_pred, y_true)

Processing:   0%|          | 0/128 [00:00<?, ?it/s]

Processed 15 / 15 examples: 100%|██████████| 15/15 [00:19<00:00,  1.33s/it]

Processing:   1%|          | 1/128 [00:20<42:21, 20.01s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:23<00:00,  1.58s/it]

Processing:   2%|▏         | 2/128 [00:43<46:37, 22.20s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:25<00:00,  1.72s/it]

Processing:   2%|▏         | 3/128 [01:09<49:40, 23.85s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.03s/it]

Processing:   3%|▎         | 4/128 [01:25<42:28, 20.55s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]

Processing:   4%|▍         | 5/128 [01:47<43:43, 21.33s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.19s/it]

Processing:   5%|▍         | 6/128 [02:05<41:00, 20.17s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:21<00:00,  1.41s/it]

Processing:   5%|▌         | 7/128 [02:26<41:19, 20.50s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.04s/it]

Processing:   6%|▋         | 8/128 [02:42<37:53, 18.95s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s]

Processing:   7%|▋         | 9/128 [02:57<34:53, 17.59s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s]

Processing:   8%|▊         | 10/128 [03:11<32:55, 16.74s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.10s/it]

Processing:   9%|▊         | 11/128 [03:28<32:31, 16.68s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.11it/s]

Processing:   9%|▉         | 12/128 [03:42<30:26, 15.74s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.23it/s]

Processing:  10%|█         | 13/128 [03:54<28:07, 14.68s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.10it/s]

Processing:  11%|█         | 14/128 [04:07<27:17, 14.37s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.12it/s]

Processing:  12%|█▏        | 15/128 [04:21<26:29, 14.06s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.22it/s]

Processing:  12%|█▎        | 16/128 [04:33<25:18, 13.55s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:24<00:00,  1.65s/it]

Processing:  13%|█▎        | 17/128 [04:58<31:18, 16.92s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:11<00:00,  1.27it/s]

Processing:  14%|█▍        | 18/128 [05:10<28:12, 15.39s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:19<00:00,  1.29s/it]

Processing:  15%|█▍        | 19/128 [05:29<30:08, 16.59s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.11s/it]

Processing:  16%|█▌        | 20/128 [05:46<29:57, 16.64s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:18<00:00,  1.24s/it]

Processing:  16%|█▋        | 21/128 [06:05<30:45, 17.24s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.01s/it]

Processing:  17%|█▋        | 22/128 [06:20<29:22, 16.62s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.17it/s]

Processing:  18%|█▊        | 23/128 [06:33<27:05, 15.48s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.03s/it]

Processing:  19%|█▉        | 24/128 [06:48<26:52, 15.51s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.19s/it]

Processing:  20%|█▉        | 25/128 [07:06<27:51, 16.23s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.06it/s]

Processing:  20%|██        | 26/128 [07:20<26:31, 15.61s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s]

Processing:  21%|██        | 27/128 [07:35<25:47, 15.32s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.06s/it]

Processing:  22%|██▏       | 28/128 [07:51<25:48, 15.49s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.12s/it]

Processing:  23%|██▎       | 29/128 [08:08<26:14, 15.90s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.10it/s]

Processing:  23%|██▎       | 30/128 [08:21<24:52, 15.23s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s]

Processing:  24%|██▍       | 31/128 [08:36<24:21, 15.06s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.00it/s]

Processing:  25%|██▌       | 32/128 [08:51<24:03, 15.04s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.11s/it]

Processing:  26%|██▌       | 33/128 [09:07<24:33, 15.51s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:20<00:00,  1.34s/it]

Processing:  27%|██▋       | 34/128 [09:28<26:26, 16.88s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.22it/s]

Processing:  27%|██▋       | 35/128 [09:40<24:01, 15.50s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.12it/s]

Processing:  28%|██▊       | 36/128 [09:53<22:50, 14.89s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:24<00:00,  1.64s/it]

Processing:  29%|██▉       | 37/128 [10:18<27:00, 17.80s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.06s/it]

Processing:  30%|██▉       | 38/128 [10:34<25:52, 17.25s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:18<00:00,  1.24s/it]

Processing:  30%|███       | 39/128 [10:52<26:10, 17.65s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.17s/it]

Processing:  31%|███▏      | 40/128 [11:10<25:50, 17.62s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.00s/it]

Processing:  32%|███▏      | 41/128 [11:25<24:26, 16.86s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.08s/it]

Processing:  33%|███▎      | 42/128 [11:41<23:52, 16.66s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]

Processing:  34%|███▎      | 43/128 [11:56<22:35, 15.95s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]

Processing:  34%|███▍      | 44/128 [12:10<21:37, 15.45s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.08s/it]

Processing:  35%|███▌      | 45/128 [12:26<21:41, 15.68s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.01s/it]

Processing:  36%|███▌      | 46/128 [12:41<21:13, 15.53s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s]

Processing:  37%|███▋      | 47/128 [12:56<20:40, 15.31s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s]

Processing:  38%|███▊      | 48/128 [13:11<20:10, 15.13s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]

Processing:  38%|███▊      | 49/128 [13:25<19:37, 14.90s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.13s/it]

Processing:  39%|███▉      | 50/128 [13:42<20:11, 15.54s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s]

Processing:  40%|███▉      | 51/128 [13:55<19:03, 14.85s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]

Processing:  41%|████      | 52/128 [14:10<18:37, 14.71s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.15it/s]

Processing:  41%|████▏     | 53/128 [14:23<17:46, 14.23s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s]

Processing:  42%|████▏     | 54/128 [14:36<17:13, 13.97s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:31<00:00,  2.10s/it]

Processing:  43%|████▎     | 55/128 [15:08<23:26, 19.26s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.19s/it]

Processing:  44%|████▍     | 56/128 [15:26<22:37, 18.86s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:23<00:00,  1.60s/it]

Processing:  45%|████▍     | 57/128 [15:50<24:08, 20.40s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.20s/it]

Processing:  45%|████▌     | 58/128 [16:08<22:57, 19.67s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.18it/s]

Processing:  46%|████▌     | 59/128 [16:20<20:13, 17.59s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:19<00:00,  1.32s/it]

Processing:  47%|████▋     | 60/128 [16:40<20:42, 18.27s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.01s/it]

Processing:  48%|████▊     | 61/128 [16:56<19:22, 17.35s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.24it/s]

Processing:  48%|████▊     | 62/128 [17:08<17:22, 15.80s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.06it/s]

Processing:  49%|████▉     | 63/128 [17:22<16:35, 15.32s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.04s/it]

Processing:  50%|█████     | 64/128 [17:38<16:26, 15.42s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s]

Processing:  51%|█████     | 65/128 [17:51<15:28, 14.74s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.18it/s]

Processing:  52%|█████▏    | 66/128 [18:03<14:37, 14.15s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:11<00:00,  1.30it/s]

Processing:  52%|█████▏    | 67/128 [18:15<13:36, 13.39s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s]

Processing:  53%|█████▎    | 68/128 [18:30<13:46, 13.77s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:11<00:00,  1.28it/s]

Processing:  54%|█████▍    | 69/128 [18:42<12:56, 13.16s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:11<00:00,  1.25it/s]

Processing:  55%|█████▍    | 70/128 [18:54<12:23, 12.82s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s]

Processing:  55%|█████▌    | 71/128 [19:08<12:41, 13.36s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.07s/it]

Processing:  56%|█████▋    | 72/128 [19:24<13:13, 14.18s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.12it/s]

Processing:  57%|█████▋    | 73/128 [19:38<12:47, 13.96s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.02s/it]

Processing:  58%|█████▊    | 74/128 [19:53<12:56, 14.37s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.10s/it]

Processing:  59%|█████▊    | 75/128 [20:10<13:15, 15.01s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:18<00:00,  1.24s/it]

Processing:  59%|█████▉    | 76/128 [20:28<13:56, 16.10s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:20<00:00,  1.33s/it]

Processing:  60%|██████    | 77/128 [20:48<14:41, 17.28s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s]

Processing:  61%|██████    | 78/128 [21:01<13:22, 16.06s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s]

Processing:  62%|██████▏   | 79/128 [21:16<12:48, 15.68s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.24it/s]

Processing:  62%|██████▎   | 80/128 [21:28<11:41, 14.62s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s]

Processing:  63%|██████▎   | 81/128 [21:42<11:08, 14.23s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.16it/s]

Processing:  64%|██████▍   | 82/128 [21:55<10:36, 13.84s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:20<00:00,  1.36s/it]

Processing:  65%|██████▍   | 83/128 [22:15<11:51, 15.81s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.22it/s]

Processing:  66%|██████▌   | 84/128 [22:27<10:49, 14.76s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.06it/s]

Processing:  66%|██████▋   | 85/128 [22:42<10:27, 14.60s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s]

Processing:  67%|██████▋   | 86/128 [22:56<10:15, 14.66s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.09it/s]

Processing:  68%|██████▊   | 87/128 [23:10<09:49, 14.39s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s]

Processing:  69%|██████▉   | 88/128 [23:25<09:36, 14.41s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]

Processing:  70%|██████▉   | 89/128 [23:47<10:58, 16.89s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:34<00:00,  2.27s/it]

Processing:  70%|███████   | 90/128 [24:21<13:57, 22.04s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:19<00:00,  1.27s/it]

Processing:  71%|███████   | 91/128 [24:40<13:03, 21.17s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:24<00:00,  1.63s/it]

Processing:  72%|███████▏  | 92/128 [25:05<13:18, 22.18s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:24<00:00,  1.62s/it]

Processing:  73%|███████▎  | 93/128 [25:29<13:18, 22.83s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.19s/it]

Processing:  73%|███████▎  | 94/128 [25:47<12:06, 21.36s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.20s/it]

Processing:  74%|███████▍  | 95/128 [26:05<11:11, 20.35s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:21<00:00,  1.44s/it]

Processing:  75%|███████▌  | 96/128 [26:27<11:04, 20.75s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:20<00:00,  1.35s/it]

Processing:  76%|███████▌  | 97/128 [26:47<10:38, 20.61s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.09it/s]

Processing:  77%|███████▋  | 98/128 [27:01<09:17, 18.57s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.06s/it]

Processing:  77%|███████▋  | 99/128 [27:17<08:35, 17.77s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:11<00:00,  1.26it/s]

Processing:  78%|███████▊  | 100/128 [27:29<07:28, 16.01s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.04s/it]

Processing:  79%|███████▉  | 101/128 [27:44<07:08, 15.89s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s]

Processing:  80%|███████▉  | 102/128 [27:59<06:44, 15.56s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s]

Processing:  80%|████████  | 103/128 [28:14<06:22, 15.32s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:09<00:00,  1.62it/s]

Processing:  81%|████████▏ | 104/128 [28:23<05:24, 13.52s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:36<00:00,  2.42s/it]

Processing:  82%|████████▏ | 105/128 [29:00<07:48, 20.37s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:22<00:00,  1.47s/it]

Processing:  83%|████████▎ | 106/128 [29:22<07:39, 20.90s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.04s/it]

Processing:  84%|████████▎ | 107/128 [29:37<06:46, 19.34s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:18<00:00,  1.23s/it]

Processing:  84%|████████▍ | 108/128 [29:56<06:21, 19.08s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:24<00:00,  1.61s/it]

Processing:  85%|████████▌ | 109/128 [30:20<06:31, 20.62s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.18it/s]

Processing:  86%|████████▌ | 110/128 [30:33<05:28, 18.27s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.03s/it]

Processing:  87%|████████▋ | 111/128 [30:48<04:56, 17.41s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:18<00:00,  1.26s/it]

Processing:  88%|████████▊ | 112/128 [31:07<04:46, 17.88s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:23<00:00,  1.59s/it]

Processing:  88%|████████▊ | 113/128 [31:31<04:55, 19.71s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.12it/s]

Processing:  89%|████████▉ | 114/128 [31:45<04:09, 17.84s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s]

Processing:  90%|████████▉ | 115/128 [32:00<03:40, 16.98s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:33<00:00,  2.25s/it]

Processing:  91%|█████████ | 116/128 [32:34<04:24, 22.02s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:17<00:00,  1.17s/it]

Processing:  91%|█████████▏| 117/128 [32:51<03:47, 20.69s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.09it/s]

Processing:  92%|█████████▏| 118/128 [33:05<03:06, 18.63s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]

Processing:  93%|█████████▎| 119/128 [33:19<02:36, 17.34s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:15<00:00,  1.03s/it]

Processing:  94%|█████████▍| 120/128 [33:35<02:14, 16.80s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:16<00:00,  1.07s/it]

Processing:  95%|█████████▍| 121/128 [33:51<01:56, 16.59s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:13<00:00,  1.09it/s]

Processing:  95%|█████████▌| 122/128 [34:05<01:34, 15.76s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:21<00:00,  1.41s/it]

Processing:  96%|█████████▌| 123/128 [34:26<01:26, 17.38s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:25<00:00,  1.68s/it]

Processing:  97%|█████████▋| 124/128 [34:51<01:18, 19.72s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:18<00:00,  1.23s/it]

Processing:  98%|█████████▊| 125/128 [35:10<00:58, 19.36s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:12<00:00,  1.16it/s]

Processing:  98%|█████████▊| 126/128 [35:23<00:34, 17.45s/it]


Processed 15 / 15 examples: 100%|██████████| 15/15 [00:19<00:00,  1.29s/it]

Processing:  99%|█████████▉| 127/128 [35:42<00:18, 18.03s/it]


Processed 3 / 3 examples: 100%|██████████| 3/3 [00:05<00:00,  1.98s/it]

Processing: 100%|██████████| 128/128 [35:48<00:00, 16.78s/it]







# TASK 2.3

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

In [109]:
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("\n")
print(" -> Final Evaluation Results on All Examples: <- \n"
      f"Accuracy: {final_results['accuracy']}, \n"
      f"F1: {final_results['f1']}, \n"
      f"Precision: {final_results['precision']}, \n"
      f"Recall: {final_results['recall']} \n")
print("\n")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")

@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@


 -> Final Evaluation Results on All Examples: <- 
Accuracy: 0.9858490566037735, 
F1: 0.985355433901831, 
Precision: 0.9889684783159963, 
Recall: 0.9823544187906642 



@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@


You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

In [110]:
# how many samples they are both correct
%store -r DeBERTa_v3_predictions
optimized_llm_predictions = predictions
TEST_SIZE = len(optimized_llm_predictions)

non_llm_predictions = []
for i in range(len(section_to_split)):
    start_index = i * 212
    end_index = start_index + 212
    non_llm_predictions.extend(DeBERTa_v3_predictions[start_index:end_index])

# on how many samples both models are correct
correct_predictions = sum(
    1 for pred, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == pred['gold_label']) and (pred['gold_label'] == pred['pred_label'])
)

print(f"Both models are correct on {correct_predictions} out of {TEST_SIZE} samples.")
print(f"Both models are correct on {correct_predictions / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples llm is correct and DeBERTa_v3_ is incorrect
llm_correct_deberta_incorrect = sum(
    1 for pred, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == pred['gold_label']) and (pred['gold_label'] != pred['pred_label'])
)
print(f"LLM is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect} out of {TEST_SIZE} samples.")
print(f"LLM is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples DeBERTa_v3 is correct and llm is incorrect
deberta_correct_llm_incorrect = sum(
    1 for pred, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (pred['pred_label'] == pred['gold_label']) and (llm_pred.label != pred['gold_label'])
)
print(f"DeBERTa_v3 is correct and LLM is incorrect on {deberta_correct_llm_incorrect} out of {TEST_SIZE} samples.")
print(f"DeBERTa_v3 is correct and LLM is incorrect on {deberta_correct_llm_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")


# on how many samples both models are incorrect
both_incorrect = sum(
    1 for pred, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label != pred['gold_label']) and (pred['pred_label'] != pred['gold_label'])
)
print(f"Both models are incorrect on {both_incorrect} out of {TEST_SIZE} samples.")
print(f"Both models are incorrect on {both_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")


Both models are correct on 879 out of 1908 samples.
Both models are correct on 46.07% of the samples.


LLM is correct and DeBERTa_v3 is incorrect on 994 out of 1908 samples.
LLM is correct and DeBERTa_v3 is incorrect on 52.10% of the samples.


DeBERTa_v3 is correct and LLM is incorrect on 0 out of 1908 samples.
DeBERTa_v3 is correct and LLM is incorrect on 0.00% of the samples.


Both models are incorrect on 27 out of 1908 samples.
Both models are incorrect on 1.42% of the samples.


