# ANLI with LLM

You have to implement in this notebook a better ANLI classifier using an LLM.
This classifier must be implemented using DSPy.


In [1]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy

from dotenv import load_dotenv
load_dotenv("grok_key.ini")
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [14]:
from typing import Literal
from sentence_transformers import SentenceTransformer, CrossEncoder

# Load the model for similarity scoring
similarity_model = SentenceTransformer("all-MiniLM-L6-v2")

# Joint prompt strategy
class ANLIJointCoT(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    explanation: str = dspy.OutputField(desc="Explanation of the label")
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

# Pipeline strategy
class ANLICOTExplanation(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    explanation: str = dspy.OutputField(desc="Explanation of the label")

class ANLILabel(dspy.Signature):
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    explanation: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

joint_cot = dspy.Predict(ANLIJointCoT)
explain_cot = dspy.Predict(ANLICOTExplanation)
label_cot = dspy.Predict(ANLILabel)

## Load ANLI dataset

In [3]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [4]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [5]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

In [19]:
def rank_similarity(premise, hypothesis, human_explanation, predicted_explanation):
    results = {}

    premis_hypothesis = f"Premise: {premise}\nHypothesis: {hypothesis}"
    passages = [premis_hypothesis, human_explanation, predicted_explanation]

    embeddings = similarity_model.encode(passages)
    similarities = similarity_model.similarity(embeddings, embeddings)

    results["premise_hypothesis_vs_human"] = similarities[0][1].item()
    results["premise_hypothesis_vs_predicted"] = similarities[0][2].item()
    results["human_vs_predicted"] = similarities[1][2].item()

    return results

In [None]:
from tqdm import tqdm

data = dataset["dev_r3"].select(range(50))
label_names = ["entailment", "neutral", "contradiction"]

joint_results = []
pipeline_results = []

for example in tqdm(data):
    premise = example['premise']
    hypothesis = example['hypothesis']
    human_explanation = example['reason']
    gold_label = label_names[example['label']]

    # Joint method    
    joint_output = joint_cot(premise=premise, hypothesis=hypothesis)
    joint_explanation = joint_output.explanation
    joint_label = joint_output.label

    # Pipeline method
    pipeline_explanation = explain_cot(premise=premise, hypothesis=hypothesis).explanation
    pipeline_label = label_cot(premise=premise, hypothesis=hypothesis, explanation=pipeline_explanation).label

    # Similarities
    joint_similarities = rank_similarity(premise, hypothesis, human_explanation, joint_explanation)
    pipeline_similarities = rank_similarity(premise, hypothesis, human_explanation, pipeline_explanation)

    joint_results.append({
        "gold_label": gold_label,
        "predicted_label": joint_label,
        "premise_hypothesis_vs_human": joint_similarities["premise_hypothesis_vs_human"],
        "premise_hypothesis_vs_predicted": joint_similarities["premise_hypothesis_vs_predicted"],
        "human_vs_predicted": joint_similarities["human_vs_predicted"]
    })

    pipeline_results.append({
        "gold_label": gold_label,
        "predicted_label": pipeline_label,
        "premise_hypothesis_vs_human": pipeline_similarities["premise_hypothesis_vs_human"],
        "premise_hypothesis_vs_predicted": pipeline_similarities["premise_hypothesis_vs_predicted"],
        "human_vs_predicted": pipeline_similarities["human_vs_predicted"]
    })  

# Calculate metrics
references = [result['gold_label'] for result in joint_results]

joint_predictions = [result['predicted_label'] for result in joint_results]
pipeline_predictions = [result['predicted_label'] for result in pipeline_results]

label_to_int = {"entailment": 0, "neutral": 1, "contradiction": 2}
references_int = [label_to_int[ref] for ref in references]
joint_predictions_int = [label_to_int[pred] for pred in joint_predictions]
pipeline_predictions_int = [label_to_int[pred] for pred in pipeline_predictions]

joint_accuracy_score = accuracy.compute(predictions=joint_predictions_int, references=references_int)['accuracy']
joint_f1_score = f1.compute(predictions=joint_predictions_int, references=references_int, average='macro')['f1']
joint_precision_score = precision.compute(predictions=joint_predictions_int, references=references_int, average='macro')['precision']
joint_recall_score = recall.compute(predictions=joint_predictions_int, references=references_int, average='macro')['recall']

pipeline_accuracy_score = accuracy.compute(predictions=pipeline_predictions_int, references=references_int)['accuracy']
pipeline_f1_score = f1.compute(predictions=pipeline_predictions_int, references=references_int, average='macro')['f1']
pipeline_precision_score = precision.compute(predictions=pipeline_predictions_int, references=references_int, average='macro')['precision']
pipeline_recall_score = recall.compute(predictions=pipeline_predictions_int, references=references_int, average='macro')['recall']

print("Joint Method Results:")
print(f"\tAccuracy: {joint_accuracy_score:.3f}")
print(f"\tF1: {joint_f1_score:.3f}")
print(f"\tPrecision: {joint_precision_score:.3f}")
print(f"\tRecall: {joint_recall_score:.3f}")
print("-" * 50)
print("Pipeline Method Results:")
print(f"\tAccuracy: {pipeline_accuracy_score:.3f}")
print(f"\tF1: {pipeline_f1_score:.3f}")
print(f"\tPrecision: {pipeline_precision_score:.3f}")
print(f"\tRecall: {pipeline_recall_score:.3f}")
print("="*50)
print()
print("="*50)

# analyze similarities
joint_premise_hypothesis_vs_human = [result['premise_hypothesis_vs_human'] for result in joint_results]
joint_premise_hypothesis_vs_predicted = [result['premise_hypothesis_vs_predicted'] for result in joint_results]
joint_human_vs_predicted = [result['human_vs_predicted'] for result in joint_results]
pipeline_premise_hypothesis_vs_human = [result['premise_hypothesis_vs_human'] for result in pipeline_results]
pipeline_premise_hypothesis_vs_predicted = [result['premise_hypothesis_vs_predicted'] for result in pipeline_results]
pipeline_human_vs_predicted = [result['human_vs_predicted'] for result in pipeline_results] 
print("Joint Method Similarities:")
print(f"\tPremise-Hypothesis vs Human: {sum(joint_premise_hypothesis_vs_human) / len(joint_premise_hypothesis_vs_human):.3f}")
print(f"\tPremise-Hypothesis vs Predicted: {sum(joint_premise_hypothesis_vs_predicted) / len(joint_premise_hypothesis_vs_predicted):.3f}")
print(f"\tHuman vs Predicted: {sum(joint_human_vs_predicted) / len(joint_human_vs_predicted):.3f}")
print("-" * 50) 
print("Pipeline Method Similarities:")
print(f"\tPremise-Hypothesis vs Human: {sum(pipeline_premise_hypothesis_vs_human) / len(pipeline_premise_hypothesis_vs_human):.3f}")
print(f"\tPremise-Hypothesis vs Predicted: {sum(pipeline_premise_hypothesis_vs_predicted) / len(pipeline_premise_hypothesis_vs_predicted):.3f}")
print(f"\tHuman vs Predicted: {sum(pipeline_human_vs_predicted) / len(pipeline_human_vs_predicted):.3f}")

# Summary comparison
print("="*50)
print("SUMMARY COMPARISON:")
print(f"Better Classification: {'Joint' if joint_accuracy_score > pipeline_accuracy_score else 'Pipeline'} ({joint_accuracy_score:.3f} vs {pipeline_accuracy_score:.3f})")
print(f"Better Explanation Relevance: {'Pipeline' if sum(pipeline_premise_hypothesis_vs_predicted) > sum(joint_premise_hypothesis_vs_predicted) else 'Joint'} ({sum(pipeline_premise_hypothesis_vs_predicted)/len(pipeline_premise_hypothesis_vs_predicted):.3f} vs {sum(joint_premise_hypothesis_vs_predicted)/len(joint_premise_hypothesis_vs_predicted):.3f})")
print(f"Better Human Similarity: {'Pipeline' if sum(pipeline_human_vs_predicted) > sum(joint_human_vs_predicted) else 'Joint'} ({sum(pipeline_human_vs_predicted)/len(pipeline_human_vs_predicted):.3f} vs {sum(joint_human_vs_predicted)/len(joint_human_vs_predicted):.3f})")

  0%|          | 0/50 [00:00<?, ?it/s]

 48%|████▊     | 24/50 [01:00<04:30, 10.39s/it]