In [30]:
import pickle
import pprint
import re
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pytz
import seaborn as sns
from omegaconf import OmegaConf
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer

from src.config import cfg
from src.dir import create_dir
from src.seed import seed_everything

cfg.exp_number = Path().resolve().name
print(OmegaConf.to_yaml(cfg, resolve=True))

seed_everything(cfg.seed)

pl.Config.set_fmt_str_lengths(100000)


exp_number: '001'
run_time: base
data:
  input_root: ../../data/input
  train_path: ../../data/input/train.csv
  test_path: ../../data/input/test.csv
  sample_submission_path: ../../data/input/sample_submission.csv
  mapping_path: ../../data/input/misconception_mapping.csv
  mapping_meta_path: ../../data/input/mapping_meta.parquet
  output_root: ../../data/output
  results_root: ../../results
  results_path: ../../results/001/base
seed: 42
embed_model: BAAI/bge-large-en-v1.5



polars.config.Config

- クエリ作成
- データベース作成
- エンコード処理
- 類似度計算 → データベースからTOPK抽出処理
- 

### Loading Data

In [2]:
# データの読み込み
train_df = pl.read_csv(cfg.data.train_path, try_parse_dates=True)
test_df = pl.read_csv(cfg.data.test_path, try_parse_dates=True)
sample_submission_df = pl.read_csv(cfg.data.sample_submission_path, try_parse_dates=True)
mapping_df = pl.read_csv(cfg.data.mapping_path, try_parse_dates=True)
mapping_meta_df = pl.read_parquet(cfg.data.mapping_meta_path)


In [3]:
test_df


QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectAnswer,QuestionText,AnswerAText,AnswerBText,AnswerCText,AnswerDText
i64,i64,str,i64,str,str,str,str,str,str,str
1869,856,"""Use the order of operations to carry out calculations involving powers""",33,"""BIDMAS""","""A""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""\( 3 \times(2+4)-5 \)""","""\( 3 \times 2+(4-5) \)""","""\( 3 \times(2+4-5) \)""","""Does not need brackets"""
1870,1612,"""Simplify an algebraic fraction by factorising the numerator""",1077,"""Simplifying Algebraic Fractions""","""D""","""Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)""","""\( m+1 \)""","""\( m+2 \)""","""\( m-1 \)""","""Does not simplify"""
1871,2774,"""Calculate the range from a list of data""",339,"""Range and Interquartile Range from a List of Data""","""B""","""Tom and Katie are discussing the \( 5 \) plants with these heights: \( 24 \mathrm{~cm}, 17 \mathrm{~cm}, 42 \mathrm{~cm}, 26 \mathrm{~cm}, 13 \mathrm{~cm} \) Tom says if all the plants were cut in half, the range wouldn't change. Katie says if all the plants grew by \( 3 \mathrm{~cm} \) each, the range wouldn't change. Who do you agree with?""","""Only Tom""","""Only Katie""","""Both Tom and Katie""","""Neither is correct"""


### Data Preparation

In [17]:
# testデータの整形
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

test_long = (
    test_df.select(pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]]))
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        (
            pl.lit("ConstructName: ")
            + pl.col("ConstructName")
            + pl.lit(" SubjectName: ")
            + pl.col("SubjectName")
            + pl.lit(" QuestionText: ")
            + pl.col("QuestionText")
            + pl.lit(" AnswerText: ")
            + pl.col("AnswerText")
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("QuestionId"),
                pl.col("AnswerAlphabet"),
            ],
            separator="_",
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
test_long.head()


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4)-5 \)""","""A""","""1869_A"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times 2+(4-5) \)""","""B""","""1869_B"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4-5) \)""","""C""","""1869_C"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: Does not need brackets""","""D""","""1869_D"""
1870,"""Simplify an algebraic fraction by factorising the numerator""","""Simplifying Algebraic Fractions""","""Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)""","""D""","""AnswerAText""","""\( m+1 \)""","""ConstructName: Simplify an algebraic fraction by factorising the numerator SubjectName: Simplifying Algebraic Fractions QuestionText: Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \) AnswerText: \( m+1 \)""","""A""","""1870_A"""


In [18]:
# 正解のテキストを抽出
test_correct_answer = (
    test_long.filter(pl.col("CorrectAnswer") == pl.col("AnswerAlphabet"))
    .select(pl.col(["QuestionId", "AnswerText"]))
    .rename({"AnswerText": "CorrectAnswerText"})
)
test_correct_answer.head()

# test_longに結合し、正解のテキスト列を追加
test_long = test_long.join(test_correct_answer, on="QuestionId", how="left")

