In [None]:
!pip install faiss-gpu-cu12 xformers torch torchvision transformers  sentence_transformers -U

In [None]:
import random

import datasets
import faiss
import numpy as np
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
from huggingface_hub import HfApi, create_repo

CONFIG = {
    "prefix_type":          "search",
    "hard_neg_rank_start":  10,
    "hard_neg_rank_end":    50,
    "hard_neg_count":       5,
    "max_triplets":         100_000,
    "log_every":            500,
    "max_tokens":           1024,
    "sample_size":          5_000,
    "margin_tau":           0.01,

    "dataset_path":         "zloelias/lenta-ru",
    "query_column":         "title",
    "positive_column":      "text",     
    "train_split":          "train",

    "hf_token":             "<token>",
    "target_repo":          "Alexator26/lenta-ru-triplets"
}

tokenizer = AutoTokenizer.from_pretrained("deepvk/USER2-base")
model     = SentenceTransformer("deepvk/USER2-base")

def count_tokens(text: str) -> int:
    return len(tokenizer.encode(text, truncation=False))



def prompt_name(is_query: bool) -> str: 
    if CONFIG["prefix_type"] == "search":
        return "search_query" if is_query else "search_document"
    return "clustering"


def create_triplets(
        split,
        q_field: str,
        p_field: str,
        *,
        max_tokens : int   = 500,
        margin_tau : float = CONFIG['margin_tau'] 
    ):

    examples = [ex for ex in split if ex.get(q_field) and ex.get(p_field)]
    random.shuffle(examples)
    if not examples:
        return []

    unique_docs, doc2idx, ex2doc_idx = [], {}, {}
    for ex_id, ex in enumerate(tqdm(examples, desc="Collect uniques")):
        doc = ex[p_field]
        if count_tokens(doc) > max_tokens:
            continue

        if doc not in doc2idx:
            doc_idx = len(unique_docs)
            doc2idx[doc] = doc_idx
            unique_docs.append(doc)

        ex2doc_idx[ex_id] = doc2idx[doc]

    if not unique_docs:
        return []

    doc_embs = model.encode(
        unique_docs,
        prompt_name       = prompt_name(is_query=False),
        convert_to_numpy  = True,
        show_progress_bar = True
    ).astype("float32")

    faiss.normalize_L2(doc_embs)                  
    index = faiss.IndexFlatIP(doc_embs.shape[1])
    index.add(doc_embs)
    triplets = []

    for ex_id, ex in enumerate(tqdm(examples, desc="Mining triplets")):
        if len(triplets) >= CONFIG["max_triplets"]:
            break
        if ex_id and ex_id % CONFIG["log_every"] == 0:
            print(f"   {ex_id}/{len(examples)} → {len(triplets)} triplets")

        q_text = ex[q_field]
        if count_tokens(q_text) > max_tokens:
            continue

        pos_idx = ex2doc_idx.get(ex_id)
        if pos_idx is None:
            continue
        pos_text = unique_docs[pos_idx]


        q_emb   = model.encode(
            q_text,
            prompt_name       = prompt_name(is_query=True),
            convert_to_numpy  = True,
            show_progress_bar = False
        ).astype("float32")                    

    
        q_emb_2d = q_emb.reshape(1, -1)
        faiss.normalize_L2(q_emb_2d)               
        q_emb    = q_emb_2d[0]                    

        
        pos_sim = float(np.dot(q_emb, doc_embs[pos_idx]))
        _, neigh = index.search(
            q_emb_2d,                             
            CONFIG["hard_neg_rank_end"] + 1
        )

        candidates = []
        for i in neigh[0][CONFIG["hard_neg_rank_start"]
                          : CONFIG["hard_neg_rank_end"] + 1]:
            if i == pos_idx:                     
                continue

            sim_i = float(np.dot(q_emb, doc_embs[i]))
            if sim_i < pos_sim - margin_tau:    
                candidates.append(i)

        random.shuffle(candidates)
        neg_idxs = candidates[:CONFIG["hard_neg_count"]]

        if len(neg_idxs) < CONFIG["hard_neg_count"]:
            continue

        triplets.append({
            "query"    : q_text,
            "positive" : pos_text,
            "negatives": [unique_docs[i]
                          for i in neg_idxs]
        })

    unique_triplets = {
        (t["query"], t["positive"], tuple(t["negatives"])): t
        for t in triplets
    }

    return list(unique_triplets.values())

