In [None]:
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from collections import defaultdict

import pandas as pd
from transformers import AutoModel, AutoTokenizer
from torch import nn
import numpy as np
from torch import optim
import torch
from tqdm.auto import tqdm
import time
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss
from datasets import Dataset as HFDataset
from itertools import islice

from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import Tensor
from usearch.index import Index
import string
from sklearn.model_selection import GroupShuffleSplit
from sentence_transformers.evaluation import InformationRetrievalEvaluator
import random
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import load_dataset

import random

In [None]:
# load model
model_name = "Alibaba-NLP/gte-large-en-v1.5"
model = SentenceTransformer(model_name, trust_remote_code=True)
model_ = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
tokenizer_ = AutoTokenizer.from_pretrained(model_name)
# load data
all_mis_texts = pd.read_csv("data/misconception_mapping.csv")["MisconceptionName"].tolist()
df = pd.read_csv("data/eedi-paraphrased/train.csv")
df["QuestionComplete"] = (
    "Subject: "
    + df["SubjectName"]
    + ". Construct: "
    + df["ConstructName"]
    + ". Question: "
    + df["QuestionText"]
    + ". Correct answer: "
    + df["CorrectText"]
    + ". Wrong answer: "
    + df["WrongText"]
    + "."
)
gss = GroupShuffleSplit(n_splits=1, train_size=0.7, random_state=42)
train_idx, val_idx = next(gss.split(df, groups=df["QuestionId"]))
df_train = df.iloc[train_idx]
df_val = df.iloc[val_idx]
df_val = df_val[~df_val["QuestionAiCreated"] & ~df_val["MisconceptionAiCreated"]].reset_index(drop=True)

In [44]:
df_q = (
    df_train[["MisconceptionId", "QuestionComplete"]]
    .sort_values("MisconceptionId")
    .drop_duplicates()
    .reset_index(drop=True)
)
df_q

Unnamed: 0,MisconceptionId,QuestionComplete
0,0,Subject: Angles in Triangles. Construct: Find ...
1,0,Subject: Angles in Triangles. Construct: Find ...
2,0,Subject: Angles in Triangles. Construct: Find ...
3,0,Subject: Angles in Triangles. Construct: Find ...
4,0,Subject: Angles in Triangles. Construct: Find ...
...,...,...
15268,2586,Subject: Rearranging Formula and Equations. Co...
15269,2586,Subject: Rearranging Formula and Equations. Co...
15270,2586,Subject: Rearranging Formula and Equations. Co...
15271,2586,Subject: Rearranging Formula and Equations. Co...


In [45]:
# NOTE: this is not a bug, we look for misconceptions from the full dataset, not train dataset
df_m = (
    temp[["MisconceptionId", "MisconceptionText"]]
    .sort_values("MisconceptionId")
    .drop_duplicates()
    .reset_index(drop=True)
)
df_m

Unnamed: 0,MisconceptionId,MisconceptionText
0,0,Unaware that the total of angles in a triangle...
1,0,Lacks knowledge that the angles within a trian...
2,0,Is not aware that the sum of angles in a trian...
3,0,Doesn't understand that the angles inside a tr...
4,0,Does not know that angles in a triangle sum to...
...,...,...
8015,2586,Misinterprets the rules governing the sequence...
8016,2586,Does not correctly understand how to prioritiz...
8017,2586,Misunderstands order of operations in algebrai...
8018,2586,Fails to grasp the correct sequence for perfor...


