# ANLI with LLM

You have to implement in this notebook a better ANLI classifier using an LLM.
This classifier must be implemented using DSPy.


In [53]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
from dotenv import load_dotenv
import os
import dspy
load_dotenv("grok_key.ini") 
lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])

# for ollama 
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

Step 1: deine the reward function and metrics

In [54]:
from typing import Literal

## Implement the DSPy classifier program.

from sentence_transformers import SentenceTransformer, util
import dspy

# Load the Sentence-Transformer model
model = SentenceTransformer('all-MiniLM-L6-v2') # the most cost effecitve model according to sbert documentation.

# Calculate the semantic similarity between two strings
def semantic_similarity(premise_hypothesis: str, reason: str) -> float:
    embeddings1 = model.encode(premise_hypothesis, convert_to_tensor=True)
    embeddings2 = model.encode(reason, convert_to_tensor=True)
    # The item() method returns a single value from a tensor
    return util.pytorch_cos_sim(embeddings1, embeddings2).item()

# Define a reward function that checks for semantic similarity
def reward_semantic_similarity(args: dict, pred: dspy.Prediction) -> float:
    # Access the 'premise' and 'hypothesis' from the input arguments dictionary
    premise = args['premise']
    hypothesis = args['hypothesis']

    # Access the 'reason' from the prediction object
    reason = pred.reason

    # Call the semantic_similarity helper function with the correct arguments
    return semantic_similarity(f"{premise}, {hypothesis}", reason)

Step 2: define the modules and signatures using Refine and the reward function.

IMPORTANT NOTE: we picked a heuristical threshold which is popular to use for similarity check.
This threshold acts as a soft filter - as the module will select the prediction with the highest reward anyway (and specificaly when the reward score is smaller than the threshold). See https://dspy.ai/api/modules/Refine/ for reference.

In [55]:

# -------------------------
# Signatures
# -------------------------
from tqdm import tqdm


threshold = 0.6 # popular heurisitic threshold for semantic similarity, according to ChatGPT, Gemini, Perplexity, Claude.

class NLISignatureJoint(dspy.Signature):
    """
    Classify the relationship between the premise and hypothesis 
    into: entailment, neutral, or contradiction.
    Provide a justification for your answer.
    """
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    reason: str = dspy.OutputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()


class NLISignatureReasonFirst(dspy.Signature):
    """
    First, reason about the relationship between the premise and hypothesis.
    """
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    reason: str = dspy.OutputField()


class NLISignatureLabelSecond(dspy.Signature):
    """
    Given the premise, hypothesis, and a reasoning explanation,
    classify the relationship into: entailment, neutral, or contradiction.
    """
    premise: str = dspy.InputField()
    hypothesis: str = dspy.InputField()
    reason: str = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()


# -------------------------
# Helper for batching
# -------------------------

def batch_predict(predictor, examples, batch_size=20, num_threads=8, desc="Processing"):
    results = []
    for i in tqdm(range(0, len(examples), batch_size), desc=desc):
        sub_batch = examples[i:i + batch_size]
        results.extend(
            predictor.batch(sub_batch, num_threads=num_threads)
        )
    return results


# -------------------------
# Joint Prompt Classifier
# -------------------------

class NLIJointClassifier(dspy.Module):
    """
    Runs the joint-prompt model: predicts both reason and label in one step.
    """
    def __init__(self, joint_predictor, batch_size=20, num_threads=8):
        super().__init__()
        self.joint_predictor = dspy.Refine(
            module=joint_predictor,
            N=3,  # Try up to 2 times
            reward_fn=reward_semantic_similarity,
            threshold=threshold  # default
        )
        self.batch_size = batch_size
        self.num_threads = num_threads

    def forward(self, examples):
        return batch_predict(
            self.joint_predictor, examples,
            batch_size=self.batch_size, num_threads=self.num_threads,
            desc="Joint Prompt Processing"
        )


# -------------------------
# Pipeline Classifier
# -------------------------

class NLIPipelineClassifier(dspy.Module):
    """
    Runs the reasoning-first → label-second pipeline.
    """
    def __init__(self, reason_predictor, label_predictor, batch_size=20, num_threads=8):
        super().__init__()
        self.reason_predictor = dspy.Refine(
            module=reason_predictor,
            N=3,  # Try up to 2 times
            reward_fn=reward_semantic_similarity,
            threshold=threshold # default
        )
        self.label_predictor = label_predictor
        self.batch_size = batch_size
        self.num_threads = num_threads

    def forward(self, examples):
        # Stage 1: Reasoning
        reasoning_results = batch_predict(
            self.reason_predictor, examples,
            batch_size=self.batch_size, num_threads=self.num_threads,
            desc="Stage 1: Reasoning"
        )

        # Build label stage input
        label_inputs = [
            dspy.Example(
                premise=ex.premise,
                hypothesis=ex.hypothesis,
                reason=reasoning_results[i].reason
            ).with_inputs("premise", "hypothesis", "reason")
            for i, ex in enumerate(examples)
        ]

        # Stage 2: Label prediction
        label_results = batch_predict(
            self.label_predictor, label_inputs,
            batch_size=self.batch_size, num_threads=self.num_threads,
            desc="Stage 2: Labeling"
        )

        # Attach reasoning to final predictions
        for i, pred in enumerate(label_results):
            pred.reason = reasoning_results[i].reason

        return label_results


## Load ANLI dataset

In [56]:
from datasets import load_dataset

dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [57]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [58]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

Step 3: prepare the data and the classifying programs, and define a metric computation function

In [59]:
# prepare the training set
preprocessed_examples_for_joint = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"],
        reason=row["reason"],
        label=row["label"]
    ).with_inputs("premise", "hypothesis")
    for row in dataset['dev_r3']
]

preprocessed_joint_no_reason = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"],
    ).with_inputs("premise", "hypothesis")
    for row in dataset['dev_r3']
]


