In [1]:
%pip install -q pysbd

Note: you may need to restart the kernel to use updated packages.


In [2]:
import warnings

import pysbd
import torch
from datasets import concatenate_datasets, load_dataset
from tqdm import tqdm
from transformers import T5ForConditionalGeneration, T5Tokenizer

warnings.filterwarnings('ignore')

2025-06-19 21:19:39.467314: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750367979.491189    8416 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750367979.498650    8416 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
subreddits = ["askphysics"]

splits = ["train", "validation", "test"]

all_data = {split: [] for split in splits}
for subreddit in subreddits:
    data = load_dataset("stanfordnlp/shp", data_dir=subreddit)
    for split in splits:
        all_data[split].append(data[split])

final_dataset = {
    split: concatenate_datasets(all_data[split]) for split in splits
}

train_dataset = final_dataset["train"]
val_dataset = final_dataset["validation"]
test_dataset = final_dataset["test"]


In [4]:
device = 'cuda'

tokenizer = T5Tokenizer.from_pretrained('stanfordnlp/SteamSHP-flan-t5-xl', verbose=False)

model = T5ForConditionalGeneration.from_pretrained(
    'stanfordnlp/SteamSHP-flan-t5-xl',
    device_map=device
).to(device)
model.eval()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

T5ForConditionalGeneration(
  (shared): Embedding(32128, 2048)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 2048)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=2048, out_features=2048, bias=False)
              (k): Linear(in_features=2048, out_features=2048, bias=False)
              (v): Linear(in_features=2048, out_features=2048, bias=False)
              (o): Linear(in_features=2048, out_features=2048, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=2048, out_features=5120, bias=False)
              (wi_1): Linear(in_features=2048, out_features=5120, bias=False)
       

In [5]:
segmenter = pysbd.Segmenter(language="en", clean=False)

def create_full_input(prompt, response_a, response_b=None):
    return (
        f"POST: {prompt.strip()}\n\n"
        f"RESPONSE A: {response_a.strip()}\n\n"
        f"RESPONSE B: {response_b.strip() if response_b else '.'}\n\n"
        f"Which response is better? RESPONSE"
    )

def prepare_input_with_pysbd(tokenizer, prompt, response_a, response_b=None, max_tokens=512):
    prompt_sentences = segmenter.segment(prompt.strip())
    response_b = response_b if response_b else "."

    for start in range(len(prompt_sentences)):
        truncated_prompt = " ".join(prompt_sentences[start:])
        full_input = create_full_input(truncated_prompt, response_a, response_b)
        input_ids = tokenizer(full_input, return_tensors="pt").input_ids
        if input_ids.shape[1] <= max_tokens:
            return full_input

    fallback_input = create_full_input(prompt, response_a, response_b)
    encoded = tokenizer(
        fallback_input,
        truncation=True,
        max_length=max_tokens,
        return_tensors="pt"
    )
    return tokenizer.decode(encoded["input_ids"][0], skip_special_tokens=True)


In [6]:
def score_steam_likelihood(model, tokenizer, prompt, response):
    input_text = prepare_input_with_pysbd(tokenizer, prompt, response, max_tokens=512)
    
    inputs = tokenizer(
        [input_text], 
        return_tensors='pt',
        truncation=True,
        max_length=512
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            return_dict_in_generate=True, 
            output_scores=True, 
            max_new_tokens=1
        )
        
    # Index 71 corresponds to the token for 'A'
    score = torch.exp(outputs.scores[0][:, 71]) / torch.exp(outputs.scores[0][:,:]).sum(axis=1).item()
    return score


In [7]:
def score_pairwise(prompt, answer_a, answer_b):
    score_a = score_steam_likelihood(model, tokenizer, prompt, answer_a)
    score_b = score_steam_likelihood(model, tokenizer, prompt, answer_b)

    return 1 if score_a > score_b else 0


In [None]:
def eval_model(num_eval):
    print(f"Evaluating model on {num_eval} samples")

    samples = val_dataset.select(range(num_eval))

    correct = 0
    invalid_predictions = 0

    for ex in tqdm(samples):
        pred = score_pairwise(
            ex["history"],
            ex["human_ref_A"],
            ex["human_ref_B"]
        )

        if pred is None:
            print("Warning: Received None prediction")
            invalid_predictions += 1
            continue

        correct += int(pred == ex["labels"])

    valid_samples = num_eval - invalid_predictions
    if valid_samples > 0:
        accuracy = correct / valid_samples
        print("\nResults:")
        print(f"Correct predictions: {correct}")
        print(f"Valid samples evaluated: {valid_samples}/{num_eval}")
        print(f"Pairwise zero-shot accuracy: {accuracy:.4f}")
    else:
        print("No valid predictions were made.")


In [9]:
eval_model(100)

Evaluating model on 100 samples


  0%|          | 0/100 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (553 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 100/100 [01:35<00:00,  1.05it/s]


Results:
Correct predictions: 84
Valid samples evaluated: 100/100
Pairwise zero-shot accuracy: 0.8400