In [5]:
@torch.inference_mode()
def batched_inference(model, tokenizer, texts: list[str], bs: int, desc: str) -> Tensor:
    """Basically SentenceTransformer.encode, but consume less vram."""
    embeddings = []
    for i in tqdm(range(0, len(texts), bs), desc=desc):
        # max_length=256 comes from plotting the complete question text, and 256 covers 99%
        encoded = tokenizer(
            texts[i : i + bs],
            max_length=256,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to("cuda")
        outputs = model(**encoded)
        emb = outputs.last_hidden_state[:, 0]  # cls token
        emb = F.normalize(emb, p=2, dim=-1)
        embeddings.append(emb.cpu())
    embeddings = torch.cat(embeddings)
    return embeddings

In [6]:
def hn_mine_hf(
    model,
    tokenizer,
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    k: int,
    bs: int,
) -> list[list[int]]:
    """Hard negative mining, but different from: https://www.sbert.net/docs/package_reference/util.html#sentence_transformers.util.mine_hard_negatives.
    Sentence Transformers' version assumes different rows are always negatives, but that is not the case if we use paraphrased data.

    Args:
        q_texts (list[str]): Question texts.
        q_mis_ids (list[int]): Ground truth misconception ids for the questions.
        mis_texts (list[str]): Misconception texts.
        mis_ids (list[int]): Misconception ids.
        k (int): Top k hard misconception ids per question.
        bs (int): Batch size.

    Returns:
        list[list[int]]:
            Hard misconceptions for each question.
            This is NOT misconception ids, but the actual list index.
    """
    assert len(q_texts) == len(q_mis_ids)
    assert len(mis_texts) == len(mis_ids)
    m_embeds = batched_inference(
        model, tokenizer, mis_texts, bs=bs, desc="miscon"
    ).numpy()
    index = Index(ndim=m_embeds.shape[-1], metric="ip")
    index.add(np.arange(m_embeds.shape[0]), m_embeds)
    q_embeds = batched_inference(
        model, tokenizer, q_texts, bs=bs, desc="questions"
    ).numpy()
    batch_matches = index.search(q_embeds, count=k)
    hards = []
    for i, matches in enumerate(batch_matches):  # type: ignore
        nth_miscons: list[int] = [m.key for m in matches]
        hard_miscons = [nth for nth in nth_miscons if mis_ids[nth] != q_mis_ids[i]]
        hards.append(hard_miscons)
    assert len(hards) == len(q_texts)
    return hards


In [8]:
def hn_mine_st(
    model: SentenceTransformer,
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    k: int,
    bs: int,
) -> list[list[int]]:
    """Hard negative mining, but different from: https://www.sbert.net/docs/package_reference/util.html#sentence_transformers.util.mine_hard_negatives.
    Sentence Transformers' version assumes different rows are always negatives, but that is not the case if we use paraphrased data.

    Args:
        q_texts (list[str]): Question texts.
        q_mis_ids (list[int]): Ground truth misconception ids for the questions.
        mis_texts (list[str]): Misconception texts.
        mis_ids (list[int]): Misconception ids.
        k (int): Top k hard misconception ids per question (at max).
        bs (int): Batch size.

    Returns:
        list[list[int]]:
            Hard misconceptions for each question.
            This is NOT misconception ids, but the actual list index.
    """
    assert len(q_texts) == len(q_mis_ids)
    assert len(mis_texts) == len(mis_ids)
    m_embeds = model.encode(
        mis_texts,
        batch_size=bs,
        normalize_embeddings=True,
        show_progress_bar=True,
        device="cuda",
    )
    index = Index(ndim=m_embeds.shape[-1], metric="ip")
    index.add(np.arange(m_embeds.shape[0]), m_embeds)
    q_embeds = model.encode(
        q_texts,
        batch_size=bs,
        normalize_embeddings=True,
        show_progress_bar=True,
        device="cuda",
    )
    batch_matches = index.search(q_embeds, count=k)
    hards = []
    for i, matches in enumerate(batch_matches):  # type: ignore
        nth_miscons: list[int] = [m.key for m in matches]
        hard_miscons = [nth for nth in nth_miscons if mis_ids[nth] != q_mis_ids[i]]
        hards.append(hard_miscons)
    assert len(hards) == len(q_texts)
    return hards

In [None]:
hards = hn_mine_hf(
    model_,
    tokenizer_,
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    k=100,
    bs=16,
)

In [11]:
hards_st = hn_mine_st(
    model,
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    k=100,
    bs=4,
)

Batches:   0%|          | 0/2005 [00:00<?, ?it/s]

Batches:   0%|          | 0/3824 [00:00<?, ?it/s]

In [13]:
def make_mnr_dataset(
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    hards: list[list[int]],
    n_negatives: int,
) -> HFDataset:
    """Create SentenceTransformer dataset suitable for MultipleNegativesRankingLoss.
    The format is (anchor, positive, negative_1, …, negative_n).
    Example: https://huggingface.co/datasets/tomaarsen/gooaq-hard-negatives
    """
    assert len(q_texts) == len(q_mis_ids) == len(hards)
    assert len(mis_texts) == len(mis_ids)
    assert all(n_negatives <= len(hard) for hard in hards)
    # create reverse mapping
    mis_id_to_mis_idx = defaultdict(list)
    for i, mis_id in enumerate(mis_ids):
        mis_id_to_mis_idx[mis_id].append(i)
    # make hf dataset
    d = {}
    d["q"], d["mis"] = [], []
    for i in range(1, n_negatives + 1):
        d[f"neg_{i}"] = []
    for i, (q_text, q_mis_id) in enumerate(zip(q_texts, q_mis_ids)):
        rand_pos = random.choice(mis_id_to_mis_idx[q_mis_id])
        rand_negs = random.sample(hards[i], k=n_negatives)
        d["q"].append(q_text)
        d["mis"].append(mis_texts[rand_pos])
        for j, rand_neg in enumerate(rand_negs, 1):
            d[f"neg_{j}"].append(mis_texts[rand_neg])
    return HFDataset.from_dict(d)


ds = make_mnr_dataset(
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    hards=hards_st,
    n_negatives=10,
)
ds

Dataset({
    features: ['q', 'mis', 'neg_1', 'neg_2', 'neg_3', 'neg_4', 'neg_5', 'neg_6', 'neg_7', 'neg_8', 'neg_9', 'neg_10'],
    num_rows: 15293
})

In [14]:
def make_cosent_dataset(
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    hards: list[list[int]],
    n_negatives: int,
) -> HFDataset:
    """Create SentenceTransformer dataset suitable for CoSENTLoss.
    The format is (sentence_A, sentence_B).
    Example: https://sbert.net/docs/sentence_transformer/training_overview.html#loss-function
    """
    assert len(q_texts) == len(q_mis_ids) == len(hards)
    assert len(mis_texts) == len(mis_ids)
    assert all(n_negatives <= len(hard) for hard in hards)
    # create reverse mapping
    mis_id_to_mis_idx = defaultdict(list)
    for i, mis_id in enumerate(mis_ids):
        mis_id_to_mis_idx[mis_id].append(i)
    # make hf dataset
    d = {"q": [], "mis": [], "label": []}
    for i, (q_text, q_mis_id) in enumerate(zip(q_texts, q_mis_ids)):
        # insert positive
        rand_pos = random.choice(mis_id_to_mis_idx[q_mis_id])
        d["q"].append(q_text)
        d["mis"].append(mis_texts[rand_pos])
        d["label"].append(1.0)
        # insert negatives
        rand_negs = random.sample(hards[i], k=n_negatives)
        for j, rand_neg in enumerate(rand_negs, 1):
            d["q"].append(q_text)
            d["mis"].append(mis_texts[rand_neg])
            d["label"].append(-1.0)
    return HFDataset.from_dict(d)


ds2 = make_cosent_dataset(
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    hards=hards_st,
    n_negatives=10,
)
ds2

Dataset({
    features: ['q', 'mis', 'label'],
    num_rows: 168223
})

# Raw training, move this to script, and make cached hn mining because we will hammer the script a LOT!

In [None]:
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

# 1. Load a model to finetune with 2. (Optional) model card data
# -- done upper

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
# -- done upper 

# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/gte-large-en-1",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="gte-large-en-1",  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# dev_evaluator = TripletEvaluator(
#     anchors=eval_dataset["anchor"],
#     positives=eval_dataset["positive"],
#     negatives=eval_dataset["negative"],
#     name="all-nli-dev",
# )
# dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=ds2,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)
test_evaluator(model)

# 8. Save the trained model
model.save_pretrained("models/mpnet-base-all-nli-triplet/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub("mpnet-base-all-nli-triplet")

In [None]:


# Load a model
modelmini = SentenceTransformer('all-MiniLM-L6-v2')

# Load the Touche-2020 IR dataset (https://huggingface.co/datasets/BeIR/webis-touche2020, https://huggingface.co/datasets/BeIR/webis-touche2020-qrels)
corpus = load_dataset("BeIR/webis-touche2020", "corpus", split="corpus")
queries = load_dataset("BeIR/webis-touche2020", "queries", split="queries")
relevant_docs_data = load_dataset("BeIR/webis-touche2020-qrels", split="test")

# For this dataset, we want to concatenate the title and texts for the corpus
corpus = corpus.map(lambda x: {'text': x['title'] + " " + x['text']}, remove_columns=['title'])

# Shrink the corpus size heavily to only the relevant documents + 30,000 random documents
required_corpus_ids = set(map(str, relevant_docs_data["corpus-id"]))
required_corpus_ids |= set(random.sample(corpus["_id"], k=30_000))
corpus = corpus.filter(lambda x: x["_id"] in required_corpus_ids)

# Convert the datasets to dictionaries
corpus = dict(zip(corpus["_id"], corpus["text"]))  # Our corpus (cid => document)
queries = dict(zip(queries["_id"], queries["text"]))  # Our queries (qid => question)
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for qid, corpus_ids in zip(relevant_docs_data["query-id"], relevant_docs_data["corpus-id"]):
    qid = str(qid)
    corpus_ids = str(corpus_ids)
    if qid not in relevant_docs:
        relevant_docs[qid] = set()
    relevant_docs[qid].add(corpus_ids)

# Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR metrics.
ir_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="BeIR-touche2020-subset-test",
)
results = ir_evaluator(modelmini)
'''
Information Retrieval Evaluation of the model on the BeIR-touche2020-test dataset:
Queries: 49
Corpus: 31923

Score-Function: cosine
Accuracy@1: 77.55%
Accuracy@3: 93.88%
Accuracy@5: 97.96%
Accuracy@10: 100.00%
Precision@1: 77.55%
Precision@3: 72.11%
Precision@5: 71.43%
Precision@10: 62.65%
Recall@1: 1.72%
Recall@3: 4.78%
Recall@5: 7.90%
Recall@10: 13.86%
MRR@10: 0.8580
NDCG@10: 0.6606
MAP@100: 0.2934
'''
print(ir_evaluator.primary_metric)
# => "BeIR-touche2020-test_cosine_map@100"
print(results[ir_evaluator.primary_metric])
# => 0.29335196224364596

0000.parquet:   0%|          | 0.00/268M [00:00<?, ?B/s]

0001.parquet:   0%|          | 0.00/95.0M [00:00<?, ?B/s]

Generating corpus split:   0%|          | 0/382545 [00:00<?, ? examples/s]

queries/queries/0000.parquet:   0%|          | 0.00/3.72k [00:00<?, ?B/s]

Generating queries split:   0%|          | 0/49 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/14.0k [00:00<?, ?B/s]

test.tsv:   0%|          | 0.00/101k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/2214 [00:00<?, ? examples/s]

Map:   0%|          | 0/382545 [00:00<?, ? examples/s]

Filter:   0%|          | 0/382545 [00:00<?, ? examples/s]

BeIR-touche2020-subset-test_cosine_map@100
0.30561077016600136


In [None]:
def make_ir_evaluator_dataset(
    df: pd.DataFrame, all_mis_texts: list[str]
) -> tuple[dict, dict, dict]:
    temp = (
        df[
            [
                "QuestionId_Answer",
                "QuestionComplete",
                "MisconceptionId",
                "MisconceptionText",
            ]
        ]
        .drop_duplicates()
        .copy()
    )
    mapping = (
        temp[["QuestionId_Answer", "MisconceptionId"]]
        .set_index("QuestionId_Answer")["MisconceptionId"]
        .apply(lambda x: [x])
        .to_dict()
    )
    q = (
        temp[["QuestionId_Answer", "QuestionComplete"]]
        .set_index("QuestionId_Answer")["QuestionComplete"]
        .to_dict()
    )
    mis = {i: mis_text for i, mis_text in enumerate(all_mis_texts)}
    return q, mis, mapping


q, mis, mapping = make_ir_evaluator_dataset(df_val, all_mis_texts)

In [None]:

# TODO this shit is taking so much time lol
# * the minconception must come from ALL dataset
# * but the
evaluator = InformationRetrievalEvaluator(
    queries=q,
    corpus=mis,
    relevant_docs=mapping,
    map_at_k=[1, 3, 5, 10, 25],
    batch_size=4,
    show_progress_bar=True,
)

In [110]:
results = evaluator(modelmini)

Batches:   0%|          | 0/329 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:01<00:00,  1.70s/it]


