In [2]:
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
from transformers import AutoModel, AutoTokenizer
from torch import nn
import numpy as np
from torch import optim
import torch
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch import Tensor
from usearch.index import Index
import string
import random

In [3]:
# load model
model_name = "Alibaba-NLP/gte-large-en-v1.5"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# load data
df = pd.read_csv("data/eedi-paraphrased/train.csv")
# need to separate the data



In [4]:
df_q = (
    df[
        [
            "ConstructName",
            "SubjectName",
            "QuestionId_Answer",
            "QuestionText",
            "WrongText",
            "CorrectText",
        ]
    ]
    .drop_duplicates()
    .reset_index(drop=True)
)
df_q["QuestionComplete"] = (
    "Subject: "
    + df_q["SubjectName"]
    + ". Construct: "
    + df_q["ConstructName"]
    + ". Question: "
    + df_q["QuestionText"]
    + ". Correct answer: "
    + df_q["CorrectText"]
    + ". Wrong answer: "
    + df_q["WrongText"]
    + "."
)
df_q["QuestionComplete"].iloc[0]

'Subject: BIDMAS. Construct: Use the order of operations to carry out calculations involving powers. Question: \\[\n3 \\times 2+4-5\n\\]\nWhere do the brackets need to go to make the answer equal \\( 13 \\) ?. Correct answer: \\( 3 \\times(2+4)-5 \\). Wrong answer: Does not need brackets.'

In [5]:
df_m = df[["MisconceptionId", "MisconceptionText"]].sort_values("MisconceptionId").drop_duplicates().reset_index(drop=True)
df_m

Unnamed: 0,MisconceptionId,MisconceptionText
0,0,Unaware that the total of angles in a triangle...
1,0,Lacks knowledge that the angles within a trian...
2,0,Is not aware that the sum of angles in a trian...
3,0,Doesn't understand that the angles inside a tr...
4,0,Does not know that angles in a triangle sum to...
...,...,...
8015,2586,Misinterprets the rules governing the sequence...
8016,2586,Does not correctly understand how to prioritiz...
8017,2586,Misunderstands order of operations in algebrai...
8018,2586,Fails to grasp the correct sequence for perfor...


In [6]:
def randomstring():
    return "".join(random.choices(string.ascii_letters, k=20))

In [7]:
@torch.inference_mode()
def batched_inference(model, texts: list[str], bs: int) -> Tensor:
    embeddings  = []
    for i in tqdm(range(0, len(texts), bs)):
        # TODO calc max length
        encoded = tokenizer(texts[i:i+bs], max_length=8192, padding=True, truncation=True, return_tensors='pt').to("cuda")
        outputs = model(**encoded)
        emb = outputs.last_hidden_state[:, 0]  # cls token
        emb = F.normalize(emb, p=2, dim=-1)
        embeddings.append(emb.cpu())
    embeddings = torch.cat(embeddings)
    return embeddings

In [None]:
q_embeds = batched_inference(model, df_q["QuestionComplete"].tolist(), bs=32).numpy()
m_embeds = batched_inference(model, df_m["MisconceptionText"].tolist(), bs=32).numpy()
# add misc to index
index = Index(ndim=m_embeds.shape[-1], metric="ip", multi=True)
index.add(np.arange(m_embeds.shape[0]), m_embeds)
# (Optionally) normalize embeddings
# embeddings = F.normalize(embeddings, p=2, dim=1)
# scores = (embeddings[:1] @ embeddings[1:].T) * 100
# print(scores.tolist())

  0%|          | 0/684 [00:00<?, ?it/s]

In [None]:
tokenizer(["a", "b"], max_length=8192, padding=True, truncation=True, return_tensors='pt').to("cuda")

{'input_ids': tensor([[ 101, 1037,  102],
        [ 101, 1038,  102]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0],
        [0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1],
        [1, 1, 1]], device='cuda:0')}

In [None]:
asd

In [None]:
matches = index.search(m_embeds, count=5)
for match in matches:
    for top in match:
        print(top)
    print("asd")

Match(key=0, distance=0.00116539)
Match(key=54, distance=0.18357867)
Match(key=161, distance=0.18665767)
Match(key=51, distance=0.18724597)
Match(key=283, distance=0.18785274)
asd
Match(key=1, distance=0.0009351373)
Match(key=277, distance=0.15661108)
Match(key=283, distance=0.1732446)
Match(key=111, distance=0.18654776)
Match(key=193, distance=0.18737775)
asd
Match(key=2, distance=0.0013642907)
Match(key=166, distance=0.22117203)
Match(key=249, distance=0.23533094)
Match(key=247, distance=0.23660499)
Match(key=230, distance=0.23827046)
asd
Match(key=3, distance=0.0019643903)
Match(key=5, distance=0.19357985)
Match(key=51, distance=0.21026236)
Match(key=73, distance=0.21306616)
Match(key=65, distance=0.21601468)
asd
Match(key=4, distance=0.00029712915)
Match(key=249, distance=0.24169004)
Match(key=36, distance=0.25013936)
Match(key=213, distance=0.25562704)
Match(key=296, distance=0.25801486)
asd
Match(key=5, distance=0.0001784563)
Match(key=74, distance=0.13958502)
Match(key=18, dista

In [None]:
outputs.last_hidden_state[:, 0].size()

torch.Size([4, 1024])

In [None]:
# how would the dataset look like??
df[df["QuestionId"] == 1]

Unnamed: 0,QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectChoice,CorrectText,QuestionText,WrongChoice,WrongText,MisconceptionId,QuestionId_Answer,MisconceptionText,QuestionAiCreated,MisconceptionAiCreated
25,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,"Simplify the following, if possible: \( \frac{...",A,\( m+1 \),2142,1_A,Does not know that to factorise a quadratic ex...,False,False
26,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,"Simplify the following, if possible: \( \frac{...",A,\( m+1 \),2142,1_A,Is unaware that in order to factor a quadratic...,False,True
27,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,"Simplify the following, if possible: \( \frac{...",A,\( m+1 \),2142,1_A,Lacks knowledge that factoring a quadratic exp...,False,True
28,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,"Simplify the following, if possible: \( \frac{...",A,\( m+1 \),2142,1_A,Doesn't realize that to factor a quadratic exp...,False,True
29,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,"Simplify the following, if possible: \( \frac{...",A,\( m+1 \),2142,1_A,Is not informed that factoring a quadratic exp...,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,Could you simplify the following fraction: \( ...,C,\( m-1 \),2142,1_C,Does not know that to factorise a quadratic ex...,True,False
96,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,Could you simplify the following fraction: \( ...,C,\( m-1 \),2142,1_C,Is unaware that in order to factor a quadratic...,True,True
97,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,Could you simplify the following fraction: \( ...,C,\( m-1 \),2142,1_C,Lacks knowledge that factoring a quadratic exp...,True,True
98,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,Does not simplify,Could you simplify the following fraction: \( ...,C,\( m-1 \),2142,1_C,Doesn't realize that to factor a quadratic exp...,True,True