preprocessed_examples_for_pipeline = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"],
        reason=row["reason"],
        label=row["label"]
    ).with_inputs("premise", "hypothesis")
    for row in dataset['dev_r3'] 
]

preprocessed_examples_for_pipeline_no_reason = [
    dspy.Example(
        premise=row["premise"],
        hypothesis=row["hypothesis"],
    ).with_inputs("premise", "hypothesis")
    for row in dataset['dev_r3']
]


# JOINT PROMPT
joint_predictor = dspy.ChainOfThought(NLISignatureJoint)
joint_program = NLIJointClassifier(joint_predictor)


# PIPELINE
reason_predictor = dspy.ChainOfThought(NLISignatureReasonFirst) # CoT needed only for reasoning
label_predictor = dspy.Predict(NLISignatureLabelSecond)
pipeline_program = NLIPipelineClassifier(
    reason_predictor=reason_predictor,
    label_predictor=label_predictor
)


def compute_metrics(preds, golds):
    return {
        "accuracy": accuracy.compute(predictions=preds, references=golds)["accuracy"],
        "precision": precision.compute(predictions=preds, references=golds, average="macro")["precision"],
        "recall": recall.compute(predictions=preds, references=golds, average="macro")["recall"],
        "f1": f1.compute(predictions=preds, references=golds, average="macro")["f1"],
    }

Step 4: calculate the results

In [60]:
NEW_TEST_SIZE = 1200 # THE NUMBER OF ITEMS IN DEV_R3

print("\n\nRunning joint-prompt model on the test set...")
results_joint = joint_program(preprocessed_joint_no_reason[:NEW_TEST_SIZE])

print("\n\nRunning pipeline model on the test set...")
results_pipeline = pipeline_program(preprocessed_examples_for_pipeline[:NEW_TEST_SIZE])



Running joint-prompt model on the test set...


Joint Prompt Processing:   0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  return forward_call(*args, **kwargs)


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:01<00:00, 13.14it/s]

Joint Prompt Processing:   2%|▏         | 1/60 [00:01<01:31,  1.54s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:01<00:00, 10.68it/s]

Joint Prompt Processing:   3%|▎         | 2/60 [00:03<01:43,  1.78s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:01<00:00, 10.20it/s]

Joint Prompt Processing:   5%|▌         | 3/60 [00:05<01:48,  1.91s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:02<00:00,  9.34it/s]

Joint Prompt Processing:   7%|▋         | 4/60 [00:07<01:54,  2.05s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:02<00:00,  7.45it/s]

Joint Prompt Processing:   8%|▊         | 5/60 [00:10<02:08,  2.34s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:02<00:00,  6.10s/it]

Joint Prompt Processing:  10%|█         | 6/60 [02:12<38:46, 43.09s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:47<00:00,  5.35s/it]

Joint Prompt Processing:  12%|█▏        | 7/60 [04:00<56:34, 64.04s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:59<00:00,  5.96s/it]

Joint Prompt Processing:  13%|█▎        | 8/60 [05:59<1:10:47, 81.68s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:44<00:00,  5.21s/it]

Joint Prompt Processing:  15%|█▌        | 9/60 [07:43<1:15:29, 88.81s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:31<00:00,  4.57s/it]

Joint Prompt Processing:  17%|█▋        | 10/60 [09:15<1:14:45, 89.71s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:19<00:00,  3.96s/it]

Joint Prompt Processing:  18%|█▊        | 11/60 [10:35<1:10:42, 86.58s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:35<00:00,  4.80s/it]

Joint Prompt Processing:  20%|██        | 12/60 [12:11<1:11:40, 89.58s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:52<00:00,  5.63s/it]

Joint Prompt Processing:  22%|██▏       | 13/60 [14:04<1:15:43, 96.66s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:25<00:00,  4.27s/it]

Joint Prompt Processing:  23%|██▎       | 14/60 [15:30<1:11:34, 93.36s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:23<00:00,  4.17s/it]

Joint Prompt Processing:  25%|██▌       | 15/60 [16:54<1:07:50, 90.45s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:47<00:00,  5.38s/it]

Joint Prompt Processing:  27%|██▋       | 16/60 [18:42<1:10:12, 95.73s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:29<00:00,  4.50s/it]

Joint Prompt Processing:  28%|██▊       | 17/60 [20:12<1:07:27, 94.12s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:38<00:00,  4.91s/it]

Joint Prompt Processing:  30%|███       | 18/60 [21:51<1:06:50, 95.49s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:40<00:00,  5.02s/it]

Joint Prompt Processing:  32%|███▏      | 19/60 [23:32<1:06:26, 97.24s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:59<00:00,  2.97s/it]

Joint Prompt Processing:  33%|███▎      | 20/60 [24:33<57:30, 86.27s/it]  


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:42<00:00,  5.10s/it]

Joint Prompt Processing:  35%|███▌      | 21/60 [26:16<59:25, 91.41s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:10<00:00,  6.53s/it]

Joint Prompt Processing:  37%|███▋      | 22/60 [28:28<1:05:32, 103.48s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:43<00:00,  5.16s/it]

Joint Prompt Processing:  38%|███▊      | 23/60 [30:12<1:04:00, 103.81s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:42<00:00,  5.15s/it]

Joint Prompt Processing:  40%|████      | 24/60 [31:57<1:02:24, 104.01s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:49<00:00,  5.48s/it]

Joint Prompt Processing:  42%|████▏     | 25/60 [33:47<1:01:51, 106.05s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:29<00:00,  4.48s/it]

Joint Prompt Processing:  43%|████▎     | 26/60 [35:18<57:28, 101.43s/it]  


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:32<00:00,  4.62s/it]

Joint Prompt Processing:  45%|████▌     | 27/60 [36:51<54:25, 98.96s/it] 


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:48<00:00,  5.42s/it]

