# ANLI Baseline with LLM

You have to implement in this notebook a baseline for ANLI classification using an LLM.
This baseline must be implemented using DSPy.



In [1]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [None]:
# from typing import Literal

# class NLIClassifier(dspy.Signature):
#     premise = dspy.InputField(desc="A factual statement")
#     hypothesis = dspy.InputField(desc="A statement to evaluate against the premise")
#     label = dspy.OutputField(
#         desc="The relationship between premise and hypothesis: entailment, neutral, or contradiction",
#         choices=["entailment", "neutral", "contradiction"]
#     )


# # Create a Predict module
# nli_predict = dspy.Predict(NLIClassifier)


In [36]:
import dspy

class BatchedNLIPredictor(dspy.Module):
    def __init__(self, model, batch_size=15):
        super().__init__()
        self.model = model
        self.batch_size = batch_size

    def forward(self, examples):
        # examples: list of dspy.Example(premise=..., hypothesis=...)
        
        # Build a single prompt with multiple pairs
        prompt = "Choose the correct relationship between the hypothesis and the premise: entailment/neutral/contradiction  .\n\n"
        for i, ex in enumerate(examples, start=1):
            prompt += f"Example {i}:\nPremise: {ex.premise}\nHypothesis: {ex.hypothesis}\nAnswer (entailment/neutral/contradiction):\n"

        # Single LLM call
        response = self.model(prompt)

        # Parse the response line by line
        predictions = []
        for line in response.splitlines():
            if line.strip().lower().startswith(("entailment", "neutral", "contradiction")):
                predictions.append(line.strip().lower())

        # Ensure same length
        while len(predictions) < len(examples):
            predictions.append("unknown")

        return predictions



In [44]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import re

batched_predictor = BatchedNLIPredictor(model=lm, batch_size=15)

def process_chunk(chunk_examples, chunk_uids, batch_index):
    # Build single prompt for this batch
    prompt = "Classify the hypothesis and premise relationship: entailment / neutral / contradiction. **provide 1 word answer**.\n\n"
    for idx, ex in enumerate(chunk_examples, start=1):
        prompt += f"Example {idx}:\nPremise: {ex.premise}\nHypothesis: {ex.hypothesis}\nAnswer (entailment/neutral/contradiction):\n"

    # LLM call (one API call per chunk)
    print(f"Processing batch {batch_index} with {len(chunk_examples)} examples")
    response = batched_predictor.model(prompt)
    print(f"Done with batch {batch_index}")
    

    # Handle response text
    response_text = "\n".join(response) if isinstance(response, list) else str(response)
    lines = response_text.splitlines()

    # Extract predictions
    predictions = []
    for line in lines:
        m = re.search(r"(entailment|neutral|contradiction)", line, re.IGNORECASE)
        if m:
            predictions.append(m.group(1).lower())

    # Pad missing
    while len(predictions) < len(chunk_examples):
        predictions.append("unknown")

    return batch_index, list(zip(chunk_uids, predictions))


def predict_batch_parallel(batch, batch_size=15, max_workers=6):
    uids = list(batch["uid"])
    premises = list(batch["premise"])
    hypotheses = list(batch["hypothesis"])

    examples = [
        dspy.Example(premise=p, hypothesis=h)
        for p, h in zip(premises, hypotheses)
    ]

    futures = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(examples), batch_size):
            chunk = examples[i:i+batch_size]
            chunk_uids = uids[i:i+batch_size]
            batch_index = i // batch_size
            futures.append(executor.submit(process_chunk, chunk, chunk_uids, batch_index))

        # Collect results in order
        results_by_index = {}
        for future in as_completed(futures):
            batch_index, batch_result = future.result()
            results_by_index[batch_index] = batch_result

    # Flatten results in original order
    ordered_results = []
    for batch_index in sorted(results_by_index.keys()):
        ordered_results.extend(results_by_index[batch_index])

    return ordered_results


## Load ANLI dataset

In [5]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [6]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


In [None]:
test_r3 = dataset['test_r3']
mini_test_r3 = test_r3.select(range(18))


import asyncio

# Run in an asyncio event loop
ordered_predictions = predict_batch_parallel(mini_test_r3, batch_size=15)

for uid, label in ordered_predictions:
    print(uid, label)

# original_uids = mini_test_r3["uid"]
# print(mini_test_r3["uid"] == [uid for uid, _ in ordered_predictions])
# print(len(mini_test_r3["uid"]) == len([uid for uid, _ in ordered_predictions]))


Processing batch 0 with 15 examplesProcessing batch 1 with 3 examples

Done with batch 0
Done with batch 1
b0e63408-53af-4b46-b33d-bf5ba302949f neutral
41ac8273-490a-4c14-adc9-28e7992b40e3 entailment
9b4b2be0-7f5e-456f-b7af-627309123ad0 neutral
db7fef31-4f2f-4b5a-855e-831209eab172 neutral
4f73b484-af35-4922-8f90-4881682041cd contradiction
769d15ea-f94c-4387-b6db-04f7121e420e entailment
6c59f001-b2cc-4a9a-a4a8-e04ccc73e4d3 entailment
9cc974da-688c-4fc5-9d4b-475ec410576e entailment
641310d4-2120-4fa9-98a2-7f750ae42c72 neutral
33fd6df2-0810-49c2-8fe2-662229badebd entailment
8fc6bda7-e103-4c8c-8802-b4dab96a1734 entailment
7d98f706-81fe-4160-ad9d-4011a6a1dad6 entailment
ec8e2a30-c20b-4b7b-af9d-fccbbb2d9906 entailment
433e53c4-c555-4e65-bc02-ff8fc7ced582 entailment
49d6eafb-5f08-4b5c-996d-106614ee1092 contradiction
f1ed4114-24bd-49ae-8026-24464026970d entailment
2b6616c8-d20e-49c1-8a44-767570d5baf1 contradiction
d38256b1-bd97-4eb0-ac1d-0777bea1fb58 entailment
True
True


In [48]:

print("len of mini_test_r3:", len(mini_test_r3))
print("len of original_uids:", len(mini_test_r3["uid"]))

len of mini_test_r3: 65
len of original_uids: 65


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [26]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [27]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [29]:
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'precision': 1.0,
 'recall': 0.5}

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]