In [None]:
import os
import pandas as pd

from datasets import load_dataset
from functools import lru_cache
from pathlib import Path
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from tqdm import tqdm

@lru_cache(maxsize=1)
def load_nli_classifier(model_name="microsoft/deberta-v2-xlarge-mnli"):
    classifier = pipeline("zero-shot-classification", model=model_name, batch_size=256)    
    return classifier
    
@lru_cache(maxsize=1)
def load_msmarco_dataset():
    msmarco_ds = load_dataset("ms_marco", "v2.1", split="train", streaming=True)
    return msmarco_ds

def test_entailment():
    # simple test case
    query = "What was Apple's revenue in Q2 2025?"
    premise = "Apple reported $119.6 billion in revenue for Q2 2025."
    
    # test cases: entailment, contradiction, neutral
    hypothesis1 = "In the second quarter of 2025, Apple posted revenue of $119.6 billion, beating analyst expectations."
    hypothesis2 = "Apple's Q2 2025 revenue was only $90 billion, which was below expectations."
    hypothesis3 = "Apple launched the Vision Pro headset in 2024 as part of its expansion into spatial computing."
    
    classifier = load_nli_classifier()
    result = classifier(premise, [hypothesis1, hypothesis2, hypothesis3])
    scores = result['scores']
    print(result)

In [None]:
import torch

def annotate_msmarco(how_many=3):
    annotations = list()
    queries = list()

    classifier = load_nli_classifier()
    msmarco_ds = load_msmarco_dataset()
    
    for i, e in enumerate(tqdm(msmarco_ds, total=how_many, desc="Annotating MS MARCO")):
        if len(queries) > how_many:
            print("dataset limit reached")
            break
        
        query_id, query = e['query_id'], e['query']

        answer = next(iter(e['answers']), None)
        well_formed = next(iter(e['wellFormedAnswers']), None)
        premise = well_formed if well_formed else answer
        queries.append({'query_id': query_id, 'query': query, 'answer': bool(answer), 'well_formed': bool(well_formed)})
        
        # make sure the selected passage is first
        pairs = zip(e['passages']['is_selected'], e['passages']['passage_text'])
        passages = [(s, p) for s, p in pairs]
        passages.sort(reverse=True)
        
        if passages[0][0] != 1:
            continue
        
        if not premise:
            premise = query
        
        with torch.autocast(device_type="mps", dtype=torch.float16):
            result = classifier(premise, [p[1] for p in passages])
        
        scores = result['scores']
    
        for score, (selected, passage_text) in zip(scores, passages):
            annotations.append({
                'query_id': query_id,
                'passage': passage_text,
                'selected': selected == 1,
                'score': score,
            })
        
    return queries, annotations


In [None]:
queries, annotations = annotate_msmarco(how_many=10000)
q_df = pd.DataFrame(queries)
a_df = pd.DataFrame(annotations)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


well_formed_df = df[df.well_formed]
well_formed_selected_df = df[df.well_formed & (df.selected == True)]
answered_df = df[df.answer.astype(bool) & ~df.well_formed]
answered_selected_df = df[df.answer.astype(bool) & ~df.well_formed & (df.selected == True)]
answered_notselected_df = df[df.answer.astype(bool) & ~df.well_formed & (df.selected == False)]

not_selected_df = df[df.selected == False]

# easier to add a group and feed the input to seaborn
plot_df = pd.concat([
    answered_df.assign(group="Answered"),
    answered_selected_df.assign(group="Answered Selected"),
    answered_notselected_df.assign(group="Answered Not Selected"),
    well_formed_df.assign(group="Well Formed"),
    well_formed_selected_df.assign(group="Well Formed Selected"),
    not_selected_df.assign(group="Not Selected")
]).reset_index()


group_stats = plot_df.groupby("group")["score"].agg(["count", "median", "mean", "std"])
print(group_stats)

counts = plot_df["group"].value_counts()
plot_df["group_label"] = plot_df["group"].apply(lambda g: f"{g}\n(n={counts[g]})")


plt.figure(figsize=(10, 6))
ax = sns.violinplot(data=plot_df, x="group_label", y="score", inner=None)


means_str = "; ".join(
    f"{g}: μ={group_stats.loc[g,'mean']:.2f}" 
    for g in group_stats.index
)

plt.title(f"Score Distributions by Group\n{means_str}")

plt.xlabel("Group")
plt.ylabel("Score")
plt.tight_layout()
plt.show()
plt.savefig("msmarco_score_distribution.png.png")


In [None]:
# base stats
group_stats = plot_df.groupby("group")["score"].agg(["count", "median", "mean", "std"])

# number of rows per query_id where score > 0.4, averaged over all rows
rows_per_qid_over04 = (
    plot_df.groupby(["group", "query_id"])
    .apply(lambda g: (g["score"] > 0.4).sum())   # count rows > 0.4 for each query_id
    .groupby("group")
    .mean()
    .rename("avg_rows_gt_0.4_per_qid")
)

# proportion of all rows with score > 0.4
prop_over04 = (
    plot_df.groupby("group")["score"]
    .apply(lambda s: (s > 0.4).mean())
    .rename("prop_rows_gt_0.4")
)

# join into stats
group_stats = group_stats.join([rows_per_qid_over04, prop_over04])
print(group_stats)