Joint Prompt Processing:  47%|████▋     | 28/60 [38:40<54:23, 101.98s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:39<00:00,  4.99s/it]

Joint Prompt Processing:  48%|████▊     | 29/60 [40:21<52:29, 101.61s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:50<00:00,  5.54s/it]

Joint Prompt Processing:  50%|█████     | 30/60 [42:14<52:26, 104.89s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:26<00:00,  4.30s/it]

Joint Prompt Processing:  52%|█████▏    | 31/60 [43:42<48:18, 99.96s/it] 


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:45<00:00,  5.28s/it]

Joint Prompt Processing:  53%|█████▎    | 32/60 [45:29<47:39, 102.13s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:19<00:00,  3.98s/it]

Joint Prompt Processing:  55%|█████▌    | 33/60 [46:51<43:12, 96.03s/it] 


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:52<00:00,  5.62s/it]

Joint Prompt Processing:  57%|█████▋    | 34/60 [48:46<44:02, 101.62s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:41<00:00,  5.10s/it]

Joint Prompt Processing:  58%|█████▊    | 35/60 [50:30<42:42, 102.49s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:31<00:00,  4.59s/it]

Joint Prompt Processing:  60%|██████    | 36/60 [52:05<40:01, 100.05s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:31<00:00,  4.60s/it]

Joint Prompt Processing:  62%|██████▏   | 37/60 [53:39<37:41, 98.34s/it] 


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:23<00:00,  4.17s/it]

Joint Prompt Processing:  63%|██████▎   | 38/60 [55:05<34:40, 94.55s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:40<00:00,  5.04s/it]

Joint Prompt Processing:  65%|██████▌   | 39/60 [56:48<33:59, 97.13s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:42<00:00,  5.15s/it]

Joint Prompt Processing:  67%|██████▋   | 40/60 [58:33<33:08, 99.43s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:54<00:00,  5.73s/it]

Joint Prompt Processing:  68%|██████▊   | 41/60 [1:00:29<33:08, 104.66s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:57<00:00,  5.86s/it]

Joint Prompt Processing:  70%|███████   | 42/60 [1:02:29<32:45, 109.18s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:51<00:00,  5.56s/it]

Joint Prompt Processing:  72%|███████▏  | 43/60 [1:04:23<31:18, 110.48s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:49<00:00,  5.47s/it]

Joint Prompt Processing:  73%|███████▎  | 44/60 [1:06:15<29:36, 111.03s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:41<00:00,  5.10s/it]

Joint Prompt Processing:  75%|███████▌  | 45/60 [1:07:59<27:14, 108.97s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:03<00:00,  6.16s/it]

Joint Prompt Processing:  77%|███████▋  | 46/60 [1:10:05<26:35, 113.95s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:44<00:00,  5.23s/it]

Joint Prompt Processing:  78%|███████▊  | 47/60 [1:11:52<24:16, 112.02s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:43<00:00,  5.16s/it]

Joint Prompt Processing:  80%|████████  | 48/60 [1:13:39<22:03, 110.33s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:52<00:00,  5.62s/it]

Joint Prompt Processing:  82%|████████▏ | 49/60 [1:15:34<20:30, 111.86s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:54<00:00,  5.72s/it]

Joint Prompt Processing:  83%|████████▎ | 50/60 [1:17:31<18:54, 113.46s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:36<00:00,  4.81s/it]

Joint Prompt Processing:  85%|████████▌ | 51/60 [1:19:10<16:22, 109.12s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:04<00:00,  3.22s/it]

Joint Prompt Processing:  87%|████████▋ | 52/60 [1:20:17<12:52, 96.53s/it] 


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:03<00:00,  6.20s/it]

Joint Prompt Processing:  88%|████████▊ | 53/60 [1:22:24<12:19, 105.68s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:28<00:00,  4.44s/it]

Joint Prompt Processing:  90%|█████████ | 54/60 [1:23:56<10:08, 101.49s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:41<00:00,  5.05s/it]

Joint Prompt Processing:  92%|█████████▏| 55/60 [1:25:40<08:31, 102.23s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:59<00:00,  5.97s/it]

Joint Prompt Processing:  93%|█████████▎| 56/60 [1:27:42<07:12, 108.22s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:49<00:00,  5.47s/it]

Joint Prompt Processing:  95%|█████████▌| 57/60 [1:29:35<05:28, 109.53s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:47<00:00,  5.38s/it]

Joint Prompt Processing:  97%|█████████▋| 58/60 [1:31:26<03:40, 110.01s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:41<00:00,  5.10s/it]

Joint Prompt Processing:  98%|█████████▊| 59/60 [1:33:11<01:48, 108.48s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:23<00:00,  4.15s/it]

Joint Prompt Processing: 100%|██████████| 60/60 [1:34:37<00:00, 94.62s/it] 





Running pipeline model on the test set...


Stage 1: Reasoning:   0%|          | 0/60 [00:00<?, ?it/s]

Processed 20 / 20 examples: 100%|██████████| 20/20 [00:19<00:00,  1.01it/s]

Stage 1: Reasoning:   2%|▏         | 1/60 [00:23<23:23, 23.78s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:22<00:00,  1.10s/it]

Stage 1: Reasoning:   3%|▎         | 2/60 [00:49<24:07, 24.95s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:19<00:00,  1.02it/s]

Stage 1: Reasoning:   5%|▌         | 3/60 [01:12<22:56, 24.15s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:23<00:00,  1.16s/it]

Stage 1: Reasoning:   7%|▋         | 4/60 [01:39<23:39, 25.34s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:23<00:00,  1.18s/it]

Stage 1: Reasoning:   8%|▊         | 5/60 [02:07<23:56, 26.12s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:00<00:00,  6.02s/it]

