# ANLI Baseline

This model illustrates how to use the DeBERTa-v3-base-mnli-fever-anli model to perform specialized inference on the ANLI dataset.
This dataset has 184M parameters. It was trained in 2021 on the basis of a BERT-like embedding approach: 
* The premise and the hypothesis are encoded using the DeBERTa-v3-base contextual encoder
* The encodings are then compared on a fine-tuned model to predict a distribution over the classification labels (entailment, contradiction, neutral)

Reported accuracy on ANLI is 0.495 (see https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) 



In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."

input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))  # device = "cuda:0" or "cpu"
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'entailment': 6.6, 'neutral': 17.3, 'contradiction': 76.1}


In [3]:
def evaluate(premise, hypothesis):
    input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
    output = model(input["input_ids"].to(device))
    prediction = torch.softmax(output["logits"][0], -1).tolist()
    prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
    return prediction

In [4]:
evaluate("The weather is nice today.", "It is sunny outside.")

{'entailment': 0.1, 'neutral': 99.8, 'contradiction': 0.0}

In [5]:
def get_prediction(pred_dict):
    if pred_dict["entailment"] > pred_dict["contradiction"]  and pred_dict["entailment"] > pred_dict["neutral"]:
        return "entailment"
    elif pred_dict["contradiction"] > pred_dict["entailment"]:
        return "contradiction"
    else:
        return "neutral"

In [6]:
get_prediction(evaluate("The weather is nice today.", "It is sunny outside."))

'neutral'

In [7]:
get_prediction(evaluate("It is sunny outside.", "The weather is nice today."))

'entailment'

In [8]:
get_prediction(evaluate("It is sunny outside.", "The weather is terrible today."))

'contradiction'

## Load ANLI dataset

In [9]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

In [10]:
dataset

DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 2923
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 4861
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 13375
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200


In [11]:
# Evaluate the model on the ANLI dataset
from tqdm import tqdm
def evaluate_on_dataset(dataset):
    results = []
    label_names = ["entailment", "neutral", "contradiction"]
    for example in tqdm(dataset):
        premise = example['premise']
        hypothesis = example['hypothesis']
        prediction = evaluate(premise, hypothesis)
        results.append({
            'premise': premise,
            'hypothesis': hypothesis,
            'prediction': prediction,
            'pred_label': get_prediction(prediction),
            'gold_label': label_names[example['label']],
            'reason': example['reason']
        })
    return results

### Task 1.1 - Evaluating ANLI samples on test sections

In [12]:
pred_test_r1 = evaluate_on_dataset(dataset['test_r1'])
pred_test_r2 = evaluate_on_dataset(dataset['test_r2'])

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [07:51<00:00,  2.12it/s]
100%|██████████| 1000/1000 [10:24<00:00,  1.60it/s]


In [13]:
pred_test_r3 = evaluate_on_dataset(dataset['test_r3'])

100%|██████████| 1200/1200 [08:46<00:00,  2.28it/s]


In [14]:
pred_test_r1[:5]

