# ANLI Baseline with LLM

You have to implement in this notebook a baseline for ANLI classification using an LLM.
This baseline must be implemented using DSPy.



In [1]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [36]:
import dspy

class BatchedNLIPredictor(dspy.Module):
    def __init__(self, model, batch_size=15):
        super().__init__()
        self.model = model
        self.batch_size = batch_size

    def forward(self, examples):
        # examples: list of dspy.Example(premise=..., hypothesis=...)
        
        # Build a single prompt with multiple pairs
        prompt = "Choose the correct relationship between the hypothesis and the premise: entailment/neutral/contradiction  .\n\n"
        for i, ex in enumerate(examples, start=1):
            prompt += f"Example {i}:\nPremise: {ex.premise}\nHypothesis: {ex.hypothesis}\nAnswer (entailment/neutral/contradiction):\n"

        # Single LLM call
        response = self.model(prompt)

        # Parse the response line by line
        predictions = []
        for line in response.splitlines():
            if line.strip().lower().startswith(("entailment", "neutral", "contradiction")):
                predictions.append(line.strip().lower())

        # Ensure same length
        while len(predictions) < len(examples):
            predictions.append("unknown")

        return predictions



In [44]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import re

batched_predictor = BatchedNLIPredictor(model=lm, batch_size=15)

def process_chunk(chunk_examples, chunk_uids, batch_index):
    # Build single prompt for this batch
    prompt = "Classify the hypothesis and premise relationship: entailment / neutral / contradiction. **provide 1 word answer**.\n\n"
    for idx, ex in enumerate(chunk_examples, start=1):
        prompt += f"Example {idx}:\nPremise: {ex.premise}\nHypothesis: {ex.hypothesis}\nAnswer (entailment/neutral/contradiction):\n"

    # LLM call (one API call per chunk)
    print(f"Processing batch {batch_index} with {len(chunk_examples)} examples")
    response = batched_predictor.model(prompt)
    print(f"Done with batch {batch_index}")
    

    # Handle response text
    response_text = "\n".join(response) if isinstance(response, list) else str(response)
    lines = response_text.splitlines()

    # Extract predictions
    predictions = []
    for line in lines:
        m = re.search(r"(entailment|neutral|contradiction)", line, re.IGNORECASE)
        if m:
            predictions.append(m.group(1).lower())

    # Pad missing
    while len(predictions) < len(chunk_examples):
        predictions.append("unknown")

    return batch_index, list(zip(chunk_uids, predictions))


def predict_batch_parallel(batch, batch_size=15, max_workers=6):
    uids = list(batch["uid"])
    premises = list(batch["premise"])
    hypotheses = list(batch["hypothesis"])

    examples = [
        dspy.Example(premise=p, hypothesis=h)
        for p, h in zip(premises, hypotheses)
    ]

    futures = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(examples), batch_size):
            chunk = examples[i:i+batch_size]
            chunk_uids = uids[i:i+batch_size]
            batch_index = i // batch_size
            futures.append(executor.submit(process_chunk, chunk, chunk_uids, batch_index))

        # Collect results in order
        results_by_index = {}
        for future in as_completed(futures):
            batch_index, batch_result = future.result()
            results_by_index[batch_index] = batch_result

    # Flatten results in original order
    ordered_results = []
    for batch_index in sorted(results_by_index.keys()):
        ordered_results.extend(results_by_index[batch_index])

    return ordered_results


## Load ANLI dataset

In [5]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [59]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


In [79]:
test_r3 = dataset['test_r3']
ordered_llm_predictions = predict_batch_parallel(test_r3, batch_size=15)

