In [45]:
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from collections import defaultdict

import pandas as pd
from transformers import AutoModel, AutoTokenizer
from torch import nn
import numpy as np
from torch import optim
import torch
from tqdm.auto import tqdm
import time

from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch import Tensor
from usearch.index import Index
import string
from sklearn.model_selection import GroupShuffleSplit
import random

In [None]:
# load model
model_name = "Alibaba-NLP/gte-large-en-v1.5"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# load data
df = pd.read_csv("data/eedi-paraphrased/train.csv")
gss = GroupShuffleSplit(n_splits=1, train_size=0.7, random_state=42)
train_idx, val_idx = next(gss.split(df, groups=df["QuestionId"]))
df_train = df.iloc[train_idx]
df_val = df.iloc[val_idx]
df_val = df_val[~df_val["QuestionAiCreated"] & ~df_val["MisconceptionAiCreated"]].reset_index(drop=True)

In [39]:
df_q = df_train[
    [
        "ConstructName",
        "SubjectName",
        "QuestionId_Answer",
        "QuestionText",
        "WrongText",
        "CorrectText",
        "MisconceptionId",
    ]
].drop_duplicates()
df_q["QuestionComplete"] = (
    "Subject: "
    + df_q["SubjectName"]
    + ". Construct: "
    + df_q["ConstructName"]
    + ". Question: "
    + df_q["QuestionText"]
    + ". Correct answer: "
    + df_q["CorrectText"]
    + ". Wrong answer: "
    + df_q["WrongText"]
    + "."
)
df_q = (
    df_q[["MisconceptionId", "QuestionComplete"]]
    .sort_values("MisconceptionId")
    .reset_index(drop=True)
)
df_q

Unnamed: 0,MisconceptionId,QuestionComplete
0,0,Subject: Angles in Triangles. Construct: Find ...
1,0,Subject: Angles in Triangles. Construct: Find ...
2,0,Subject: Angles in Triangles. Construct: Find ...
3,0,Subject: Angles in Triangles. Construct: Find ...
4,0,Subject: Angles in Triangles. Construct: Find ...
...,...,...
15288,2586,Subject: Rearranging Formula and Equations. Co...
15289,2586,Subject: Rearranging Formula and Equations. Co...
15290,2586,Subject: Rearranging Formula and Equations. Co...
15291,2586,Subject: Rearranging Formula and Equations. Co...


In [42]:
df_m = df[["MisconceptionId", "MisconceptionText"]].sort_values("MisconceptionId").drop_duplicates().reset_index(drop=True)
df_m

Unnamed: 0,MisconceptionId,MisconceptionText
0,0,Unaware that the total of angles in a triangle...
1,0,Lacks knowledge that the angles within a trian...
2,0,Is not aware that the sum of angles in a trian...
3,0,Doesn't understand that the angles inside a tr...
4,0,Does not know that angles in a triangle sum to...
...,...,...
8015,2586,Misinterprets the rules governing the sequence...
8016,2586,Does not correctly understand how to prioritiz...
8017,2586,Misunderstands order of operations in algebrai...
8018,2586,Fails to grasp the correct sequence for perfor...