Stage 1: Reasoning:  10%|█         | 6/60 [04:11<53:32, 59.48s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:49<00:00,  5.48s/it]

Stage 1: Reasoning:  12%|█▏        | 7/60 [06:04<1:08:02, 77.03s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:11<00:00,  6.57s/it]

Stage 1: Reasoning:  13%|█▎        | 8/60 [08:19<1:22:34, 95.29s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:09<00:00,  6.49s/it]

Stage 1: Reasoning:  15%|█▌        | 9/60 [10:31<1:30:54, 106.95s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:46<00:00,  5.32s/it]

Stage 1: Reasoning:  17%|█▋        | 10/60 [12:21<1:29:50, 107.81s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:35<00:00,  4.76s/it]

Stage 1: Reasoning:  18%|█▊        | 11/60 [13:59<1:25:38, 104.86s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:51<00:00,  5.56s/it]

Stage 1: Reasoning:  20%|██        | 12/60 [15:54<1:26:12, 107.76s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:40<00:00,  5.02s/it]

Stage 1: Reasoning:  22%|██▏       | 13/60 [17:38<1:23:32, 106.65s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:38<00:00,  4.93s/it]

Stage 1: Reasoning:  23%|██▎       | 14/60 [19:20<1:20:47, 105.39s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:46<00:00,  5.35s/it]

Stage 1: Reasoning:  25%|██▌       | 15/60 [21:11<1:20:13, 106.96s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:45<00:00,  5.29s/it]

Stage 1: Reasoning:  27%|██▋       | 16/60 [23:01<1:19:04, 107.84s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:24<00:00,  4.20s/it]

Stage 1: Reasoning:  28%|██▊       | 17/60 [24:28<1:12:48, 101.59s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:48<00:00,  5.41s/it]

Stage 1: Reasoning:  30%|███       | 18/60 [26:19<1:13:13, 104.61s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:08<00:00,  6.44s/it]

Stage 1: Reasoning:  32%|███▏      | 19/60 [28:33<1:17:20, 113.19s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:45<00:00,  5.29s/it]

Stage 1: Reasoning:  33%|███▎      | 20/60 [30:22<1:14:47, 112.20s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:38<00:00,  4.94s/it]

Stage 1: Reasoning:  35%|███▌      | 21/60 [32:05<1:10:57, 109.18s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:38<00:00,  4.95s/it]

Stage 1: Reasoning:  37%|███▋      | 22/60 [33:47<1:07:52, 107.16s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:08<00:00,  6.44s/it]

Stage 1: Reasoning:  38%|███▊      | 23/60 [36:00<1:10:46, 114.78s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:45<00:00,  5.28s/it]

Stage 1: Reasoning:  40%|████      | 24/60 [37:49<1:07:55, 113.20s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:35<00:00,  7.76s/it]

Stage 1: Reasoning:  42%|████▏     | 25/60 [40:28<1:14:05, 127.02s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:57<00:00,  5.87s/it]

Stage 1: Reasoning:  43%|████▎     | 26/60 [42:30<1:11:05, 125.45s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:11<00:00,  6.56s/it]

Stage 1: Reasoning:  45%|████▌     | 27/60 [44:45<1:10:38, 128.43s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:02<00:00,  6.12s/it]

Stage 1: Reasoning:  47%|████▋     | 28/60 [46:51<1:08:06, 127.69s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:39<00:00,  4.97s/it]

Stage 1: Reasoning:  48%|████▊     | 29/60 [48:35<1:02:17, 120.58s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:54<00:00,  5.72s/it]

Stage 1: Reasoning:  50%|█████     | 30/60 [50:33<59:46, 119.53s/it]  


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:04<00:00,  6.21s/it]

Stage 1: Reasoning:  52%|█████▏    | 31/60 [52:42<59:11, 122.47s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:01<00:00,  6.06s/it]

Stage 1: Reasoning:  53%|█████▎    | 32/60 [54:47<57:27, 123.13s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:38<00:00,  4.91s/it]

Stage 1: Reasoning:  55%|█████▌    | 33/60 [56:29<52:38, 116.98s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:04<00:00,  6.22s/it]

Stage 1: Reasoning:  57%|█████▋    | 34/60 [58:37<52:08, 120.33s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:57<00:00,  5.86s/it]

Stage 1: Reasoning:  58%|█████▊    | 35/60 [1:00:38<50:13, 120.54s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:49<00:00,  5.48s/it]

Stage 1: Reasoning:  60%|██████    | 36/60 [1:02:33<47:29, 118.74s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:04<00:00,  6.21s/it]

Stage 1: Reasoning:  62%|██████▏   | 37/60 [1:04:41<46:38, 121.69s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:34<00:00,  4.74s/it]

Stage 1: Reasoning:  63%|██████▎   | 38/60 [1:06:20<42:02, 114.64s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:52<00:00,  5.60s/it]

Stage 1: Reasoning:  65%|██████▌   | 39/60 [1:08:16<40:19, 115.21s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:00<00:00,  6.01s/it]

Stage 1: Reasoning:  67%|██████▋   | 40/60 [1:10:19<39:12, 117.63s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:54<00:00,  8.70s/it]

Stage 1: Reasoning:  68%|██████▊   | 41/60 [1:13:17<42:58, 135.71s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:06<00:00,  6.34s/it]

Stage 1: Reasoning:  70%|███████   | 42/60 [1:15:28<40:16, 134.26s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:17<00:00,  6.88s/it]

Stage 1: Reasoning:  72%|███████▏  | 43/60 [1:17:51<38:46, 136.85s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:08<00:00,  6.44s/it]

Stage 1: Reasoning:  73%|███████▎  | 44/60 [1:20:04<36:10, 135.68s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:18<00:00,  6.91s/it]

Stage 1: Reasoning:  75%|███████▌  | 45/60 [1:22:26<34:23, 137.54s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:17<00:00,  6.86s/it]

