In [1]:
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
from transformers import AutoModel, AutoTokenizer
from torch import nn
import numpy as np
from torch import optim
import torch
from tqdm.auto import tqdm
import time

import torch.nn.functional as F
from torch import Tensor
from usearch.index import Index
import string
import random

In [2]:
# load model
model_name = "Alibaba-NLP/gte-large-en-v1.5"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# load data
df = pd.read_csv("data/eedi-paraphrased/train.csv")


In [3]:
df_q = df[
    [
        "ConstructName",
        "SubjectName",
        "QuestionId_Answer",
        "QuestionText",
        "WrongText",
        "CorrectText",
        "MisconceptionId",
    ]
].drop_duplicates()
df_q["QuestionComplete"] = (
    "Subject: "
    + df_q["SubjectName"]
    + ". Construct: "
    + df_q["ConstructName"]
    + ". Question: "
    + df_q["QuestionText"]
    + ". Correct answer: "
    + df_q["CorrectText"]
    + ". Wrong answer: "
    + df_q["WrongText"]
    + "."
)
df_q = (
    df_q[["MisconceptionId", "QuestionComplete"]]
    .sort_values("MisconceptionId")
    .reset_index(drop=True)
)
df_q

Unnamed: 0,MisconceptionId,QuestionComplete
0,0,Subject: Angles in Triangles. Construct: Find ...
1,0,Subject: Angles in Triangles. Construct: Find ...
2,0,Subject: Angles in Triangles. Construct: Find ...
3,0,Subject: Angles in Triangles. Construct: Find ...
4,0,Subject: Angles in Triangles. Construct: Find ...
...,...,...
21853,2586,Subject: Rearranging Formula and Equations. Co...
21854,2586,Subject: Rearranging Formula and Equations. Co...
21855,2586,Subject: Rearranging Formula and Equations. Co...
21856,2586,Subject: Rearranging Formula and Equations. Co...


In [4]:
df_m = df[["MisconceptionId", "MisconceptionText"]].sort_values("MisconceptionId").drop_duplicates().reset_index(drop=True)
df_m

Unnamed: 0,MisconceptionId,MisconceptionText
0,0,Unaware that the total of angles in a triangle...
1,0,Lacks knowledge that the angles within a trian...
2,0,Is not aware that the sum of angles in a trian...
3,0,Doesn't understand that the angles inside a tr...
4,0,Does not know that angles in a triangle sum to...
...,...,...
8015,2586,Misinterprets the rules governing the sequence...
8016,2586,Does not correctly understand how to prioritiz...
8017,2586,Misunderstands order of operations in algebrai...
8018,2586,Fails to grasp the correct sequence for perfor...


In [5]:
@torch.inference_mode()
def batched_inference(model, texts: list[str], bs: int, desc: str) -> Tensor:
    embeddings = []
    for i in tqdm(range(0, len(texts), bs), desc=desc):
        # max_length=256 comes from plotting the complete question text, and 256 covers 99%
        encoded = tokenizer(
            texts[i : i + bs],
            max_length=256,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to("cuda")
        outputs = model(**encoded)
        emb = outputs.last_hidden_state[:, 0]  # cls token
        emb = F.normalize(emb, p=2, dim=-1)
        embeddings.append(emb.cpu())
    embeddings = torch.cat(embeddings)
    return embeddings

In [None]:
def hn_mine(
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    k: int,
    bs: int,
) -> list[list[int]]:
    """Hard negative mining.

    Args:
        q_texts (list[str]): Question texts.
        q_mis_ids (list[int]): Ground truth misconception ids for the questions.
        mis_texts (list[str]): Misconception texts.
        mis_ids (list[int]): Misconception ids.
        k (int): Top k hard misconception ids per question.
        bs (int): Batch size.

    Returns:
        list[list[int]]: Hard misconceptions for each question. This is NOT misconception ids.
    """
    assert len(q_texts) == len(q_mis_ids)
    assert len(mis_texts) == len(mis_ids)
    m_embeds = batched_inference(model, mis_texts, bs=bs, desc="miscon").numpy()
    index = Index(ndim=m_embeds.shape[-1], metric="ip")
    index.add(np.arange(m_embeds.shape[0]), m_embeds)
    q_embeds = batched_inference(model, q_texts, bs=bs, desc="questions").numpy()
    batch_matches = index.search(q_embeds, count=k)
    hards = []
    for i, matches in enumerate(batch_matches):  # type: ignore
        nth_miscons: list[int] = [m.key for m in matches]
        hard_miscons = [nth for nth in nth_miscons if mis_ids[nth] != q_mis_ids[i]]
        hards.append(hard_miscons)
    assert len(hards) == len(q_texts)
    return hards


In [7]:
hards = hn_mine(
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    k=100,
    bs=32,
)

miscon:   0%|          | 0/251 [00:00<?, ?it/s]

questions:   0%|          | 0/684 [00:00<?, ?it/s]

searching q=21858 mis=8020 took 0.51 secs


In [8]:
# ram 5gb -> 
# gpu 