# ANLI Baseline with LLM

You have to implement in this notebook a baseline for ANLI classification using an LLM.
This baseline must be implemented using DSPy.



In [1]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
from dotenv import load_dotenv
import os
import dspy
load_dotenv("grok_key.ini") 
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])

# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [2]:
## Implement the DSPy classifier program.

from typing import Literal
from tqdm import tqdm
import dspy

# Signature for the NLI task
class NLISignature(dspy.Signature):
    """
    Classify the relationship between the premise and hypothesis 
    to a label: entailment, neutral or contradiction.
    """
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

# A class for Parallel processing with progress display
class NLIClassifier(dspy.Module):
    def __init__(self, predictor_module: dspy.Module, batch_size: int = 20, num_threads: int = 8):
        super().__init__()
        self.predictor = predictor_module  # Predict, ChainOfThought, etc.
        self.batch_size = batch_size
        self.num_threads = num_threads

    def forward(self, examples: dspy.Example) -> list[dspy.Prediction]:
        # Display progress with tqdm while processing
        results = []
        for i in tqdm(range(0, len(examples), self.batch_size), desc="Processing"):
            sub_batch = examples[i:i + self.batch_size]
            processed = self.predictor.batch( # perform batch processing
                sub_batch,
                num_threads=self.num_threads
            )
            results.extend(processed)

        return results

## Load ANLI dataset

In [3]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [4]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [8]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [9]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [10]:
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'precision': 1.0,
 'recall': 0.5}

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

first we will optimize the model on a train set from "dev_r3":

In [None]:
# prepare the training set
import random 

preprocessed_examples = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"],
        label=row["label"]
    ).with_inputs("premise", "hypothesis")
    for row in dataset['dev_r3']  # Use the 'dev_r3' split for training
]

train_set_size = 40 # tradeoff between quality and speed after testing, permitted range is 20-100
trainset = random.sample(preprocessed_examples, train_set_size)  # pick examples randomly for training to avoid bias
print(f"Total examples: {len(trainset)}")

Total examples: 40


In [12]:
# Do the optimization using few-shot learning - only in task 1.4 we will use CoT

from dspy.teleprompt import BootstrapFewShot

label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}

def exact_match(example, pred, trace=None):
    # Ensure both labels are strings and lowercase
    ex_label = str(example.label).strip().lower()
    pred_label = str(pred.label).strip().lower()

    # In case example.label is already an int, use reverse mapping
    if ex_label.isdigit():
        id2label = {v: k for k, v in label2id.items()}
        ex_label = id2label[int(ex_label)]

    return label2id.get(pred_label) == label2id.get(ex_label)

def compute_metrics(preds, golds):
    return {
        "accuracy": accuracy.compute(predictions=preds, references=golds)["accuracy"],
        "precision": precision.compute(predictions=preds, references=golds, average="macro")["precision"],
        "recall": recall.compute(predictions=preds, references=golds, average="macro")["recall"],
        "f1": f1.compute(predictions=preds, references=golds, average="macro")["f1"],
    }

model_simple = dspy.Predict(NLISignature)
bootstrap = BootstrapFewShot(metric=exact_match)
optimized_bootstrap = bootstrap.compile(student=model_simple, trainset=trainset)

 25%|██▌       | 10/40 [00:52<02:38,  5.28s/it]

Bootstrapped 4 full traces after 10 examples for up to 1 rounds, amounting to 10 attempts.





In [21]:
# run the optimized model on 'test_r3' split
testset_with_labels = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"],
        label=row["label"]
    ).with_inputs("premise", "hypothesis")
    for row in dataset['test_r3']  # Use the 'test_r3' split for evaluation
]

testset_no_labels = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"]
    ).with_inputs("premise", "hypothesis")
    for row in dataset['test_r3']  # Use the 'test_r3' split for evaluation
]


program = NLIClassifier(optimized_bootstrap)
predictions = program(testset_no_labels)
pred_labels = [label2id[pred.label] for pred in predictions]

Processing:   0%|          | 0/60 [00:00<?, ?it/s]

Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1608.68it/s]
Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 2188.29it/s]

Processing:   3%|▎         | 2/60 [00:00<00:03, 16.60it/s]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 2097.89it/s]
Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1484.68it/s]

Processing:   7%|▋         | 4/60 [00:00<00:04, 13.04it/s]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1718.24it/s]
Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.52it/s]