In [85]:
# mis

In [111]:
len(q), len(mis)

(1313, 2587)

In [112]:
results

{'cosine_accuracy@1': 0.09367859862909368,
 'cosine_accuracy@3': 0.2063975628332064,
 'cosine_accuracy@5': 0.2802741812642803,
 'cosine_accuracy@10': 0.40594059405940597,
 'cosine_precision@1': 0.09367859862909368,
 'cosine_precision@3': 0.06879918761106879,
 'cosine_precision@5': 0.05605483625285606,
 'cosine_precision@10': 0.040594059405940595,
 'cosine_recall@1': 0.09367859862909368,
 'cosine_recall@3': 0.2063975628332064,
 'cosine_recall@5': 0.2802741812642803,
 'cosine_recall@10': 0.40594059405940597,
 'cosine_ndcg@10': 0.22940932083812754,
 'cosine_mrr@10': 0.17549112054062532,
 'cosine_map@1': 0.09367859862909368,
 'cosine_map@3': 0.1419141914191419,
 'cosine_map@5': 0.15866971312515865,
 'cosine_map@10': 0.1754911205406255,
 'cosine_map@25': 0.18624516297058172,
 'dot_accuracy@1': 0.09367859862909368,
 'dot_accuracy@3': 0.2063975628332064,
 'dot_accuracy@5': 0.2802741812642803,
 'dot_accuracy@10': 0.40594059405940597,
 'dot_precision@1': 0.09367859862909368,
 'dot_precision@3':