# ANLI Baseline with LLM

You have to implement in this notebook a baseline for ANLI classification using an LLM.
This baseline must be implemented using DSPy.



In [76]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

# lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
# for ollama
# lm = dspy.LM('ollama_chat/llama3.2', api_base='http://localhost:11434', api_key='')
# dspy.configure(lm=lm)
lm = dspy.LM(
    "ollama/llama3.2:latest",
    api_base="http://localhost:11434",
    format="json"        # litellm translates this to Ollama's stream=false
)
dspy.configure(lm=lm, adapter=dspy.JSONAdapter())  # ask DSPy to keep JSON

In [98]:
from typing import Literal

## Implement the DSPy classifier program.
class NLIClassifier(dspy.Signature):
    premise     :str = dspy.InputField(desc="A short passage or statement. All facts should be inferred from this text alone.")
    hypothesis  :str = dspy.InputField(desc="A second statement to evaluate. Check if this follows from, contradicts, or is unrelated to the premise.")
    label       : Literal["entailment", "neutral", "contradiction"] = dspy.OutputField(
        desc=(
            "Return one of: 'entailment', 'neutral', or 'contradiction'.\n"
            "- 'entailment': The hypothesis must be true if the premise is true.\n"
            "- 'contradiction': The hypothesis must be false if the premise is true.\n"
            "- 'neutral': The hypothesis could be either true or false based on the premise."
        )
    )

predictor = dspy.Predict(NLIClassifier)

def zero_shot_nli_classifier(x):
    # if hasattr(x,'premise') and hasattr(x,'hypothesis'):
    return predictor(premise=x['premise'], hypothesis=x['hypothesis']).label
    # print("ERROR")

## Load ANLI dataset

In [78]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [79]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


In [99]:
example = dataset['test_r3'][0]
label_names = ["entailment", "neutral", "contradiction"]
print(label_names[example['label']])
print(zero_shot_nli_classifier(example))



entailment
contradiction


In [93]:
#Lets optimize
from dspy import BootstrapFewShot

def accuracy_metric(example, pred, *args):
    return int(pred.label.strip().lower() == example["label"])

opt = BootstrapFewShot(
    metric=accuracy_metric,
    max_bootstrapped_demos=500,
    max_labeled_demos=16,
    max_rounds=1,
)
def convert_dict(ex):
    return (
        dspy.Example(
           premise=ex["premise"],
           hypothesis=ex["hypothesis"],
           label={0: "entailment", 1: "neutral", 2: "contradiction"}[ex["label"]]
        )
        .with_inputs("premise", "hypothesis")
    )

trainset = [convert_dict(x) for x in dataset['dev_r3'].to_list()]
compiled_clf = opt.compile(predictor, trainset=trainset)  # returns an *improved* module

100%|██████████| 1200/1200 [00:57<00:00, 20.98it/s]

Bootstrapped 421 full traces after 1199 examples for up to 1 rounds, amounting to 1200 attempts.





In [94]:
def few_shots_nli_classifier(x):
    return compiled_clf(premise=x['premise'], hypothesis=x['hypothesis']).label

In [95]:
example = dataset['test_r3'][0]
label_names = ["entailment", "neutral", "contradiction"]
print(example)
print(few_shots_nli_classifier(example))



{'uid': 'b0e63408-53af-4b46-b33d-bf5ba302949f', 'premise': "It is Sunday today, let's take a look at the most popular posts of the last couple of days. Most of the articles this week deal with the iPhone, its future version called the iPhone 8 or iPhone Edition, and new builds of iOS and macOS. There are also some posts that deal with the iPhone rival called the Galaxy S8 and some other interesting stories. The list of the most interesting articles is available below. Stay tuned for more rumors and don't forget to follow us on Twitter.", 'hypothesis': 'The day of the passage is usually when Christians praise the lord together', 'label': 0, 'reason': "Sunday is considered Lord's Day"}
contradiction


## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [None]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [None]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [None]:
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]