In [46]:
@torch.inference_mode()
def batched_inference(model, texts: list[str], bs: int, desc: str) -> Tensor:
    embeddings = []
    for i in tqdm(range(0, len(texts), bs), desc=desc):
        # max_length=256 comes from plotting the complete question text, and 256 covers 99%
        encoded = tokenizer(
            texts[i : i + bs],
            max_length=256,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to("cuda")
        outputs = model(**encoded)
        emb = outputs.last_hidden_state[:, 0]  # cls token
        emb = F.normalize(emb, p=2, dim=-1)
        embeddings.append(emb.cpu())
    embeddings = torch.cat(embeddings)
    return embeddings

In [47]:
def hn_mine(
    q_texts: list[str],
    q_mis_ids: list[int],
    mis_texts: list[str],
    mis_ids: list[int],
    k: int,
    bs: int,
) -> list[list[int]]:
    """Hard negative mining.

    Args:
        q_texts (list[str]): Question texts.
        q_mis_ids (list[int]): Ground truth misconception ids for the questions.
        mis_texts (list[str]): Misconception texts.
        mis_ids (list[int]): Misconception ids.
        k (int): Top k hard misconception ids per question.
        bs (int): Batch size.

    Returns:
        list[list[int]]:
            Hard misconceptions for each question.
            This is NOT misconception ids, but the actual list index.
    """
    assert len(q_texts) == len(q_mis_ids)
    assert len(mis_texts) == len(mis_ids)
    m_embeds = batched_inference(model, mis_texts, bs=bs, desc="miscon").numpy()
    index = Index(ndim=m_embeds.shape[-1], metric="ip")
    index.add(np.arange(m_embeds.shape[0]), m_embeds)
    q_embeds = batched_inference(model, q_texts, bs=bs, desc="questions").numpy()
    batch_matches = index.search(q_embeds, count=k)
    hards = []
    for i, matches in enumerate(batch_matches):  # type: ignore
        nth_miscons: list[int] = [m.key for m in matches]
        hard_miscons = [nth for nth in nth_miscons if mis_ids[nth] != q_mis_ids[i]]
        hards.append(hard_miscons)
    assert len(hards) == len(q_texts)
    return hards


In [48]:
hards = hn_mine(
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    k=100,
    bs=32,
)

miscon:   0%|          | 0/251 [00:00<?, ?it/s]

questions:   0%|          | 0/478 [00:00<?, ?it/s]

In [None]:
class HardNegativeEediTrainDataset(Dataset):
    def __init__(
        self,
        q_texts: list[str],
        q_mis_ids: list[int],
        mis_texts: list[str],
        mis_ids: list[int],
        hards: list[list[int]],
    ) -> None:
        assert len(q_texts) == len(q_mis_ids) == len(hards)
        assert len(mis_texts) == len(mis_ids)
        self.q_texts = q_texts
        self.q_mis_ids = q_mis_ids
        self.mis_texts = mis_texts
        self.mis_ids = mis_ids
        self.hards = hards
        # reverse mapping
        self.mis_id_to_mis_idx = defaultdict(list)
        for i, mis_id in enumerate(mis_ids):
            self.mis_id_to_mis_idx[mis_id].append(i)

    def __len__(self) -> int:
        return len(self.q_texts)

    def __getitem__(self, i: int) -> dict:
        # in a single call, this must return a triples of anchor, positive, negative, i.e.
        # q_text itself, random correct misconception, random hard incorrect misconception
        ground_truth_mis_id = self.q_mis_ids[i]
        rand_pos = random.choice(self.mis_id_to_mis_idx[ground_truth_mis_id])
        rand_neg = random.choice(self.hards[i])
        return {
            "q": self.q_texts[i],
            "pos": self.mis_texts[rand_pos],
            "neg": self.mis_texts[rand_neg],
            "gt_mis_id": ground_truth_mis_id,
        }


class TrainCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch: list[dict]):
        pass


train_ds = HardNegativeEediTrainDataset(
    q_texts=df_q["QuestionComplete"].tolist(),
    q_mis_ids=df_q["MisconceptionId"].tolist(),
    mis_texts=df_m["MisconceptionText"].tolist(),
    mis_ids=df_m["MisconceptionId"].tolist(),
    hards=hards,
)
train_ds

<__main__.HardNegativeEediTrainDataset at 0x7da1b8f64210>

In [114]:
train_ds[444]

{'q': 'Subject: Converting between Decimals and Percentages. Construct: Convert decimals less than 1 with 2 decimal place to percentages. Question: How much does 0.01 added to 57 percent equal?. Correct answer: \\( 58 \\% \\). Wrong answer: \\( 57.1 \\% \\).',
 'pos': 'Multiplies by 10 instead of 100 when converting decimal to percentage',
 'neg': "Considers that a fraction taken from a number that isn't 100 stands for a percentage.",
 'gt_mis_id': 61}

In [57]:
orig = df[~df["QuestionAiCreated"] & ~df["MisconceptionAiCreated"]]
orig[orig["MisconceptionId"] == 4]

Unnamed: 0,QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectChoice,CorrectText,QuestionText,WrongChoice,WrongText,MisconceptionId,QuestionId_Answer,MisconceptionText,QuestionAiCreated,MisconceptionAiCreated
52790,902,1273,Simplify algebraic expressions to maintain equ...,251,Simplifying Expressions by Collecting Like Terms,D,Does not simplify,"Simplify, if possible:\n\[\nz+5\n\]",B,\( z^{5} \),4,902_B,Believes addition of terms and powers of terms...,False,False
58615,998,1274,Simplify algebraic expressions to maintain equ...,251,Simplifying Expressions by Collecting Like Terms,D,Does not simplify,"Simplify, if possible:\n\[\nc+a\n\]",C,\( c^{a} \),4,998_C,Believes addition of terms and powers of terms...,False,False


In [8]:
# ram 5gb -> 
# gpu 