# SummEval example

In this notebook, we use one summarization example from the SummEval to demostrate how to use the PairS.

In [1]:
import sys
sys.path.append('../scripts')


In [2]:
from utils import shuffle_lists, calculate_correlation, load_newsroom, load_summEval, calculate_uncertainty, load_sf_data, CompareResultObject, insert_index_to_anchors


summ_eval_path = '../data/SummEval/model_annotations.aligned.paired.jsonl'
input_doc, output_doc, scores_doc = load_summEval(summ_eval_path, flat_output=False)
scores_doc = scores_doc['coherence']

doc_id = 42
input, output, scores = input_doc[doc_id], output_doc[doc_id], scores_doc[doc_id]
print('Number of summary candidates:', len(output))

Number of summary candidates: 16


## PairS-greedy

In [1]:
from transformers import LlamaForCausalLM, AutoTokenizer
import torch
device = 'cuda'

model = 'meta-llama/Meta-Llama-3-8B'
model = 'meta-llama/Meta-Llama-3-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model)   # base_model

model = LlamaForCausalLM.from_pretrained(model,
    torch_dtype=torch.bfloat16,
    device_map=device,
)


  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Downloading shards: 100%|██████████| 4/4 [1:09:33<00:00, 1043.27s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


In [3]:
from tqdm import tqdm
import numpy as np

# Set the meta-parameters
params = {
    'dataset': 'SummEval',
    'engine': "mistralai/Mistral-7B-Instruct-v0.1",
    'aspect': 'coherence',
    'eval_method': 'pairwise comparison',
    'confidence_beam': False,  # False for PairS-greedy search
    # 'beam_size': 2000,
    # 'prob_gap': 0.1,
    'api_call': 0,
    'with_input': True,
    'compare_log': {},
    'calibration': False,
}


In [4]:
from sorting import merge_sort_indices, merge_sort
import random

random.seed(42)

# Set the progress bar
if params['confidence_beam']:
    params['progress_bar'] = tqdm(total=int(len(input)**2), desc='Processing')
else:
    params['progress_bar'] = tqdm(total=int(len(input) * np.log2(len(input))), desc='Processing')

# Shuffle the input, output, and scores
input, output, scores = shuffle_lists(input, output, scores)

# Perform the PairS-greedy ranking
# Please note: All prompts are saved in /scripts/prompts.py
ranking_indices = merge_sort_indices(input, output, params)

params['progress_bar'].close()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.88s/it]
Processing:  58%|█████▊    | 37/64 [00:16<00:12,  2.21it/s]


In [9]:
# Calculate the correlation
spearman_corr, kendall_tau = calculate_correlation(np.array(scores)[ranking_indices], list(range(len(scores))))


Spearmans correlation: 0.119
Kendall tau: 0.104


## PairS-beam

In [10]:
from tqdm import tqdm
import numpy as np

# Set the meta-parameters
params = {
    'dataset': 'SummEval',
    'engine': "mistralai/Mistral-7B-Instruct-v0.1",
    'aspect': 'coherence',
    'eval_method': 'pairwise comparison',
    'confidence_beam': True,  # True for PairS-beam search
    'beam_size': 2000,
    'api_call': 0,
    'prob_gap': 0.1,
    'with_input': True,
    'compare_log': {},
    'calibration': False,
}


In [13]:
from sorting import merge_sort_indices, merge_sort
import random

random.seed(42)

# Set the progress bar
if params['confidence_beam']:
    params['progress_bar'] = tqdm(total=int(len(input)**2), desc='Processing')
else:
    params['progress_bar'] = tqdm(total=int(len(input) * np.log2(len(input))), desc='Processing')

# Shuffle the input, output, and scores
input, output, scores = shuffle_lists(input, output, scores)

# Perform the PairS-beam ranking
# Please note: All prompts are saved in /scripts/prompts.py
ranking_indices = merge_sort_indices(input, output, params)

params['progress_bar'].close()

Processing:  25%|██▌       | 64/256 [00:19<00:58,  3.26it/s]


In [14]:
# Calculate the correlation
spearman_corr, kendall_tau = calculate_correlation(np.array(scores)[ranking_indices], list(range(len(scores))))


Spearmans correlation: 0.326
Kendall tau: 0.261
