# Improve the dataset used for SBERT finetuning

The dataset used for finetuning the ModernBERT model for similarity looks strange with respect to the very opinionated similarity values. We can do better by using a (very good) reranking model and calculating the similarities with this.

It turns out that these also improves the performance of the resulting finetuned model!

In [1]:
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM

This is an instruction following model and needs to be prompted

In [2]:
def format_instruction(instruction, query, doc):
    output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction,query=query, doc=doc)
    return output

In [3]:
def process_inputs(pairs):
    inputs = tokenizer(
        pairs, padding=False, truncation='longest_first',
        return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
    )
    for i, ele in enumerate(inputs['input_ids']):
        inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
    inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)
    return inputs

We are interested in the probabilities of yet and no, therefore we take a look at the last layer and get these:

In [4]:
@torch.no_grad()
def compute_logits(inputs, **kwargs):
    batch_scores = model(**inputs).logits[:, -1, :]
    true_vector = batch_scores[:, token_true_id]
    false_vector = batch_scores[:, token_false_id]
    batch_scores = torch.stack([false_vector, true_vector], dim=1)
    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
    scores = batch_scores[:, 1].exp().tolist()
    return scores

In [5]:
# this is large and slow, but slightly better:
# model_name = "Qwen/Qwen3-Reranker-4B"
model_name = "Qwen/Qwen3-Reranker-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()

Use the *flawed* dataset again

In [6]:
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")

In [7]:
df = train_dataset.to_pandas()

In [8]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [9]:
df.head(20).style.background_gradient(cmap='coolwarm')

Unnamed: 0,sentence1,sentence2,score
0,A person on a horse jumps over a broken down airplane.,A person is training his horse for a competition.,0.5
1,A person on a horse jumps over a broken down airplane.,"A person is at a diner, ordering an omelette.",0.0
2,A person on a horse jumps over a broken down airplane.,"A person is outdoors, on a horse.",1.0
3,Children smiling and waving at camera,They are smiling at their parents,0.5
4,Children smiling and waving at camera,There are children present,1.0
5,Children smiling and waving at camera,The kids are frowning,0.0
6,A boy is jumping on skateboard in the middle of a red bridge.,The boy skates down the sidewalk.,0.0
7,A boy is jumping on skateboard in the middle of a red bridge.,The boy does a skateboarding trick.,1.0
8,A boy is jumping on skateboard in the middle of a red bridge.,The boy is wearing safety equipment.,0.5
9,An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background.,An older man drinks his juice as he waits for his daughter to get off work.,0.5


Get the tokens ids for `yes` and `no`, used for calculating the probabilities later

In [10]:
token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")

We use the Qwen chat template (more about that later) and a special system prompt:

In [11]:
max_length = 8192

prefix = "<|im_start|>system\nJudge whether the documents are similar . Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)

Create a temporary dataframe containing the start of the dataset and some arbitrary samples:

In [12]:
dft = pd.concat([df.head(20), df.sample(20, random_state=42)])

Create the instruct codes for prompting the LLM:

In [13]:
pairs = [format_instruction("Given two sentences, calculate their similarity", query, doc) 
           for query, doc in zip(dft["sentence1"], dft["sentence2"])]

Run the actual reranking process:

In [14]:
%%time
# Tokenize the input texts
max_length = 8192
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


CPU times: user 243 ms, sys: 140 ms, total: 383 ms
Wall time: 377 ms


In [15]:
# integrate scores
dft["scores"] = scores

In [16]:
dft.style.background_gradient(cmap='coolwarm')

Unnamed: 0,sentence1,sentence2,score,scores
0,A person on a horse jumps over a broken down airplane.,A person is training his horse for a competition.,0.5,0.008514
1,A person on a horse jumps over a broken down airplane.,"A person is at a diner, ordering an omelette.",0.0,4.5e-05
2,A person on a horse jumps over a broken down airplane.,"A person is outdoors, on a horse.",1.0,0.992188
3,Children smiling and waving at camera,They are smiling at their parents,0.5,0.031158
4,Children smiling and waving at camera,There are children present,1.0,0.845703
5,Children smiling and waving at camera,The kids are frowning,0.0,0.003885
6,A boy is jumping on skateboard in the middle of a red bridge.,The boy skates down the sidewalk.,0.0,0.039093
7,A boy is jumping on skateboard in the middle of a red bridge.,The boy does a skateboarding trick.,1.0,0.983398
8,A boy is jumping on skateboard in the middle of a red bridge.,The boy is wearing safety equipment.,0.5,0.149048
9,An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background.,An older man drinks his juice as he waits for his daughter to get off work.,0.5,0.855957