# CorrectAnswerとAnswerAlphabetが一致するもの（つまり正解）は除外
test_long = test_long.filter(pl.col("CorrectAnswer") != pl.col("AnswerAlphabet"))
test_long.head(3)


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,CorrectAnswerText
i64,str,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times 2+(4-5) \)""","""B""","""1869_B""","""\( 3 \times(2+4)-5 \)"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4-5) \)""","""C""","""1869_C""","""\( 3 \times(2+4)-5 \)"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: Does not need brackets""","""D""","""1869_D""","""\( 3 \times(2+4)-5 \)"""


In [19]:
# mappingデータの整形
mapping_meta_df = mapping_meta_df.with_columns(
    [
        pl.when(pl.col("SubjectNames").list.len() > 0)
        .then(
            pl.lit("The misconception '")
            + pl.col("MisconceptionName")
            + pl.lit("' is primarily observed in the following subjects: ")
            + pl.col("SubjectNames").list.join(", ")
        )
        .otherwise(pl.lit("The misconception is: ") + pl.col("MisconceptionName"))
        .alias("MisconceptionName_with_SubjectNames")
    ]
)

mapping_meta_df.head()


MisconceptionId,MisconceptionName,SubjectNames,MisconceptionName_with_SubjectNames
i64,str,list[str],str
0,"""Does not know that angles in a triangle sum to 180 degrees""","[""Angles in Triangles""]","""The misconception 'Does not know that angles in a triangle sum to 180 degrees' is primarily observed in the following subjects: Angles in Triangles"""
1,"""Uses dividing fractions method for multiplying fractions""","[""Multiplying and Dividing Negative Numbers"", ""Multiplying Fractions""]","""The misconception 'Uses dividing fractions method for multiplying fractions' is primarily observed in the following subjects: Multiplying and Dividing Negative Numbers, Multiplying Fractions"""
2,"""Believes there are 100 degrees in a full turn""","[""Types, Naming and Estimating"", ""Measuring Angles""]","""The misconception 'Believes there are 100 degrees in a full turn' is primarily observed in the following subjects: Types, Naming and Estimating, Measuring Angles"""
3,"""Thinks a quadratic without a non variable term, can not be factorised""","[""Factorising into a Single Bracket""]","""The misconception 'Thinks a quadratic without a non variable term, can not be factorised' is primarily observed in the following subjects: Factorising into a Single Bracket"""
4,"""Believes addition of terms and powers of terms are equivalent e.g. a + c = a^c""","[""Simplifying Expressions by Collecting Like Terms""]","""The misconception 'Believes addition of terms and powers of terms are equivalent e.g. a + c = a^c' is primarily observed in the following subjects: Simplifying Expressions by Collecting Like Terms"""


### 埋め込みモデルでTOP100を抽出（1st stage）

In [20]:
model = SentenceTransformer(cfg.embed_model)

test_long_vec = model.encode(test_long["AllText"].to_list(), normalize_embeddings=True)
misconception_vec = model.encode(
    mapping_meta_df["MisconceptionName_with_SubjectNames"].to_list(), normalize_embeddings=True
)


In [21]:
print(test_long_vec.shape)
print(misconception_vec.shape)


(9, 1024)
(2587, 1024)


In [25]:
# 類似度TOP50を取得
top50ids = util.semantic_search(test_long_vec, misconception_vec, top_k=50)


In [69]:
top50ids[0]