[{'premise': 'Ernest Jones is a British jeweller and watchmaker. Established in 1949, its first store was opened in Oxford Street, London. Ernest Jones specialises in diamonds and watches, stocking brands such as Gucci and Emporio Armani. Ernest Jones is part of the Signet Jewelers group.',
  'hypothesis': 'The first Ernest Jones store was opened on the continent of Europe.',
  'prediction': {'entailment': 99.5, 'neutral': 0.1, 'contradiction': 0.3},
  'pred_label': 'entailment',
  'gold_label': 'entailment',
  'reason': "The first store was opened in London, which is in Europe. It may have been difficult for the system because continents weren't mentioned."},
 {'premise': 'Old Trafford is a football stadium in Old Trafford, Greater Manchester, England, and the home of Manchester United. With a capacity of 75,643, it is the largest club football stadium in the United Kingdom, the second-largest football stadium, and the eleventh-largest in Europe. It is about 0.5 mi from Old Trafford C

In [15]:
pred_test_r2[:5]

[{'premise': 'There is a little Shia community in El Salvador. There is an Islamic Library operated by the Shia community, named "Fatimah Az-Zahra". They published the first Islamic magazine in Central America: "Revista Biblioteca Islámica". Additionally, they are credited with providing the first and only Islamic library dedicated to spreading Islamic culture in the country.',
  'hypothesis': 'The community is south of the United States.',
  'prediction': {'entailment': 94.5, 'neutral': 1.7, 'contradiction': 3.8},
  'pred_label': 'entailment',
  'gold_label': 'entailment',
  'reason': 'The community is in El Salvador which is south of the US.'},
 {'premise': '"Look at Me (When I Rock Wichoo)" is a song by American indie rock band Black Kids, taken from their debut album "Partie Traumatic". It was released in the UK by Almost Gold Recordings on September 8, 2008 and debuted on the Top 200 UK Singles Chart at number 175.',
  'hypothesis': 'The song was released in America in September 2

In [16]:
pred_test_r3[:5]  # Display the first 5 predictions

[{'premise': "It is Sunday today, let's take a look at the most popular posts of the last couple of days. Most of the articles this week deal with the iPhone, its future version called the iPhone 8 or iPhone Edition, and new builds of iOS and macOS. There are also some posts that deal with the iPhone rival called the Galaxy S8 and some other interesting stories. The list of the most interesting articles is available below. Stay tuned for more rumors and don't forget to follow us on Twitter.",
  'hypothesis': 'The day of the passage is usually when Christians praise the lord together',
  'prediction': {'entailment': 2.4, 'neutral': 97.4, 'contradiction': 0.2},
  'pred_label': 'neutral',
  'gold_label': 'entailment',
  'reason': "Sunday is considered Lord's Day"},
 {'premise': 'By The Associated Press WELLINGTON, New Zealand (AP) — All passengers and crew have survived a crash-landing of a plane in a lagoon in the Federated States of Micronesia. WELLINGTON, New Zealand (AP) — All passeng

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [21]:
from evaluate import load
!pip install scikit-learn

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


Collecting scikit-learn
  Using cached scikit_learn-1.7.0-cp313-cp313-win_amd64.whl.metadata (14 kB)
Collecting scipy>=1.8.0 (from scikit-learn)
  Using cached scipy-1.16.0-cp313-cp313-win_amd64.whl.metadata (60 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Using cached joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Using cached scikit_learn-1.7.0-cp313-cp313-win_amd64.whl (10.7 MB)
Using cached joblib-1.5.1-py3-none-any.whl (307 kB)
Using cached scipy-1.16.0-cp313-cp313-win_amd64.whl (38.4 MB)
Using cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn

   ---------- ----------------------------- 1/4 [scipy]
   ---------- ----------------------------- 1/4 [scipy]
   ---------- ----------------------------- 1/4 [scipy]
   ---------- ----------------------------- 1/4 [scipy]
   ---------- --

Downloading builder script: 7.56kB [00:00, 20.1MB/s]
Downloading builder script: 7.38kB [00:00, 36.7MB/s]
Downloading builder script: 6.79kB [00:00, 12.9MB/s]


In [24]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [25]:
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'precision': 1.0,
 'recall': 0.5}

### Task 1.2 - Investigate Errors of the NLI Model

Sample 20 errors from the baseline model, and investigate the reasons the model made a mistake.

In [35]:
import random
import pandas as pd

# Combine all test results
all_predictions = pred_test_r1 + pred_test_r2 + pred_test_r3

errors = [pred for pred in all_predictions if pred['pred_label'] != pred['gold_label']]

print(f"Total predictions: {len(all_predictions)}")
print(f"Total errors: {len(errors)}")
print(f"Error rate: {len(errors)/len(all_predictions):.3f}")

# Sample 20 random errors
random.seed(42)
sampled_errors = random.sample(errors, min(20, len(errors)))

# Create analysis table
error_analysis = []
for i, error in enumerate(sampled_errors):
    pred_label = error['pred_label']
    gold_label = error['gold_label']
    premise = error['premise']
    hypothesis = error['hypothesis']
    reason = error['reason']
    
    error_type = f"{gold_label} → {pred_label}"
    
    # Length analysis
    premise_len = len(premise.split())
    hypothesis_len = len(hypothesis.split())
    
    # Check for negation
    has_negation = any(neg in premise.lower() or neg in hypothesis.lower() 
                      for neg in ['not', 'no', 'never', 'none', 'nobody', 'nothing', 'neither'])
    
    # Check for complex reasoning
    has_complex_reasoning = any(word in premise.lower() or word in hypothesis.lower() 
                               for word in ['because', 'since', 'therefore', 'however', 'although', 'unless'])
    
    error_analysis.append({
        'Error_ID': i+1,
        'Error_Type': error_type,
        'Premise': premise[:100] + "..." if len(premise) > 100 else premise,
        'Hypothesis': hypothesis[:100] + "..." if len(hypothesis) > 100 else hypothesis,
        'Gold_Label': gold_label,
        'Pred_Label': pred_label,
        'Premise_Length': premise_len,
        'Hypothesis_Length': hypothesis_len,
        'Has_Negation': has_negation,
        'Complex_Reasoning': has_complex_reasoning,
        'Human_Reason': reason[:150] + "..." if len(reason) > 150 else reason
    })

df_errors = pd.DataFrame(error_analysis)

Total predictions: 3200
Total errors: 1500
Error rate: 0.469


In [36]:
df_errors.head(20)

Unnamed: 0,Error_ID,Error_Type,Premise,Hypothesis,Gold_Label,Pred_Label,Premise_Length,Hypothesis_Length,Has_Negation,Complex_Reasoning,Human_Reason
0,1,neutral → contradiction,A missed call is a telephone call that is deli...,Pre-agreed missed call messages are only pract...,neutral,contradiction,65,10,False,False,The context does specify if the countries men...
1,2,contradiction → entailment,"John-Michael Hakim Gibson, (born August 15, 19...",Gibson was 18 years old when he released his f...,contradiction,entailment,53,11,True,False,He was born in 1990 and released his debut alb...
2,3,entailment → contradiction,Svein Holden (born 23 August 1973) is a Norweg...,Svein Holden is 45 years old.,entailment,contradiction,40,6,True,False,"he was born on august 23 1973, thus as of 8/11..."
3,4,entailment → contradiction,"""Vanlose Stairway"" is a song written by Northe...",Vanlose Stairway is a Van Morrison Song and on...,entailment,contradiction,39,11,True,False,Its just vague enough to cause it problems
4,5,neutral → entailment,Peeya Rai Chowdhary is an Indian actress. Peey...,Peeya Rai was not married to Munshi while she ...,neutral,entailment,61,17,True,False,The context doesn't state what year the TV sho...
5,6,neutral → contradiction,Flatbush Avenue is a major avenue in the New Y...,The north end extension was going to be called...,neutral,contradiction,55,16,True,False,There's no indication whether it was ever goin...
6,7,neutral → contradiction,The South Kalgoorlie Gold Mine is a gold mine ...,The mine should be called South West Kalgoorie...,neutral,contradiction,49,9,False,False,My statement was a matter of opinion based off...
7,8,contradiction → entailment,The Robinson R44 is a four-seat light helicopt...,It took three years for the Robinson R44 to re...,contradiction,entailment,51,27,False,True,It took only two years for the certification t...
8,9,entailment → contradiction,The exchanges resulted in greatly improved fin...,Great Britain refused to address the latest Is...,entailment,contradiction,26,16,False,False,The created statement says that there is a pub...
9,10,neutral → contradiction,How to get a bank account<br>Choose a banking ...,You can only open it in a branch,neutral,contradiction,49,8,False,False,It does not really say whether or not you can ...


In [34]:
# Convert to DataFrame for better display
df_errors = pd.DataFrame(error_analysis)
print("\n=== ERROR ANALYSIS TABLE ===")
print(df_errors.to_string(index=False))

# Summary statistics
print("\n=== ERROR SUMMARY ===")
print(f"Error type distribution:")
error_type_counts = df_errors['Error_Type'].value_counts()
print(error_type_counts)

print(f"\nAverage premise length: {df_errors['Premise_Length'].mean():.1f} words")
print(f"Average hypothesis length: {df_errors['Hypothesis_Length'].mean():.1f} words")
print(f"Errors with negation: {df_errors['Has_Negation'].sum()}/20 ({df_errors['Has_Negation'].sum()/20*100:.1f}%)")
print(f"Errors with complex reasoning: {df_errors['Complex_Reasoning'].sum()}/20 ({df_errors['Complex_Reasoning'].sum()/20*100:.1f}%)")


=== ERROR ANALYSIS TABLE ===
 Error_ID                 Error_Type                                                                                                 Premise                                                                                              Hypothesis    Gold_Label    Pred_Label  Premise_Length  Hypothesis_Length  Has_Negation  Complex_Reasoning                                                                                                                                              Human_Reason
        1    neutral → contradiction A missed call is a telephone call that is deliberately terminated by the caller before being answere...                                      Pre-agreed missed call messages are only practiced in 3 countries.       neutral contradiction              65                 10         False              False                                                     The context  does specify if the countries mentioned are the only ones using this 

## Your Turn

Compute the classification metrics on the baseline model on each section of the ANLI dataset.

https://www.kaggle.com/code/faijanahamadkhan/llm-evaluation-framework-hugging-face provides good documentation on how to use the Huggingface evaluate library.

In [45]:
import numpy as np

def labels_to_numeric(labels):
    label_to_num = {"entailment": 0, "neutral": 1, "contradiction": 2}
    return [label_to_num[label] for label in labels]

def compute_metrics_for_section(predictions, section_name):
    pred_labels = [pred['pred_label'] for pred in predictions]
    gold_labels = [pred['gold_label'] for pred in predictions]
    
    pred_numeric = labels_to_numeric(pred_labels)
    gold_numeric = labels_to_numeric(gold_labels)
    
    acc = accuracy.compute(predictions=pred_numeric, references=gold_numeric)
    prec = precision.compute(predictions=pred_numeric, references=gold_numeric, average='weighted')
    rec = recall.compute(predictions=pred_numeric, references=gold_numeric, average='weighted')
    f1_score = f1.compute(predictions=pred_numeric, references=gold_numeric, average='weighted')
    
    metrics = {
        'accuracy': acc['accuracy'],
        'precision': prec['precision'],
        'recall': rec['recall'],
        'f1': f1_score['f1']
    }
    
    print(f"\n=== {section_name} METRICS ===")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"Total samples: {len(predictions)}")
    
    return metrics

# Compute metrics for each test section
metrics_r1 = compute_metrics_for_section(pred_test_r1, "TEST_R1")
metrics_r2 = compute_metrics_for_section(pred_test_r2, "TEST_R2") 
metrics_r3 = compute_metrics_for_section(pred_test_r3, "TEST_R3")

all_test_predictions = pred_test_r1 + pred_test_r2 + pred_test_r3
metrics_overall = compute_metrics_for_section(all_test_predictions, "OVERALL")

# Create summary table
import pandas as pd

summary_data = {
    'Section': ['Test_R1', 'Test_R2', 'Test_R3', 'Overall'],
    'Accuracy': [metrics_r1['accuracy'], metrics_r2['accuracy'], metrics_r3['accuracy'], metrics_overall['accuracy']],
    'F1': [metrics_r1['f1'], metrics_r2['f1'], metrics_r3['f1'], metrics_overall['f1']],
    'Precision': [metrics_r1['precision'], metrics_r2['precision'], metrics_r3['precision'], metrics_overall['precision']],
    'Recall': [metrics_r1['recall'], metrics_r2['recall'], metrics_r3['recall'], metrics_overall['recall']],
    'Samples': [len(pred_test_r1), len(pred_test_r2), len(pred_test_r3), len(all_test_predictions)]
}

df_summary = pd.DataFrame(summary_data)
print("\n=== SUMMARY TABLE ===")
print(df_summary.to_string(index=False, float_format='%.4f'))

# Additional analysis: per-class metrics
print("\n=== PER-CLASS METRICS (Overall) ===")
from sklearn.metrics import classification_report

all_pred_labels = [pred['pred_label'] for pred in all_test_predictions]
all_gold_labels = [pred['gold_label'] for pred in all_test_predictions]

print(classification_report(all_gold_labels, all_pred_labels, 
                          target_names=['entailment', 'neutral', 'contradiction'],
                          digits=4))


=== TEST_R1 METRICS ===
Accuracy: 0.6190
F1 Score: 0.6046
Precision: 0.6332
Recall: 0.6190
Total samples: 1000

=== TEST_R2 METRICS ===
Accuracy: 0.5040
F1 Score: 0.4894
Precision: 0.5077
Recall: 0.5040
Total samples: 1000

=== TEST_R3 METRICS ===
Accuracy: 0.4808
F1 Score: 0.4622
Precision: 0.4651
Recall: 0.4808
Total samples: 1200

=== OVERALL METRICS ===
Accuracy: 0.5312
F1 Score: 0.5149
Precision: 0.5292
Recall: 0.5312
Total samples: 3200

=== SUMMARY TABLE ===
Section  Accuracy     F1  Precision  Recall  Samples
Test_R1    0.6190 0.6046     0.6332  0.6190     1000
Test_R2    0.5040 0.4894     0.5077  0.5040     1000
Test_R3    0.4808 0.4622     0.4651  0.4808     1200
Overall    0.5312 0.5149     0.5292  0.5312     3200

=== PER-CLASS METRICS (Overall) ===
               precision    recall  f1-score   support

   entailment     0.4990    0.6968    0.5815      1062
      neutral     0.5950    0.6178    0.6061      1070
contradiction     0.4934    0.2800    0.3572      1068

     

In [44]:
df_summary

Unnamed: 0,Section,Accuracy,F1,Precision,Recall,Samples
0,Test_R1,0.619,0.604639,0.633233,0.619,1000
1,Test_R2,0.504,0.489391,0.507715,0.504,1000
2,Test_R3,0.480833,0.462211,0.465068,0.480833,1200
3,Overall,0.53125,0.5149,0.529213,0.53125,3200