Processing batch 0 with 15 examples
Done with batch 0
Processing batch 1 with 15 examples
Done with batch 1
Processing batch 2 with 15 examples
Done with batch 2
Processing batch 3 with 15 examples
Processing batch 4 with 15 examples
Processing batch 5 with 15 examples
Done with batch 3
Done with batch 4
Done with batch 5
Processing batch 8 with 15 examples
Processing batch 6 with 15 examples
Processing batch 7 with 15 examples
Processing batch 9 with 15 examples
Processing batch 10 with 15 examples
Processing batch 11 with 15 examples
Done with batch 8
Processing batch 12 with 15 examples
Done with batch 6
Processing batch 13 with 15 examples
Done with batch 11
Processing batch 14 with 15 examples
Done with batch 9
Processing batch 15 with 15 examples
Done with batch 10
Processing batch 16 with 15 examples
Done with batch 7
Processing batch 17 with 15 examples
Done with batch 14
Processing batch 18 with 15 examples
Done with batch 12
Processing batch 19 with 15 examples
Done with batc

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [80]:
from evaluate import load

# Load evaluation metrics from the `evaluate` library
accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")

# Define mapping from string labels to integer IDs
# Defined in Cell 11 in anli_baseline.ipynb
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}

# Extract predicted labels from ordered_llm_predictions (second element in each tuple)
predicted_labels = [label2id[label.lower()] for uid, label in ordered_llm_predictions]

# Extract gold labels from test_r3, slice to match predictions length to be safe
gold_labels = test_r3['label'][:len(predicted_labels)]


# Compute all metrics using integer IDs for predicted and gold labels
acc_result = accuracy.compute(predictions=predicted_labels, references=gold_labels)
prec_result = precision.compute(predictions=predicted_labels, references=gold_labels, average="weighted")
rec_result = recall.compute(predictions=predicted_labels, references=gold_labels, average="weighted")
f1_result = f1.compute(predictions=predicted_labels, references=gold_labels, average="weighted")

# Print out the evaluation results
print("Accuracy:", acc_result["accuracy"])
print("Precision:", prec_result["precision"])
print("Recall:", rec_result["recall"])
print("F1:", f1_result["f1"])


Accuracy: 0.6775
Precision: 0.6859799390313412
Recall: 0.6775
F1: 0.6805853334720283


## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

# NOTE FOR MICHAEL: 

1) in the assignment instructions on git , it says "Evaluate the model on the "test_r3" partition of the ANLI dataset". not on EACH test parition.
So we will compare only to "test_r3" parition.

2) we define DeBERTa baseline model as "Model 1" and the LLM baseline model as "Model 2"

In [81]:
%store -r pred_test_r3

In [82]:
# Model 1 predictions and gold labels (from pred_test_r3)
model1_preds = pred_test_r3[:len(ordered_llm_predictions)]

# Model 2 predictions from ordered_llm_predictions
model2_preds = ordered_llm_predictions

n = len(model2_preds)

both_correct = 0
only_model1_correct = 0
only_model2_correct = 0
both_incorrect = 0

for example, (uid, pred2) in zip(model1_preds, model2_preds):
    gold = example['gold_label']
    pred1 = example['pred_label']
    
    model1_correct = (pred1 == gold)
    model2_correct = (pred2 == gold)
    
    if model1_correct and model2_correct:
        both_correct += 1
    elif model1_correct and not model2_correct:
        only_model1_correct += 1
    elif not model1_correct and model2_correct:
        only_model2_correct += 1
    else:
        both_incorrect += 1

print(f"Both models correct on {both_correct} samples ({both_correct / n * 100:.2f}%).")
print(f"Only Model 1 correct on {only_model1_correct} samples ({only_model1_correct / n * 100:.2f}%).")
print(f"Only Model 2 correct on {only_model2_correct} samples ({only_model2_correct / n * 100:.2f}%).")
print(f"Both models incorrect on {both_incorrect} samples ({both_incorrect / n * 100:.2f}%).")


Both models correct on 433 samples (36.08%).
Only Model 1 correct on 144 samples (12.00%).
Only Model 2 correct on 380 samples (31.67%).
Both models incorrect on 243 samples (20.25%).
