In [1]:
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

  from tqdm.autonotebook import tqdm, trange


In [2]:
DATA_PATH = 'dataset'
RETRIEVE_NUM = 25

MODEL_OUTPUT_PATH = 'model/'

test = pd.read_csv(f"{DATA_PATH}/test.csv")
misconception_mapping = pd.read_csv(f"{DATA_PATH}/misconception_mapping.csv")

In [3]:
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

test_long = (
    test[common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]]]
    .melt(
        id_vars=common_col,
        value_vars=[f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]],
        var_name="AnswerType",
        value_name="AnswerText"
    )
)
test_long["AllText"] = test_long["ConstructName"] + " " + test_long["SubjectName"] + " " + test_long["QuestionText"] + " " + test_long["AnswerText"]
test_long["AnswerAlphabet"] = test_long["AnswerType"].str.extract(r"Answer([A-D])Text$")
test_long["QuestionId_Answer"] = test_long["QuestionId"].astype(str) + "_" + test_long["AnswerAlphabet"]


In [4]:
test_long

Unnamed: 0,QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
0,794,Know that the probabilities of exhaustive even...,Probability of Single Events,"A bag contains some strawberry sweets, some or...",D,AnswerAText,\( \frac{1}{3} \),Know that the probabilities of exhaustive even...,A,794_A
1,248,Recognise a quadratic graph from its shape,Quadratic Graphs-Others,Which of the following is an example of a quad...,D,AnswerAText,![A graph showing a straight line from top lef...,Recognise a quadratic graph from its shape Qua...,A,248_A
2,1376,Identify percentages of a shape where the perc...,Fractions of an Amount,What percentage of this shape is shaded? ![A b...,B,AnswerAText,\( 30 \% \),Identify percentages of a shape where the perc...,A,1376_A
3,473,Recognise the place value of each digit in int...,Place Value,What is the value of the \( 3 \) in the number...,D,AnswerAText,\( 30000 \),Recognise the place value of each digit in int...,A,473_A
4,695,Use the order of operations to carry out calcu...,BIDMAS,\( 18 \div 2+2 \times 2= \),B,AnswerAText,\( 3 \),Use the order of operations to carry out calcu...,A,695_A
...,...,...,...,...,...,...,...,...,...,...
743,207,Simplify algebraic expressions to maintain equ...,Expanding Single Brackets,\(\nx(3 x-5)+6(2 x-3) \equiv P x^{2}+Q x+R\n\)...,B,AnswerDText,\( -3 \),Simplify algebraic expressions to maintain equ...,D,207_D
744,1341,Label angles using correct 3-letter notation (...,"Types, Naming and Estimating",Which of the following correctly describes the...,D,AnswerDText,FGH,Label angles using correct 3-letter notation (...,D,1341_D
745,144,Continue a sequence involving triangle numbers,Other Sequences,If you add together consecutive triangle numbe...,A,AnswerDText,a Fibonacci number,Continue a sequence involving triangle numbers...,D,144_D
746,1563,Calculate the midpoint between two coordinates...,Midpoint Between Two Co-ordinates,Cara is trying to work out the midpoint of the...,C,AnswerDText,\( 3.5 \),Calculate the midpoint between two coordinates...,D,1563_D


In [5]:
model = SentenceTransformer(MODEL_OUTPUT_PATH)

test_long_vec = model.encode(
    test_long["AllText"].to_list(), normalize_embeddings=True
)
misconception_mapping_vec = model.encode(
    misconception_mapping["MisconceptionName"].to_list(), normalize_embeddings=True
)
print(test_long_vec.shape)
print(misconception_mapping_vec.shape)


(748, 1024)
(2587, 1024)


In [6]:
test_cos_sim_arr = cosine_similarity(test_long_vec, misconception_mapping_vec)
test_sorted_indices = np.argsort(-test_cos_sim_arr, axis=1)[:, :RETRIEVE_NUM]

In [7]:
test_long["MisconceptionId"] = [" ".join(map(str, indices)) for indices in test_sorted_indices]
test_long["MisconceptionText"] = ["\n".join(misconception_mapping.iloc[indices]["MisconceptionName"].values) for indices in test_sorted_indices]

# Filter where CorrectAnswer != AnswerAlphabet
filtered_test_long = test_long[test_long["CorrectAnswer"] != test_long["AnswerAlphabet"]]

# Select relevant columns and sort by QuestionId_Answer
submission = filtered_test_long[["QuestionId_Answer", "MisconceptionId"]].sort_values(by="QuestionId_Answer")
submission.to_csv("submission.csv", index=False)

In [8]:
submission

Unnamed: 0,QuestionId_Answer,MisconceptionId
5,1001_A,2330 1379 1988 1605 1675 2515 1272 1529 1248 8...
379,1001_C,2330 1379 1988 1605 1675 2515 1248 1272 1529 2...
566,1001_D,2330 1379 1988 1605 2515 1675 1248 1272 2392 8...
112,1006_A,1883 2353 2332 2450 1788 2197 305 2103 396 238...
486,1006_C,1883 2353 2332 2450 1788 2197 305 2103 396 117...
...,...,...
418,953_C,81 1058 1198 2016 1582 2252 2386 936 523 1293 ...
605,953_D,81 1198 1058 2016 1582 2386 2252 523 936 1240 ...
43,969_A,1471 2185 2119 2315 2165 1937 134 1231 621 201...
230,969_B,1471 2185 2119 2315 2165 134 1937 1231 2019 22...