[{'corpus_id': 2488, 'score': 0.7744837403297424},
 {'corpus_id': 2306, 'score': 0.7182890176773071},
 {'corpus_id': 1672, 'score': 0.7162009477615356},
 {'corpus_id': 1963, 'score': 0.7103480696678162},
 {'corpus_id': 1316, 'score': 0.7066697478294373},
 {'corpus_id': 15, 'score': 0.7040739059448242},
 {'corpus_id': 328, 'score': 0.6977964043617249},
 {'corpus_id': 1005, 'score': 0.6934683322906494},
 {'corpus_id': 1054, 'score': 0.6923346519470215},
 {'corpus_id': 2532, 'score': 0.6913980841636658},
 {'corpus_id': 2586, 'score': 0.68998783826828},
 {'corpus_id': 871, 'score': 0.6881226897239685},
 {'corpus_id': 1516, 'score': 0.6867824792861938},
 {'corpus_id': 1862, 'score': 0.6867375373840332},
 {'corpus_id': 1345, 'score': 0.6852461695671082},
 {'corpus_id': 1666, 'score': 0.6822766065597534},
 {'corpus_id': 217, 'score': 0.6821097135543823},
 {'corpus_id': 987, 'score': 0.6819729804992676},
 {'corpus_id': 1468, 'score': 0.6798923015594482},
 {'corpus_id': 2131, 'score': 0.6781441

### LLMによる絞り込み（2nd stage）

In [32]:
# tokenizerを準備
model_name = "Qwen/Qwen2.5-32B-Instruct-AWQ"
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [66]:
# プロンプト
prompt = """
Question Details:
Subject: {SubjectName}
Topic: {ConstructName}
Question: {Question}
Correct Answer: {CorrectAnswer}
Student's Incorrect Answer: {IncorrectAnswer}

You are an experienced mathematics teacher analyzing student misconceptions. Your task is to identify the underlying misconceptions that led to this incorrect answer.

Below are 50 potential misconceptions identified by semantic similarity analysis for this specific question:
{misconception_top50}

Instructions:
1. From these semantically similar misconceptions, select 25 that are most likely to explain this student's error
2. Rank your selections by confidence level (most likely first)
3. Provide only the numbers in a comma-separated format (e.g., 1,10,11,12,...)

Key considerations:
- Consider the student's likely problem-solving process
- Take into account how well each misconception matches the specific error pattern
- Pay special attention to the semantic relevance already identified
- Consider how the incorrect answer might have been derived from these misconceptions

Output format: [numbers only, comma-separated]
"""


In [64]:
# プロンプトを用いてテキストを前処理する用の関数
def preprocess_text(x):
    x = re.sub(r"http\w+", "", x)  # Delete URL
    x = re.sub(r"\.+", ".", x)  # Replace consecutive commas and periods with one comma and period character
    x = re.sub(r"\,+", ",", x)
    x = re.sub(r"\\\(", " ", x)
    x = re.sub(r"\\\)", " ", x)
    x = re.sub(r"[ ]{1,}", " ", x)
    x = x.strip()  # Remove empty characters at the beginning and end
    return x


def apply_template(row, tokenizer):
    messages = [
        {
            "role": "user",
            "content": preprocess_text(
                prompt.format(
                    ConstructName=row["ConstructName"],
                    SubjectName=row["SubjectName"],
                    Question=row["QuestionText"],
                    IncorrectAnswer=row["AnswerText"],
                    CorrectAnswer=row["CorrectAnswerText"],
                    misconception_top50=row["misconception_top50"],
                )
            ),
        }
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return text


In [63]:
# まずStage1で取得したtop50の結果を、misconception_top50列としてtest_longに追加
misconception_top50_list = []
for top50id in top50ids:
    misconception_list = []
    for i, mid in enumerate(top50id):
        misconception_list.append(f"{i+1}. {mapping_meta_df['MisconceptionName'][mid['corpus_id']]}")
    misconception_top50 = "\n".join(misconception_list)
    misconception_top50_list.append(misconception_top50)

test_long = test_long.with_columns(
    pl.Series(misconception_top50_list).alias("misconception_top50")
)

print(test_long["misconception_top50"][0])



1. Answers order of operations questions with brackets as if the brackets are not there
2. Applies BIDMAS in strict order (does not realize addition and subtraction, and multiplication and division, are of equal priority)
3. Confuses the order of operations, believes addition comes before multiplication 
4. Performs subtraction right to left if priority order means doing a calculation to the right first
5. Does not understand that the numerator and denominator of fractions represent groupings and have the same order of priority as brackets
6. Confuses the order of operations, believes addition comes before division
7. Performs addition ahead of multiplication
8. Carries out operations from left to right regardless of priority order, unless brackets are used
9. Performs subtraction in wrong order
10. Believes order of operations does not affect the answer to a calculation
11. Misunderstands order of operations in algebraic expressions
12. Does not include brackets when attempting to mul

In [67]:
# apply_templateを適用し、prompt列を作成
test_long = test_long.with_columns(
    pl.Series(
        name="prompt",
        values=[apply_template(row, tokenizer) for row in test_long.iter_rows(named=True)],
    )
)


In [68]:
# 作成されたpromptを見てみる
print(test_long["prompt"][0])



<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Question Details:
Subject: BIDMAS
Topic: Use the order of operations to carry out calculations involving powers
Question: \[
3 \times 2+4-5
\]
Where do the brackets need to go to make the answer equal 13 ?
Correct Answer: 3 \times(2+4)-5 
Student's Incorrect Answer: 3 \times 2+(4-5) 

You are an experienced mathematics teacher analyzing student misconceptions. Your task is to identify the underlying misconceptions that led to this incorrect answer.

Below are 50 potential misconceptions identified by semantic similarity analysis for this specific question:
1. Answers order of operations questions with brackets as if the brackets are not there
2. Applies BIDMAS in strict order (does not realize addition and subtraction, and multiplication and division, are of equal priority)
3. Confuses the order of operations, believes addition comes before multiplication 
4. Performs subt