Stage 1: Reasoning:  77%|███████▋  | 46/60 [1:24:47<32:21, 138.65s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:02<00:00,  6.11s/it]

Stage 1: Reasoning:  78%|███████▊  | 47/60 [1:26:53<29:13, 134.89s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:59<00:00,  5.98s/it]

Stage 1: Reasoning:  80%|████████  | 48/60 [1:28:56<26:14, 131.18s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:41<00:00,  5.09s/it]

Stage 1: Reasoning:  82%|████████▏ | 49/60 [1:30:42<22:39, 123.63s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:04<00:00,  6.24s/it]

Stage 1: Reasoning:  83%|████████▎ | 50/60 [1:32:51<20:51, 125.20s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:44<00:00,  5.25s/it]

Stage 1: Reasoning:  85%|████████▌ | 51/60 [1:34:39<18:00, 120.04s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:37<00:00,  4.86s/it]

Stage 1: Reasoning:  87%|████████▋ | 52/60 [1:36:22<15:18, 114.86s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:11<00:00,  6.58s/it]

Stage 1: Reasoning:  88%|████████▊ | 53/60 [1:38:37<14:06, 120.90s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:55<00:00,  5.76s/it]

Stage 1: Reasoning:  90%|█████████ | 54/60 [1:40:36<12:02, 120.41s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:17<00:00,  6.87s/it]

Stage 1: Reasoning:  92%|█████████▏| 55/60 [1:42:57<10:32, 126.59s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:57<00:00,  5.90s/it]

Stage 1: Reasoning:  93%|█████████▎| 56/60 [1:44:59<08:21, 125.27s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:43<00:00,  5.18s/it]

Stage 1: Reasoning:  95%|█████████▌| 57/60 [1:46:48<06:00, 120.26s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:13<00:00,  6.69s/it]

Stage 1: Reasoning:  97%|█████████▋| 58/60 [1:49:05<04:11, 125.52s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [01:49<00:00,  5.47s/it]

Stage 1: Reasoning:  98%|█████████▊| 59/60 [1:50:57<02:01, 121.37s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [02:11<00:00,  6.59s/it]

Stage 1: Reasoning: 100%|██████████| 60/60 [1:53:13<00:00, 113.23s/it]





Stage 2: Labeling:   0%|          | 0/60 [00:00<?, ?it/s]

Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1259.25it/s]
Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1312.85it/s]
Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1525.98it/s]

Stage 2: Labeling:   5%|▌         | 3/60 [00:00<00:02, 21.94it/s]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1289.72it/s]
Processed 20 / 20 examples: 100%|██████████| 20/20 [00:00<00:00, 1968.56it/s]
Processed 20 / 20 examples: 100%|██████████| 20/20 [00:21<00:00,  1.09s/it]

Stage 2: Labeling:  10%|█         | 6/60 [00:21<03:52,  4.30s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:22<00:00,  1.12s/it]

Stage 2: Labeling:  12%|█▏        | 7/60 [00:44<07:19,  8.29s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]

Stage 2: Labeling:  13%|█▎        | 8/60 [00:58<08:21,  9.64s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.39it/s]

Stage 2: Labeling:  15%|█▌        | 9/60 [01:12<09:13, 10.86s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.38it/s]

Stage 2: Labeling:  17%|█▋        | 10/60 [01:27<09:53, 11.86s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.21it/s]

Stage 2: Labeling:  18%|█▊        | 11/60 [01:44<10:44, 13.16s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:23<00:00,  1.19s/it]

Stage 2: Labeling:  20%|██        | 12/60 [02:08<12:57, 16.19s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s]

Stage 2: Labeling:  22%|██▏       | 13/60 [02:24<12:43, 16.24s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.16it/s]

Stage 2: Labeling:  23%|██▎       | 14/60 [02:41<12:41, 16.55s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s]

Stage 2: Labeling:  25%|██▌       | 15/60 [02:56<12:06, 16.15s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.19it/s]

Stage 2: Labeling:  27%|██▋       | 16/60 [03:13<11:58, 16.33s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:19<00:00,  1.03it/s]

Stage 2: Labeling:  28%|██▊       | 17/60 [03:32<12:20, 17.23s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:23<00:00,  1.19s/it]

Stage 2: Labeling:  30%|███       | 18/60 [03:56<13:25, 19.17s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:23<00:00,  1.15s/it]

Stage 2: Labeling:  32%|███▏      | 19/60 [04:19<13:54, 20.34s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:23<00:00,  1.16s/it]

Stage 2: Labeling:  33%|███▎      | 20/60 [04:43<14:08, 21.21s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s]

Stage 2: Labeling:  35%|███▌      | 21/60 [04:59<12:49, 19.74s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.11it/s]

Stage 2: Labeling:  37%|███▋      | 22/60 [05:17<12:10, 19.22s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:19<00:00,  1.00it/s]

Stage 2: Labeling:  38%|███▊      | 23/60 [05:37<11:59, 19.45s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.09it/s]

Stage 2: Labeling:  40%|████      | 24/60 [05:55<11:28, 19.12s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s]

Stage 2: Labeling:  42%|████▏     | 25/60 [06:11<10:37, 18.22s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.14it/s]

Stage 2: Labeling:  43%|████▎     | 26/60 [06:29<10:12, 18.02s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:23<00:00,  1.19s/it]

Stage 2: Labeling:  45%|████▌     | 27/60 [06:53<10:52, 19.76s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.17it/s]

Stage 2: Labeling:  47%|████▋     | 28/60 [07:10<10:06, 18.97s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.19it/s]

Stage 2: Labeling:  48%|████▊     | 29/60 [07:27<09:28, 18.33s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:21<00:00,  1.08s/it]

Stage 2: Labeling:  50%|█████     | 30/60 [07:48<09:39, 19.31s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:15<00:00,  1.31it/s]