Processing:  10%|█         | 6/60 [00:13<02:45,  3.07s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.72it/s]

Processing:  12%|█▏        | 7/60 [00:25<04:34,  5.18s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:43<00:00,  2.17s/it]

Processing:  13%|█▎        | 8/60 [01:08<13:06, 15.12s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s]

Processing:  15%|█▌        | 9/60 [01:24<12:58, 15.27s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:27<00:00,  1.39s/it]

Processing:  17%|█▋        | 10/60 [01:52<15:37, 18.76s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.48it/s]

Processing:  18%|█▊        | 11/60 [02:05<14:06, 17.27s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:10<00:00,  1.91it/s]

Processing:  20%|██        | 12/60 [02:16<12:14, 15.31s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.77it/s]

Processing:  22%|██▏       | 13/60 [02:27<11:05, 14.15s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:10<00:00,  1.82it/s]

Processing:  23%|██▎       | 14/60 [02:38<10:08, 13.23s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:10<00:00,  1.95it/s]

Processing:  25%|██▌       | 15/60 [02:48<09:16, 12.36s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.76it/s]

Processing:  27%|██▋       | 16/60 [03:00<08:51, 12.08s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.45it/s]

Processing:  28%|██▊       | 17/60 [03:14<09:01, 12.60s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.80it/s]

Processing:  30%|███       | 18/60 [03:25<08:31, 12.18s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:10<00:00,  1.85it/s]

Processing:  32%|███▏      | 19/60 [03:36<08:02, 11.77s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:20<00:00,  1.03s/it]

Processing:  33%|███▎      | 20/60 [03:56<09:37, 14.44s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:24<00:00,  1.22s/it]

Processing:  35%|███▌      | 21/60 [04:21<11:21, 17.46s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:22<00:00,  1.11s/it]

Processing:  37%|███▋      | 22/60 [04:43<11:58, 18.92s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.15it/s]

Processing:  38%|███▊      | 23/60 [05:00<11:23, 18.46s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s]

Processing:  40%|████      | 24/60 [05:17<10:42, 17.85s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.46it/s]

Processing:  42%|████▏     | 25/60 [05:31<09:42, 16.63s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.12it/s]

Processing:  43%|████▎     | 26/60 [05:49<09:37, 17.00s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:20<00:00,  1.03s/it]

Processing:  45%|████▌     | 27/60 [06:09<09:57, 18.11s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:20<00:00,  1.02s/it]

Processing:  47%|████▋     | 28/60 [06:30<10:01, 18.80s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.74it/s]

Processing:  48%|████▊     | 29/60 [06:41<08:35, 16.62s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:12<00:00,  1.64it/s]

Processing:  50%|█████     | 30/60 [06:53<07:39, 15.31s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:12<00:00,  1.65it/s]

Processing:  52%|█████▏    | 31/60 [07:06<06:56, 14.37s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.71it/s]

Processing:  53%|█████▎    | 32/60 [07:17<06:20, 13.58s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.70it/s]

Processing:  55%|█████▌    | 33/60 [07:29<05:52, 13.04s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.72it/s]

Processing:  57%|█████▋    | 34/60 [07:41<05:28, 12.63s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:09<00:00,  2.00it/s]

Processing:  58%|█████▊    | 35/60 [07:51<04:56, 11.85s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:12<00:00,  1.57it/s]

Processing:  60%|██████    | 36/60 [08:04<04:51, 12.14s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:12<00:00,  1.62it/s]

Processing:  62%|██████▏   | 37/60 [08:16<04:40, 12.20s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:22<00:00,  1.14s/it]

Processing:  63%|██████▎   | 38/60 [08:39<05:38, 15.38s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.41it/s]

Processing:  65%|██████▌   | 39/60 [08:53<05:15, 15.02s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.10it/s]

Processing:  67%|██████▋   | 40/60 [09:11<05:19, 15.99s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:12<00:00,  1.54it/s]

Processing:  68%|██████▊   | 41/60 [09:24<04:46, 15.10s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:57<00:00,  2.85s/it]

Processing:  70%|███████   | 42/60 [10:21<08:18, 27.69s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it]

Processing:  72%|███████▏  | 43/60 [10:50<07:57, 28.08s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.36it/s]

Processing:  73%|███████▎  | 44/60 [11:05<06:25, 24.09s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.47it/s]

