In [None]:
import json
import torch
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from utils import get_data, get_batches

In [None]:
BATCH_SIZE = 10

In [None]:
tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-mistral-7b-instruct")
model = AutoModel.from_pretrained("intfloat/e5-mistral-7b-instruct", device_map="auto")
model.eval();

In [None]:
dataset = load_dataset("kuznetsoffandrey/sberquad")
train_data = [{"question": q, "context": c} for q, c in zip(dataset["train"]["question"], dataset["train"]["context"]) if c is not None]
queries_train, passages_train = get_data(range(len(train_data)), train_data)
train_data_batched = get_batches(queries_train, passages_train, BATCH_SIZE)

In [None]:
def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'


def hard_negative_mining(batch_index: int, query: str, positive_doc: str, candidate_docs: list, 
                            task: str, margin: float = 0.95, top_k: int = 5):
    query_text = get_detailed_instruct(task, query)
    input_texts = [query_text, positive_doc] + candidate_docs

    batch_dict = tokenizer(input_texts, max_length=4096, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**batch_dict)
    embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    embeddings = F.normalize(embeddings, p=2, dim=1)

    q_emb = embeddings[0:1]
    pos_emb = embeddings[1:2]
    cand_embs = embeddings[2:]

    pos_score = (q_emb @ pos_emb.T).item()
    scores = (q_emb @ cand_embs.T).squeeze(0)

    threshold = pos_score * margin
    mask = scores < threshold
    filtered_scores = scores[mask]
    filtered_indices = torch.arange(len(candidate_docs))[mask]

    # top-k самых трудных
    if len(filtered_scores) > 0:
        topk = torch.topk(filtered_scores, min(top_k, len(filtered_scores)))[1]
        hard_negatives = [candidate_docs[idx] for idx in filtered_indices[topk]]
    else:
        hard_negatives = []

    return {
        "batch_index": batch_index,
        "query": query,
        "positive": positive_doc,
        "hard_negatives": hard_negatives
    }

In [None]:
task = 'Given a web search query, retrieve relevant passages that answer the query'
output_file = "hardneg_dataset.jsonl"

with open(output_file, "w", encoding="utf-8") as f:
    for batch_index, batch in enumerate(train_data_batched):
        for idx in range(BATCH_SIZE):
            query = batch["question"][idx]
            positive_doc = batch["context"][idx]
            candidate_docs = batch["context"][0:idx] + batch["context"][idx + 1:BATCH_SIZE]
            
            result = hard_negative_mining(batch_index, query, positive_doc, candidate_docs, task, margin=0.95, top_k=5)

            f.write(json.dumps(result, ensure_ascii=False) + "\n")