In [None]:
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from collections import defaultdict

import pandas as pd
from transformers import AutoModel, AutoTokenizer
from torch import nn
import numpy as np
from torch import optim
import torch
from tqdm.auto import tqdm
import time
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss
from datasets import Dataset as HFDataset

from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import Tensor
from usearch.index import Index
import string
from sklearn.model_selection import GroupShuffleSplit
import random

In [None]:
from datasets import Dataset

anchors = []
positives = []
# Open a file, do preprocessing, filtering, cleaning, etc.
# and append to the lists

dataset = Dataset.from_dict({
    "anchor": anchors,
    "positive": positives,
})

In [None]:
# load model
model_name = "Alibaba-NLP/gte-large-en-v1.5"
model = SentenceTransformer(model_name, trust_remote_code=True)
model_ = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
tokenizer_ = AutoTokenizer.from_pretrained(model_name)
# load data
df = pd.read_csv("data/eedi-paraphrased/train.csv")
gss = GroupShuffleSplit(n_splits=1, train_size=0.7, random_state=42)
train_idx, val_idx = next(gss.split(df, groups=df["QuestionId"]))
df_train = df.iloc[train_idx]
df_val = df.iloc[val_idx]
df_val = df_val[~df_val["QuestionAiCreated"] & ~df_val["MisconceptionAiCreated"]].reset_index(drop=True)

In [None]:
df_q = df_train[
    [
        "ConstructName",
        "SubjectName",
        "QuestionId_Answer",
        "QuestionText",
        "WrongText",
        "CorrectText",
        "MisconceptionId",
    ]
].drop_duplicates()
df_q["QuestionComplete"] = (
    "Subject: "
    + df_q["SubjectName"]
    + ". Construct: "
    + df_q["ConstructName"]
    + ". Question: "
    + df_q["QuestionText"]
    + ". Correct answer: "
    + df_q["CorrectText"]
    + ". Wrong answer: "
    + df_q["WrongText"]
    + "."
)
df_q = (
    df_q[["MisconceptionId", "QuestionComplete"]]
    .sort_values("MisconceptionId")
    .reset_index(drop=True)
)
df_q

In [None]:
df_m = df[["MisconceptionId", "MisconceptionText"]].sort_values("MisconceptionId").drop_duplicates().reset_index(drop=True)
df_m