def evaluate_hard_negative_quality(triplets, sample_size=CONFIG['sample_size']):
    sample = random.sample(triplets, k=min(sample_size, len(triplets)))

    q_texts   = [t["query"]               for t in sample]
    pos_texts = [t["positive"]            for t in sample]
    neg_texts = [n for t in sample for n in t["negatives"]]

    q_embs   = model.encode(q_texts,
                            prompt_name=prompt_name(True),
                            convert_to_numpy=True, show_progress_bar=True)

    pos_embs = model.encode(pos_texts,
                            prompt_name=prompt_name(False),
                            convert_to_numpy=True, show_progress_bar=True)

    neg_embs = model.encode(neg_texts,
                            prompt_name=prompt_name(False),
                            convert_to_numpy=True, show_progress_bar=True)

    q_embs   /= np.linalg.norm(q_embs,   axis=1, keepdims=True)
    pos_embs /= np.linalg.norm(pos_embs, axis=1, keepdims=True)
    neg_embs /= np.linalg.norm(neg_embs, axis=1, keepdims=True)

    pos_sims = (q_embs * pos_embs).sum(axis=1)

    neg_sims, margins = [], []
    idx, bad_cnt = 0, 0                                 
    for i, t in tqdm(enumerate(sample), desc='evaluate_hard_negative_quality', total=sample_size):
        n = len(t["negatives"])
        sims = q_embs[i] @ neg_embs[idx: idx + n].T

        hardest = sims.max()
        neg_sims.extend(sims.tolist())

        if hardest >= pos_sims[i]:
            bad_cnt += 1
        margins.append(pos_sims[i] - hardest)

        idx += n

    stats = {
        "sampled_triplets":              len(sample),
        "mean_pos_sim":                 float(np.mean(pos_sims)),
        "mean_neg_sim":                 float(np.mean(neg_sims)),
        "mean_margin_pos_vs_hardest":   float(np.mean(margins)),
        "triplets_with_harder_negative": bad_cnt
    }
    return stats


def push_dataset(dsdict, readme_fragment):
    api = HfApi()

    create_repo(
        repo_id=CONFIG["target_repo"],
        repo_type="dataset",
        private=True,
        token=CONFIG["hf_token"]
    )

    dsdict.push_to_hub(
        repo_id=CONFIG["target_repo"],
        token=CONFIG["hf_token"],
        private=True,
    )

    with open("README.md", "w") as f:
        f.write(readme_fragment)

    api.upload_file(
        path_or_fileobj="README.md",
        path_in_repo="README.md",
        repo_id=CONFIG["target_repo"],
        repo_type="dataset",
        token=CONFIG["hf_token"]
    )

def main():
    raw_ds  = datasets.load_dataset(CONFIG["dataset_path"])
    split   = raw_ds[CONFIG["train_split"]]

    triplets = create_triplets(
        split,
        CONFIG["query_column"],
        CONFIG["positive_column"],
        max_tokens=CONFIG["max_tokens"]
    )

    if not triplets:
        raise RuntimeError("Triplets list empty!")

    train_ds = Dataset.from_list(triplets)
    dsdict   = DatasetDict({"train": train_ds})
    print(dsdict)

   
    stats = evaluate_hard_negative_quality(triplets)

    readme_fragment = f"""
    ### 🔎 Hard‑negative sanity check
    Randomly inspected {stats['sampled_triplets']:,} triplets with `deepvk/USER2-base`.

    | metric | value |
    | --- | --- |
    | mean cos‑sim(query, **positive**) | **{stats['mean_pos_sim']:.4f}** |
    | mean cos‑sim(query, negatives)   | {stats['mean_neg_sim']:.4f} |
    | mean margin = pos – hardest_neg  | {stats['mean_margin_pos_vs_hardest']:.4f} |
    | bad cases (neg ≥ pos)            | {stats['triplets_with_harder_negative']}/{stats['sampled_triplets']} |

    Lower margin ⇒ harder negatives.
    Ideally the last line should be 0.
    """

    print(readme_fragment)
    push_dataset(dsdict, readme_fragment)



if __name__ == "__main__":
    main()