Processing:  75%|███████▌  | 45/60 [11:19<05:14, 20.96s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.51it/s]

Processing:  77%|███████▋  | 46/60 [11:32<04:21, 18.66s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:40<00:00,  2.01s/it]

Processing:  78%|███████▊  | 47/60 [12:12<05:26, 25.14s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.37it/s]

Processing:  80%|████████  | 48/60 [12:27<04:23, 21.97s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.07it/s]

Processing:  82%|████████▏ | 49/60 [12:46<03:50, 21.00s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.53it/s]

Processing:  83%|████████▎ | 50/60 [12:59<03:06, 18.62s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.14it/s]

Processing:  85%|████████▌ | 51/60 [13:16<02:44, 18.33s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.49it/s]

Processing:  87%|████████▋ | 52/60 [13:30<02:15, 16.88s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s]

Processing:  88%|████████▊ | 53/60 [13:45<01:54, 16.38s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:11<00:00,  1.71it/s]

Processing:  90%|█████████ | 54/60 [13:57<01:29, 15.00s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it]

Processing:  92%|█████████▏| 55/60 [14:26<01:35, 19.14s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:12<00:00,  1.59it/s]

Processing:  93%|█████████▎| 56/60 [14:38<01:08, 17.19s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:12<00:00,  1.58it/s]

Processing:  95%|█████████▌| 57/60 [14:51<00:47, 15.83s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.17it/s]

Processing:  97%|█████████▋| 58/60 [15:08<00:32, 16.21s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.40it/s]

Processing:  98%|█████████▊| 59/60 [15:22<00:15, 15.63s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.15it/s]

Processing: 100%|██████████| 60/60 [15:40<00:00, 15.67s/it]







# TASK 1.3 Answers

a) Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

In [22]:
# use compute_metrics to evaluate the model and print the results
gold_labels = [ex.label for ex in testset_with_labels]
metrics = compute_metrics(pred_labels, gold_labels)

# Print the metrics
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("Model scores:")
print(f"F1 score: {metrics['f1']:.4f}")
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")

@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
Model scores:
F1 score: 0.7204
Accuracy: 0.7167
Precision: 0.7373
Recall: 0.7168
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@


Compare the results with the baseline and provide agreement metrics between the two models.

In [23]:
# how many samples they are both correct
%store -r pred_test_r3
optimized_llm_predictions = predictions
non_llm_predictions = pred_test_r3 # DeBERTa_v3_predictions
TEST_SIZE = len(optimized_llm_predictions)
# gold_labels = [label2id[row['gold_label']] for row in pred_test_r3]


# on how many samples both models are correct
correct_predictions = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == row['gold_label']) and (row['gold_label'] == row['pred_label'])
)

print(f"Both models are correct on {correct_predictions} out of {TEST_SIZE} samples.")
print(f"Both models are correct on {correct_predictions / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples llm is correct and DeBERTa_v3_ is incorrect
llm_correct_deberta_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == row['gold_label']) and (row['gold_label'] != row['pred_label'])
)
print(f"LLM is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect} out of {TEST_SIZE} samples.")
print(f"LLM is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples DeBERTa_v3 is correct and llm is incorrect
deberta_correct_llm_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (row['pred_label'] == row['gold_label']) and (llm_pred.label != row['gold_label'])
)
print(f"DeBERTa_v3 is correct and LLM is incorrect on {deberta_correct_llm_incorrect} out of {TEST_SIZE} samples.")
print(f"DeBERTa_v3 is correct and LLM is incorrect on {deberta_correct_llm_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")


# on how many samples both models are incorrect
both_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label != row['gold_label']) and (row['pred_label'] != row['gold_label'])
)
print(f"Both models are incorrect on {both_incorrect} out of {TEST_SIZE} samples.")
print(f"Both models are incorrect on {both_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

Both models are correct on 441 out of 1200 samples.
Both models are correct on 36.75% of the samples.


LLM is correct and DeBERTa_v3 is incorrect on 419 out of 1200 samples.
LLM is correct and DeBERTa_v3 is incorrect on 34.92% of the samples.


DeBERTa_v3 is correct and LLM is incorrect on 136 out of 1200 samples.
DeBERTa_v3 is correct and LLM is incorrect on 11.33% of the samples.


Both models are incorrect on 204 out of 1200 samples.
Both models are incorrect on 17.00% of the samples.


