# ANLI with LLM

You have to implement in this notebook a better ANLI classifier using an LLM.
This classifier must be implemented using DSPy.


In [1]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

from dotenv import load_dotenv
load_dotenv("grok_key.ini")
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [None]:
from typing import Literal
from sentence_transformers import CrossEncoder

# Load the model for ranking
reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

# Joint prompt strategy
class ANLIJointCoT(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    explanation: str = dspy.OutputField(desc="Explanation of the label")
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

# Pipeline strategy
class ANLICOTExplanation(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    explanation: str = dspy.OutputField(desc="Explanation of the label")

class ANLILabel(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    explanation: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

joint_cot = dspy.Predict(ANLIJointCoT)
explain_cot = dspy.Predict(ANLICOTExplanation)
label_cot = dspy.Predict(ANLILabel)

def rank_similarity(query, passages):
    ranks = reranker_model.rank(query, passages, return_documents=True)
    return ranks

## Load ANLI dataset

In [4]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [5]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [6]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]