In [None]:
!pip install faiss-gpu-cu12 xformers torch torchvision transformers  sentence_transformers -U

In [None]:
import random

import datasets
import faiss
import numpy as np
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
from huggingface_hub import HfApi, create_repo

CONFIG = {
    "prefix_type":          "search",
    "hard_neg_rank_start":  10,
    "hard_neg_rank_end":    50,
    "hard_neg_count":       5,
    "max_triplets":         100_000,
    "log_every":            500,
    "max_tokens":           1024,
    "sample_size":          5_000,
    "margin_tau":           0.01,

    "dataset_path":         "zloelias/lenta-ru",
    "query_column":         "title",
    "positive_column":      "text",
    "train_split":          "train",

    "hf_token":             "<token>",
    "target_repo":          "Alexator26/lenta-ru-triplets"
}

tokenizer = AutoTokenizer.from_pretrained("deepvk/USER2-base")
model     = SentenceTransformer("deepvk/USER2-base")

def count_tokens(text: str) -> int:
    return len(tokenizer.encode(text, truncation=False))



def prompt_name(is_query: bool) -> str:
    if CONFIG["prefix_type"] == "search":
        return "search_query" if is_query else "search_document"
    return "clustering"


def create_triplets(
        split,
        q_field: str,
        p_field: str,
        *,
        max_tokens : int   = 500,
        margin_tau : float = CONFIG['margin_tau']
    ):

    examples = [ex for ex in split if ex.get(q_field) and ex.get(p_field)]
    random.shuffle(examples)
    if not examples:
        return []

    unique_docs, doc2idx, ex2doc_idx = [], {}, {}
    for ex_id, ex in enumerate(tqdm(examples, desc="Collect uniques")):
        doc = ex[p_field]
        if count_tokens(doc) > max_tokens:
            continue

        if doc not in doc2idx:
            doc_idx = len(unique_docs)
            doc2idx[doc] = doc_idx
            unique_docs.append(doc)

        ex2doc_idx[ex_id] = doc2idx[doc]

    if not unique_docs:
        return []

    doc_embs = model.encode(
        unique_docs,
        prompt_name       = prompt_name(is_query=False),
        convert_to_numpy  = True,
        show_progress_bar = True
    ).astype("float32")

    faiss.normalize_L2(doc_embs)
    index = faiss.IndexFlatIP(doc_embs.shape[1])
    index.add(doc_embs)
    triplets = []

    for ex_id, ex in enumerate(tqdm(examples, desc="Mining triplets")):
        if len(triplets) >= CONFIG["max_triplets"]:
            break
        if ex_id and ex_id % CONFIG["log_every"] == 0:
            print(f"   {ex_id}/{len(examples)} → {len(triplets)} triplets")

        q_text = ex[q_field]
        if count_tokens(q_text) > max_tokens:
            continue

        pos_idx = ex2doc_idx.get(ex_id)
        if pos_idx is None:
            continue
        pos_text = unique_docs[pos_idx]


        q_emb   = model.encode(
            q_text,
            prompt_name       = prompt_name(is_query=True),
            convert_to_numpy  = True,
            show_progress_bar = False
        ).astype("float32")


        q_emb_2d = q_emb.reshape(1, -1)
        faiss.normalize_L2(q_emb_2d)
        q_emb    = q_emb_2d[0]


        pos_sim = float(np.dot(q_emb, doc_embs[pos_idx]))
        _, neigh = index.search(
            q_emb_2d,
            CONFIG["hard_neg_rank_end"] + 1
        )

        candidates = []
        for i in neigh[0][CONFIG["hard_neg_rank_start"]
                          : CONFIG["hard_neg_rank_end"] + 1]:
            if i == pos_idx:
                continue

            sim_i = float(np.dot(q_emb, doc_embs[i]))
            if sim_i < pos_sim - margin_tau:
                candidates.append(i)

        random.shuffle(candidates)
        neg_idxs = candidates[:CONFIG["hard_neg_count"]]

        if len(neg_idxs) < CONFIG["hard_neg_count"]:
            continue

        triplets.append({
            "query"    : q_text,
            "positive" : pos_text,
            "negatives": [unique_docs[i]
                          for i in neg_idxs]
        })

    unique_triplets = {
        (t["query"], t["positive"], tuple(t["negatives"])): t
        for t in triplets
    }

    return list(unique_triplets.values())

if __name__ == "__main__":
    main()