Stage 2: Labeling:  52%|█████▏    | 31/60 [08:04<08:44, 18.09s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.42it/s]

Stage 2: Labeling:  53%|█████▎    | 32/60 [08:18<07:53, 16.89s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.19it/s]

Stage 2: Labeling:  55%|█████▌    | 33/60 [08:34<07:35, 16.89s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.17it/s]

Stage 2: Labeling:  57%|█████▋    | 34/60 [08:52<07:21, 16.97s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.21it/s]

Stage 2: Labeling:  58%|█████▊    | 35/60 [09:08<07:01, 16.84s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.21it/s]

Stage 2: Labeling:  60%|██████    | 36/60 [09:25<06:42, 16.76s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:26<00:00,  1.30s/it]

Stage 2: Labeling:  62%|██████▏   | 37/60 [09:51<07:29, 19.55s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:15<00:00,  1.29it/s]

Stage 2: Labeling:  63%|██████▎   | 38/60 [10:06<06:43, 18.35s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.07it/s]

Stage 2: Labeling:  65%|██████▌   | 39/60 [10:25<06:27, 18.45s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s]

Stage 2: Labeling:  67%|██████▋   | 40/60 [10:40<05:49, 17.47s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.36it/s]

Stage 2: Labeling:  68%|██████▊   | 41/60 [10:55<05:16, 16.65s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.09it/s]

Stage 2: Labeling:  70%|███████   | 42/60 [11:13<05:08, 17.15s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.08it/s]

Stage 2: Labeling:  72%|███████▏  | 43/60 [11:32<04:58, 17.57s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:13<00:00,  1.48it/s]

Stage 2: Labeling:  73%|███████▎  | 44/60 [11:45<04:21, 16.37s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:15<00:00,  1.29it/s]

Stage 2: Labeling:  75%|███████▌  | 45/60 [12:01<04:01, 16.12s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.11it/s]

Stage 2: Labeling:  77%|███████▋  | 46/60 [12:19<03:53, 16.68s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:20<00:00,  1.04s/it]

Stage 2: Labeling:  78%|███████▊  | 47/60 [12:40<03:52, 17.90s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.11it/s]

Stage 2: Labeling:  80%|████████  | 48/60 [12:58<03:35, 17.93s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:19<00:00,  1.03it/s]

Stage 2: Labeling:  82%|████████▏ | 49/60 [13:17<03:22, 18.40s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.38it/s]

Stage 2: Labeling:  83%|████████▎ | 50/60 [13:32<02:52, 17.23s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:22<00:00,  1.12s/it]

Stage 2: Labeling:  85%|████████▌ | 51/60 [13:54<02:48, 18.78s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:20<00:00,  1.00s/it]

Stage 2: Labeling:  87%|████████▋ | 52/60 [14:14<02:33, 19.16s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s]

Stage 2: Labeling:  88%|████████▊ | 53/60 [14:30<02:07, 18.28s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:22<00:00,  1.10s/it]

Stage 2: Labeling:  90%|█████████ | 54/60 [14:52<01:56, 19.43s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.15it/s]

Stage 2: Labeling:  92%|█████████▏| 55/60 [15:10<01:34, 18.83s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:18<00:00,  1.09it/s]

Stage 2: Labeling:  93%|█████████▎| 56/60 [15:28<01:14, 18.67s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.14it/s]

Stage 2: Labeling:  95%|█████████▌| 57/60 [15:46<00:54, 18.31s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:20<00:00,  1.02s/it]

Stage 2: Labeling:  97%|█████████▋| 58/60 [16:06<00:37, 18.93s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:17<00:00,  1.12it/s]

Stage 2: Labeling:  98%|█████████▊| 59/60 [16:24<00:18, 18.61s/it]


Processed 20 / 20 examples: 100%|██████████| 20/20 [00:14<00:00,  1.38it/s]

Stage 2: Labeling: 100%|██████████| 60/60 [16:38<00:00, 16.65s/it]







Step 5: display the Metrics results

In [61]:
label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
pred_labels_joint = [label2id[pred.label] for pred in results_joint]
gold_labels_joint = [ex.label for ex in preprocessed_examples_for_joint[:NEW_TEST_SIZE]]
metrics_joint = compute_metrics(pred_labels_joint, gold_labels_joint)