In [None]:
@torch.inference_mode()
def batched_inference(model, tokenizer, texts: list[str], bs: int, desc: str) -> Tensor:
    """Basically SentenceTransformer.encode, but consume less vram."""
    embeddings = []
    for i in tqdm(range(0, len(texts), bs), desc=desc):
        # max_length=256 comes from plotting the complete question text, and 256 covers 99%
        encoded = tokenizer(
            texts[i : i + bs],
            max_length=256,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to("cuda")
        outputs = model(**encoded)
        emb = outputs.last_hidden_state[:, 0]  # cls token
        emb = F.normalize(emb, p=2, dim=-1)
        embeddings.append(emb.cpu())
    embeddings = torch.cat(embeddings)
    return embeddings

In [None]:
def hn_mine_hf(
    model,
    tokenizer,
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    k: int,
    bs: int,
) -> list[list[int]]:
    """Hard negative mining, but different from: https://www.sbert.net/docs/package_reference/util.html#sentence_transformers.util.mine_hard_negatives.
    Sentence Transformers' version assumes different rows are always negatives, but that is not the case if we use paraphrased data.

    Args:
        q_texts (list[str]): Question texts.
        q_mis_ids (list[int]): Ground truth misconception ids for the questions.
        mis_texts (list[str]): Misconception texts.
        mis_ids (list[int]): Misconception ids.
        k (int): Top k hard misconception ids per question.
        bs (int): Batch size.

    Returns:
        list[list[int]]:
            Hard misconceptions for each question.
            This is NOT misconception ids, but the actual list index.
    """
    assert len(q_texts) == len(q_mis_ids)
    assert len(mis_texts) == len(mis_ids)
    m_embeds = batched_inference(
        model, tokenizer, mis_texts, bs=bs, desc="miscon"
    ).numpy()
    index = Index(ndim=m_embeds.shape[-1], metric="ip")
    index.add(np.arange(m_embeds.shape[0]), m_embeds)
    q_embeds = batched_inference(
        model, tokenizer, q_texts, bs=bs, desc="questions"
    ).numpy()
    batch_matches = index.search(q_embeds, count=k)
    hards = []
    for i, matches in enumerate(batch_matches):  # type: ignore
        nth_miscons: list[int] = [m.key for m in matches]
        hard_miscons = [nth for nth in nth_miscons if mis_ids[nth] != q_mis_ids[i]]
        hards.append(hard_miscons)
    assert len(hards) == len(q_texts)
    return hards


In [None]:
def hn_mine_st(
    model: SentenceTransformer,
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    k: int,
    bs: int,
) -> list[list[int]]:
    """Hard negative mining, but different from: https://www.sbert.net/docs/package_reference/util.html#sentence_transformers.util.mine_hard_negatives.
    Sentence Transformers' version assumes different rows are always negatives, but that is not the case if we use paraphrased data.

    Args:
        q_texts (list[str]): Question texts.
        q_mis_ids (list[int]): Ground truth misconception ids for the questions.
        mis_texts (list[str]): Misconception texts.
        mis_ids (list[int]): Misconception ids.
        k (int): Top k hard misconception ids per question (at max).
        bs (int): Batch size.

    Returns:
        list[list[int]]:
            Hard misconceptions for each question.
            This is NOT misconception ids, but the actual list index.
    """
    assert len(q_texts) == len(q_mis_ids)
    assert len(mis_texts) == len(mis_ids)
    m_embeds = model.encode(
        mis_texts,
        batch_size=bs,
        normalize_embeddings=True,
        show_progress_bar=True,
        device="cuda",
    )
    index = Index(ndim=m_embeds.shape[-1], metric="ip")
    index.add(np.arange(m_embeds.shape[0]), m_embeds)
    q_embeds = model.encode(
        q_texts,
        batch_size=bs,
        normalize_embeddings=True,
        show_progress_bar=True,
        device="cuda",
    )
    batch_matches = index.search(q_embeds, count=k)
    hards = []
    for i, matches in enumerate(batch_matches):  # type: ignore
        nth_miscons: list[int] = [m.key for m in matches]
        hard_miscons = [nth for nth in nth_miscons if mis_ids[nth] != q_mis_ids[i]]
        hards.append(hard_miscons)
    assert len(hards) == len(q_texts)
    return hards

In [None]:
hards = hn_mine_hf(
    model_,
    tokenizer_,
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    k=100,
    bs=16,
)

In [None]:
hards_st = hn_mine_st(
    model,
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    k=100,
    bs=4,
)

In [None]:
def make_mnr_dataset(
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    hards: list[list[int]],
    n_negatives: int,
) -> HFDataset:
    """Create SentenceTransformer dataset suitable for MultipleNegativesRankingLoss.
    The format is (anchor, positive, negative_1, …, negative_n).
    Example: https://huggingface.co/datasets/tomaarsen/gooaq-hard-negatives
    """
    assert len(q_texts) == len(q_mis_ids) == len(hards)
    assert len(mis_texts) == len(mis_ids)
    assert all(n_negatives <= len(hard) for hard in hards)
    # create reverse mapping
    mis_id_to_mis_idx = defaultdict(list)
    for i, mis_id in enumerate(mis_ids):
        mis_id_to_mis_idx[mis_id].append(i)
    # make hf dataset
    d = {}
    d["q"], d["mis"] = [], []
    for i in range(1, n_negatives + 1):
        d[f"neg_{i}"] = []
    for i, (q_text, q_mis_id) in enumerate(zip(q_texts, q_mis_ids)):
        rand_pos = random.choice(mis_id_to_mis_idx[q_mis_id])
        rand_negs = random.sample(hards[i], k=n_negatives)
        d["q"].append(q_text)
        d["mis"].append(mis_texts[rand_pos])
        for j, rand_neg in enumerate(rand_negs, 1):
            d[f"neg_{j}"].append(mis_texts[rand_neg])
    return HFDataset.from_dict(d)


ds = make_mnr_dataset(
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    hards=hards,
    n_negatives=10,
)
ds

In [None]:
def make_cosent_dataset(
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    hards: list[list[int]],
    n_negatives: int,
) -> HFDataset:
    """Create SentenceTransformer dataset suitable for CoSENTLoss.
    The format is (sentence_A, sentence_B).
    Example: https://sbert.net/docs/sentence_transformer/training_overview.html#loss-function
    """
    assert len(q_texts) == len(q_mis_ids) == len(hards)
    assert len(mis_texts) == len(mis_ids)
    assert all(n_negatives <= len(hard) for hard in hards)
    # create reverse mapping
    mis_id_to_mis_idx = defaultdict(list)
    for i, mis_id in enumerate(mis_ids):
        mis_id_to_mis_idx[mis_id].append(i)
    # make hf dataset
    d = {"q": [], "mis": [], "label": []}
    for i, (q_text, q_mis_id) in enumerate(zip(q_texts, q_mis_ids)):
        # insert positive
        rand_pos = random.choice(mis_id_to_mis_idx[q_mis_id])
        d["q"].append(q_text)
        d["mis"].append(mis_texts[rand_pos])
        d["label"].append(1.0)
        # insert negatives
        rand_negs = random.sample(hards[i], k=n_negatives)
        for j, rand_neg in enumerate(rand_negs, 1):
            d["q"].append(q_text)
            d["mis"].append(mis_texts[rand_neg])
            d["label"].append(-1.0)
    return HFDataset.from_dict(d)


ds = make_cosent_dataset(
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    hards=hards,
    n_negatives=10,
)
ds