# Print the metrics
print("\n@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("JOINT Model scores:")
print(f"F1 score: {metrics_joint['f1']:.4f}")
print(f"Accuracy: {metrics_joint['accuracy']:.4f}")
print(f"Precision: {metrics_joint['precision']:.4f}")
print(f"Recall: {metrics_joint['recall']:.4f}")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n")

#######################################################################

pred_labels_pipeline = [label2id[pred.label] for pred in results_pipeline]
gold_labels_pipeline = [ex.label for ex in preprocessed_examples_for_pipeline[:NEW_TEST_SIZE]]
metrics_pipeline = compute_metrics(pred_labels_pipeline, gold_labels_pipeline)

# Print the metrics
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
print("PIPELINE Model scores:")
print(f"F1 score: {metrics_pipeline['f1']:.4f}")
print(f"Accuracy: {metrics_pipeline['accuracy']:.4f}")
print(f"Precision: {metrics_pipeline['precision']:.4f}")
print(f"Recall: {metrics_pipeline['recall']:.4f}")
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")



@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
JOINT Model scores:
F1 score: 0.6967
Accuracy: 0.6917
Precision: 0.7223
Recall: 0.6916
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
PIPELINE Model scores:
F1 score: 0.7004
Accuracy: 0.7017
Precision: 0.6996
Recall: 0.7017
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@


Step 6: define a function to calculate and present the Semantic similarity scores across the 3 passages and two models

In [62]:
def present_semantic_similarity(pairs: list[dspy.Example], predictions: list[dspy.Example]) -> None:
    """
    Computes and displays semantic similarity statistics between (premise, hypothesis) pairs
    and the generated reasons from a model.
    
    Args:
        pairs: List of input dspy.Example objects containing 'premise' and 'hypothesis'.
        predictions: List of output dspy.Example objects containing 'reason'.
    """
    assert len(pairs) == len(predictions), "Input and output lists must have the same length."

    if len(pairs[0]) == 1:
        scores = [
            semantic_similarity(f"{pair.reason}", pred.reason)
            for pair, pred in zip(pairs, predictions)
        ]
    else:
        scores = [
            semantic_similarity(f"{pair.premise}, {pair.hypothesis}", pred.reason)
            for pair, pred in zip(pairs, predictions)
        ]

    above_threshold_count = sum(score >= threshold for score in scores)
    print(f"the threshold is: {threshold}")
    print(f"Number of pairs with score above threshold ({threshold}): {above_threshold_count} out of {len(pairs)}")
    print(f"Percentage of pairs with score above threshold:{above_threshold_count / len(pairs) * 100:.2f}%")

    average_score = sum(scores) / len(scores)
    print(f"Average semantic similarity score for model: {average_score:.4f}")
    above_average_count = sum(score >= average_score for score in scores)
    print(f"Number of pairs with score above average ({average_score:.4f}): {above_average_count} out of {len(pairs)}")
    print(f"Percentage of pairs with score above average:{above_average_count / len(pairs) * 100:.2f}%")


Step 7: Perform the semantic similarity calculation on the three passages

In [63]:
######## PASSAGE 1 ##########
# HUMAN reason vs. LLM reason
human_examples = [dspy.Example(reason=row["reason"]) for row in dataset['dev_r3']][:NEW_TEST_SIZE] 
print("\n\n@@@ Calculating semantic similarity of HUMAN reason vs JOINT LLM reason...@@@")
present_semantic_similarity(human_examples, results_joint)
print("\n\n@@@ Calculating semantic similarity for the HUMAN reason vs PIPELINE LLM reason...@@@")
present_semantic_similarity(human_examples, results_pipeline)

######## PASSAGE 2 ##########
# (premise, hypothesis) pairs and the generated reasons from the model
print("\n\n@@@ Calculating semantic similarity for the (premise, hypothesis) pairs vs JOINT results...@@@")
present_semantic_similarity(preprocessed_joint_no_reason[:NEW_TEST_SIZE], results_joint)
print("\n\n@@@ Calculating semantic similarity for the (premise, hypothesis) pairs vs PIPELINE results...@@@")
present_semantic_similarity(preprocessed_examples_for_pipeline_no_reason[:NEW_TEST_SIZE], results_pipeline)


######## PASSAGE 3 ##########
# (premise, hypothesis) pairs and the reasons provided from the model
human_predictions = [dspy.Prediction(reason=row["reason"]) for row in dataset['dev_r3']][:NEW_TEST_SIZE]  # Use the first 40 examples for human predictions
print("\n\n@@@ Calculating semantic similarity for the (premise, hypothesis) pairs vs HUMAN REASON results... @@@")
present_semantic_similarity(preprocessed_examples_for_joint[:NEW_TEST_SIZE], human_predictions) # no importance for joint\pipeline here, both have the same dataset in this case





@@@ Calculating semantic similarity of HUMAN reason vs JOINT LLM reason...@@@


  return forward_call(*args, **kwargs)


the threshold is: 0.6
Number of pairs with score above threshold (0.6): 559 out of 1200
Percentage of pairs with score above threshold:46.58%
Average semantic similarity score for model: 0.5569
Number of pairs with score above average (0.5569): 649 out of 1200
Percentage of pairs with score above average:54.08%


@@@ Calculating semantic similarity for the HUMAN reason vs PIPELINE LLM reason...@@@
the threshold is: 0.6
Number of pairs with score above threshold (0.6): 505 out of 1200
Percentage of pairs with score above threshold:42.08%
Average semantic similarity score for model: 0.5368
Number of pairs with score above average (0.5368): 649 out of 1200
Percentage of pairs with score above average:54.08%


@@@ Calculating semantic similarity for the (premise, hypothesis) pairs vs JOINT results...@@@
the threshold is: 0.6
Number of pairs with score above threshold (0.6): 720 out of 1200
Percentage of pairs with score above threshold:60.00%
Average semantic similarity score for model: 0.

Step 8: Compute the agreement between the models
NOTE: MODEL 1 = JOINT , MODEL 2 = PIPELINE

In [64]:
# MODEL 1 == JOINT
# MODEL 2 == PIPELINE
gold_labels = [row.label for row in preprocessed_examples_for_joint] # no importance, both preprocessed contain the same data
TEST_SIZE = len(results_joint)

# on how many samples both models are correct
correct_predictions = 0
for i in range(TEST_SIZE):
    if (label2id[results_joint[i].label] == gold_labels[i]) and (label2id[results_pipeline[i].label] == gold_labels[i]):
        correct_predictions += 1
print(f"Both models are correct on {correct_predictions} out of {TEST_SIZE} samples.")
print(f"Both models are correct on {correct_predictions / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples llm is correct and DeBERTa_v3_ is incorrect
only_joint_correct = 0
for i in range(TEST_SIZE):
    if (label2id[results_joint[i].label] == gold_labels[i]) and (label2id[results_pipeline[i].label] != gold_labels[i]):
        only_joint_correct += 1
print(f"only JOINT correct on {only_joint_correct} out of {TEST_SIZE} samples.")
print(f"only JOINT correct on {only_joint_correct / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")


# On how many samples DeBERTa_v3 is correct and llm is incorrect
only_pipeline_correct = 0
for i in range(TEST_SIZE):
    if (label2id[results_joint[i].label] != gold_labels[i]) and (label2id[results_pipeline[i].label] == gold_labels[i]):
        only_pipeline_correct += 1
print(f"only PIPELINE correct on {only_pipeline_correct} out of {TEST_SIZE} samples.")
print(f"only PIPELINE correct on {only_pipeline_correct / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")


# on how many samples both models are incorrect
both_wrong = 0
for i in range(TEST_SIZE):
    if (label2id[results_joint[i].label] != gold_labels[i]) and (label2id[results_pipeline[i].label] != gold_labels[i]):
        both_wrong += 1
print(f"both WRONG on {both_wrong} out of {TEST_SIZE} samples.")
print(f"both WRONG on {both_wrong / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

Both models are correct on 724 out of 1200 samples.
Both models are correct on 60.33% of the samples.


only JOINT correct on 106 out of 1200 samples.
only JOINT correct on 8.83% of the samples.


only PIPELINE correct on 118 out of 1200 samples.
only PIPELINE correct on 9.83% of the samples.


both WRONG on 252 out of 1200 samples.
both WRONG on 21.00% of the samples.




STEP 9: Compute the agreement between DeBERTa_v3 and the JOINT, PROMPT models

In [None]:
# DeBERTa_v3 vs JOINT
%store -r pred_test_r3
optimized_llm_predictions = results_joint
non_llm_predictions = pred_test_r3 # DeBERTa_v3_predictions
TEST_SIZE = len(optimized_llm_predictions)
# gold_labels = [label2id[row['gold_label']] for row in pred_test_r3]


# on how many samples both models are correct
correct_predictions = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == row['gold_label']) and (row['gold_label'] == row['pred_label'])
)

print(f"Both models are correct on {correct_predictions} out of {TEST_SIZE} samples.")
print(f"Both models are correct on {correct_predictions / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples llm is correct and DeBERTa_v3_ is incorrect
llm_correct_deberta_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == row['gold_label']) and (row['gold_label'] != row['pred_label'])
)
print(f"LLM JOINT is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect} out of {TEST_SIZE} samples.")
print(f"LLM JOINT is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples DeBERTa_v3 is correct and llm is incorrect
deberta_correct_llm_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (row['pred_label'] == row['gold_label']) and (llm_pred.label != row['gold_label'])
)
print(f"DeBERTa_v3 is correct and LLM JOINT is incorrect on {deberta_correct_llm_incorrect} out of {TEST_SIZE} samples.")
print(f"DeBERTa_v3 is correct and LLM JOINT is incorrect on {deberta_correct_llm_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")


# on how many samples both models are incorrect
both_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label != row['gold_label']) and (row['pred_label'] != row['gold_label'])
)
print(f"Both models are incorrect on {both_incorrect} out of {TEST_SIZE} samples.")
print(f"Both models are incorrect on {both_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

Both models are correct on 282 out of 1200 samples.
Both models are correct on 23.50% of the samples.


LLM JOINT is correct and DeBERTa_v3 is incorrect on 357 out of 1200 samples.
LLM JOINT is correct and DeBERTa_v3 is incorrect on 29.75% of the samples.


DeBERTa_v3 is correct and LLM JOINT is incorrect on 295 out of 1200 samples.
DeBERTa_v3 is correct and LLM JOINT is incorrect on 24.58% of the samples.


Both models are incorrect on 266 out of 1200 samples.
Both models are incorrect on 22.17% of the samples.




In [66]:
# DeBERTa_v3 vs PIPELINE
%store -r pred_test_r3
optimized_llm_predictions = results_pipeline
non_llm_predictions = pred_test_r3 # DeBERTa_v3_predictions
TEST_SIZE = len(optimized_llm_predictions)
# gold_labels = [label2id[row['gold_label']] for row in pred_test_r3]


# on how many samples both models are correct
correct_predictions = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == row['gold_label']) and (row['gold_label'] == row['pred_label'])
)

print(f"Both models are correct on {correct_predictions} out of {TEST_SIZE} samples.")
print(f"Both models are correct on {correct_predictions / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples llm is correct and DeBERTa_v3_ is incorrect
llm_correct_deberta_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label == row['gold_label']) and (row['gold_label'] != row['pred_label'])
)
print(f"LLM PIPELINE is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect} out of {TEST_SIZE} samples.")
print(f"LLM PIPELINE is correct and DeBERTa_v3 is incorrect on {llm_correct_deberta_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

# On how many samples DeBERTa_v3 is correct and llm is incorrect
deberta_correct_llm_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (row['pred_label'] == row['gold_label']) and (llm_pred.label != row['gold_label'])
)
print(f"DeBERTa_v3 is correct and LLM PIPELINE is incorrect on {deberta_correct_llm_incorrect} out of {TEST_SIZE} samples.")
print(f"DeBERTa_v3 is correct and LLM PIPELINE is incorrect on {deberta_correct_llm_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")


# on how many samples both models are incorrect
both_incorrect = sum(
    1 for row, llm_pred in zip(non_llm_predictions,optimized_llm_predictions)
    if (llm_pred.label != row['gold_label']) and (row['pred_label'] != row['gold_label'])
)
print(f"Both models are incorrect on {both_incorrect} out of {TEST_SIZE} samples.")
print(f"Both models are incorrect on {both_incorrect / TEST_SIZE * 100:.2f}% of the samples.")
print("\n")

Both models are correct on 321 out of 1200 samples.
Both models are correct on 26.75% of the samples.


LLM PIPELINE is correct and DeBERTa_v3 is incorrect on 334 out of 1200 samples.
LLM PIPELINE is correct and DeBERTa_v3 is incorrect on 27.83% of the samples.


DeBERTa_v3 is correct and LLM PIPELINE is incorrect on 256 out of 1200 samples.
DeBERTa_v3 is correct and LLM PIPELINE is incorrect on 21.33% of the samples.


Both models are incorrect on 289 out of 1200 samples.
Both models are incorrect on 24.08% of the samples.


