In [1]:
import re
from pathlib import Path

import numpy as np
import polars as pl
import vllm
from omegaconf import OmegaConf
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer

from src.config import cfg
from src.seed import seed_everything

cfg.exp_number = Path().resolve().name
print(OmegaConf.to_yaml(cfg, resolve=True))

seed_everything(cfg.seed)

pl.Config.set_fmt_str_lengths(100000)


  from .autonotebook import tqdm as notebook_tqdm
2024-11-16 10:59:41,795	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


exp_number: '001'
run_time: base
data:
  input_root: ../../data/input
  train_path: ../../data/input/train.csv
  test_path: ../../data/input/test.csv
  sample_submission_path: ../../data/input/sample_submission.csv
  mapping_path: ../../data/input/misconception_mapping.csv
  mapping_meta_path: ../../data/input/mapping_meta.parquet
  output_root: ../../data/output
  results_root: ../../results
  results_path: ../../results/001/base
seed: 42
embed_model: BAAI/bge-large-en-v1.5
k: 50
llm_model: Qwen/Qwen2.5-32B-Instruct-AWQ



polars.config.Config

### Loading Data

In [2]:
# データの読み込み
train_df = pl.read_csv(cfg.data.train_path, try_parse_dates=True)
test_df = pl.read_csv(cfg.data.test_path, try_parse_dates=True)
sample_submission_df = pl.read_csv(cfg.data.sample_submission_path, try_parse_dates=True)
mapping_df = pl.read_csv(cfg.data.mapping_path, try_parse_dates=True)
mapping_meta_df = pl.read_parquet(cfg.data.mapping_meta_path)


### Data Preparation

In [3]:
# testデータの整形
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

test_long = (
    test_df.select(pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]]))
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        (
            pl.lit("ConstructName: ")
            + pl.col("ConstructName")
            + pl.lit(" SubjectName: ")
            + pl.col("SubjectName")
            + pl.lit(" QuestionText: ")
            + pl.col("QuestionText")
            + pl.lit(" AnswerText: ")
            + pl.col("AnswerText")
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("QuestionId"),
                pl.col("AnswerAlphabet"),
            ],
            separator="_",
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
test_long.head()


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4)-5 \)""","""A""","""1869_A"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times 2+(4-5) \)""","""B""","""1869_B"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4-5) \)""","""C""","""1869_C"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: Does not need brackets""","""D""","""1869_D"""
1870,"""Simplify an algebraic fraction by factorising the numerator""","""Simplifying Algebraic Fractions""","""Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)""","""D""","""AnswerAText""","""\( m+1 \)""","""ConstructName: Simplify an algebraic fraction by factorising the numerator SubjectName: Simplifying Algebraic Fractions QuestionText: Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \) AnswerText: \( m+1 \)""","""A""","""1870_A"""


In [4]:
# 正解のテキストを抽出
test_correct_answer = (
    test_long.filter(pl.col("CorrectAnswer") == pl.col("AnswerAlphabet"))
    .select(pl.col(["QuestionId", "AnswerText"]))
    .rename({"AnswerText": "CorrectAnswerText"})
)
test_correct_answer.head()

# test_longに結合し、正解のテキスト列を追加
test_long = test_long.join(test_correct_answer, on="QuestionId", how="left")

# CorrectAnswerとAnswerAlphabetが一致するもの（つまり正解）は除外
test_long = test_long.filter(pl.col("CorrectAnswer") != pl.col("AnswerAlphabet"))
test_long.head(3)


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,CorrectAnswerText
i64,str,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times 2+(4-5) \)""","""B""","""1869_B""","""\( 3 \times(2+4)-5 \)"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4-5) \)""","""C""","""1869_C""","""\( 3 \times(2+4)-5 \)"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: Does not need brackets""","""D""","""1869_D""","""\( 3 \times(2+4)-5 \)"""


In [5]:
# mappingデータの整形
mapping_meta_df = mapping_meta_df.with_columns(
    [
        # 複数のSubjectNamesがある場合は、それらを列挙した文字列を生成
        pl.when(pl.col("SubjectNames").list.len() > 0)
        .then(
            pl.lit("The misconception '")
            + pl.col("MisconceptionName")
            + pl.lit("' is primarily observed in the following subjects: ")
            + pl.col("SubjectNames").list.join(", ")
        )
        # 複数のSubjectNamesがない場合は、MisconceptionNameをそのまま使う
        .otherwise(pl.lit("The misconception is: ") + pl.col("MisconceptionName"))
        .alias("MisconceptionName_with_SubjectNames")
    ]
)

mapping_meta_df.head()


MisconceptionId,MisconceptionName,SubjectNames,MisconceptionName_with_SubjectNames
i64,str,list[str],str
0,"""Does not know that angles in a triangle sum to 180 degrees""","[""Angles in Triangles""]","""The misconception 'Does not know that angles in a triangle sum to 180 degrees' is primarily observed in the following subjects: Angles in Triangles"""
1,"""Uses dividing fractions method for multiplying fractions""","[""Multiplying and Dividing Negative Numbers"", ""Multiplying Fractions""]","""The misconception 'Uses dividing fractions method for multiplying fractions' is primarily observed in the following subjects: Multiplying and Dividing Negative Numbers, Multiplying Fractions"""
2,"""Believes there are 100 degrees in a full turn""","[""Types, Naming and Estimating"", ""Measuring Angles""]","""The misconception 'Believes there are 100 degrees in a full turn' is primarily observed in the following subjects: Types, Naming and Estimating, Measuring Angles"""
3,"""Thinks a quadratic without a non variable term, can not be factorised""","[""Factorising into a Single Bracket""]","""The misconception 'Thinks a quadratic without a non variable term, can not be factorised' is primarily observed in the following subjects: Factorising into a Single Bracket"""
4,"""Believes addition of terms and powers of terms are equivalent e.g. a + c = a^c""","[""Simplifying Expressions by Collecting Like Terms""]","""The misconception 'Believes addition of terms and powers of terms are equivalent e.g. a + c = a^c' is primarily observed in the following subjects: Simplifying Expressions by Collecting Like Terms"""


### 埋め込みモデルでTOPkを抽出（1st stage）

In [6]:
# 複数のモデルの埋め込みを取得し → それぞれの類似度をアンサンブル → TOPkを取得
models = {
    "model1": SentenceTransformer(cfg.embed_model),
    "model2": SentenceTransformer("../../data/Joseph-Eedi-finetuned-bge"),
    # "model3": SentenceTransformer("paraphrase-multilingual-mpnet-base-v2"),
}

# 各モデルで埋め込みと類似度計算を実行
similarities = {}
for model_name, model in models.items():
    # テキストの埋め込み
    test_vec = model.encode(test_long["AllText"].to_list(), normalize_embeddings=True)
    misconception_vec = model.encode(
        mapping_meta_df["MisconceptionName_with_SubjectNames"].to_list(), normalize_embeddings=True
    )

    # コサイン類似度の計算
    similarities[model_name] = cosine_similarity(test_vec, misconception_vec)

# アンサンブル（単純平均）
ensemble_sim = np.mean([sim for sim in similarities.values()], axis=0)

# TOPkの取得
test_sorted_indices = np.argsort(-ensemble_sim, axis=1)
topk_ids = test_sorted_indices[:, : cfg.k]


In [8]:
topk_ids[0]


array([1672, 2488,   15,  328, 1963, 1862, 2532, 1516, 2518, 1054,  706,
       2181, 1005, 2306, 1941, 1345, 1999,   77,  987, 2221, 2586, 1119,
       1642, 1507, 2270, 2140,  158, 2175, 2326, 1316, 1856, 2131, 2087,
       1890,  593, 2441,   27,  561, 1576,  751, 2278, 2556,  234,  107,
        400,  261, 1125, 1972, 1679,  217])

### LLMによる絞り込み（2nd stage）

In [9]:
# tokenizerを準備
tokenizer = AutoTokenizer.from_pretrained(cfg.llm_model)


In [10]:
# プロンプト
prompt = """
You are a mathematics education expert analyzing student misconceptions.

INPUT CONTEXT:
Subject: {SubjectName}
Topic: {ConstructName}
Question: {Question}
Correct Answer: {CorrectAnswer}
Student's Incorrect Answer: {IncorrectAnswer}

Below are {k} potential misconception candidates:
{misconception_topk}

TASK:
1. Select exactly 25 most relevant misconceptions
2. Rank them by likelihood (most likely first)

EVALUATION CRITERIA:
Primary Analysis Factors:
- Direct correlation with the observed error pattern
- Logical pathway between misconception and the given incorrect answer
- Mathematical concept alignment with question requirements
- Strength of causal relationship between misconception and error

Secondary Analysis Factors:
- Complexity match with the required solution steps
- Frequency of occurrence in similar mathematical contexts
- Cognitive load alignment with the problem-solving process
- Interaction effects between multiple misconceptions
- Impact on calculation sequence and final result
- Relevance to key mathematical principles involved

OUTPUT REQUIREMENTS:
1. Format: [numbers,numbers,...,numbers]
2. Rules:
   - Must contain exactly 25 numbers
   - Numbers must be between 1-{k}
   - Use commas without spaces
   - No duplicate numbers
   - Numbers must be ordered by likelihood (most likely first)
   - Must be enclosed in square brackets
   - No additional text or line breaks
   - No spaces anywhere in the output

EXAMPLE VALID OUTPUT:
[4,12,7,25,1,15,8,30,22,3,45,11,19,27,33,38,42,16,20,29,35,41,47,50,44]

EXAMPLE INVALID OUTPUTS:
4,12,7,25 (Missing brackets)
[4, 12, 7, 25] (Contains spaces)
[4,12,7] (Less than 25 numbers)
[4,12,7,25,...] (Incomplete list)

VALIDATION CHECKLIST:
- Contains exactly 25 numbers
- All numbers between 1-{k}
- No duplicate numbers
- Proper formatting with brackets
- No spaces or additional text
- Numbers ordered by likelihood
- Single line output only

Analyze the misconceptions using the specified criteria and provide your ranked output in the exact format specified above.
"""


In [11]:
# プロンプトを用いてテキストを前処理する用の関数
def preprocess_text(x):
    x = re.sub(r"http\w+", "", x)  # Delete URL
    x = re.sub(r"\.+", ".", x)  # Replace consecutive commas and periods with one comma and period character
    x = re.sub(r"\,+", ",", x)
    x = re.sub(r"\\\(", " ", x)
    x = re.sub(r"\\\)", " ", x)
    x = re.sub(r"[ ]{1,}", " ", x)
    x = x.strip()  # Remove empty characters at the beginning and end
    return x


def apply_template(row, tokenizer):
    messages = [
        {
            "role": "user",
            "content": preprocess_text(
                prompt.format(
                    ConstructName=row["ConstructName"],
                    SubjectName=row["SubjectName"],
                    Question=row["QuestionText"],
                    IncorrectAnswer=row["AnswerText"],
                    CorrectAnswer=row["CorrectAnswerText"],
                    k=cfg.k,
                    misconception_topk=row["misconception_topk"],
                )
            ),
        }
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return text


In [12]:
# まずStage1で取得したtopkの結果を、misconception_topk列としてtest_longに追加
misconception_topk_list = []
for topk_id in topk_ids:
    misconception_list = []
    for i, mid in enumerate(topk_id):
        misconception_list.append(f"{i+1}. {mapping_meta_df['MisconceptionName'][int(mid)]}")
    misconception_topk = "\n".join(misconception_list)
    misconception_topk_list.append(misconception_topk)

test_long = test_long.with_columns(pl.Series(misconception_topk_list).alias("misconception_topk"))

print(test_long["misconception_topk"][0])


1. Confuses the order of operations, believes addition comes before multiplication 
2. Answers order of operations questions with brackets as if the brackets are not there
3. Confuses the order of operations, believes addition comes before division
4. Performs addition ahead of multiplication
5. Performs subtraction right to left if priority order means doing a calculation to the right first
6. Performs addition ahead of division
7. Believes order of operations does not affect the answer to a calculation
8. Performs addition ahead of any other operation
9. May have made a calculation error using the order of operations
10. Performs subtraction in wrong order
11. Carries out operations from right to left regardless of priority order
12. Performs addition ahead of subtraction
13. Carries out operations from left to right regardless of priority order, unless brackets are used
14. Applies BIDMAS in strict order (does not realize addition and subtraction, and multiplication and division, ar

In [13]:
# apply_templateを適用し、prompt列を作成
test_long = test_long.with_columns(
    pl.Series(
        name="prompt",
        values=[apply_template(row, tokenizer) for row in test_long.iter_rows(named=True)],
    )
)


In [14]:
# 作成されたpromptを見てみる
print(test_long["prompt"][0])


<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
You are a mathematics education expert analyzing student misconceptions.

INPUT CONTEXT:
Subject: BIDMAS
Topic: Use the order of operations to carry out calculations involving powers
Question: \[
3 \times 2+4-5
\]
Where do the brackets need to go to make the answer equal 13 ?
Correct Answer: 3 \times(2+4)-5 
Student's Incorrect Answer: 3 \times 2+(4-5) 

Below are 50 potential misconception candidates:
1. Confuses the order of operations, believes addition comes before multiplication 
2. Answers order of operations questions with brackets as if the brackets are not there
3. Confuses the order of operations, believes addition comes before division
4. Performs addition ahead of multiplication
5. Performs subtraction right to left if priority order means doing a calculation to the right first
6. Performs addition ahead of division
7. Believes order of operations does not affe

In [15]:
# プロンプトのトークン数を確認
token_lengths = [len(tokenizer.encode(prompt)) for prompt in test_long["prompt"].to_list()]
print("トークン数の統計:")
print(f"最小: {min(token_lengths)}")
print(f"最大: {max(token_lengths)}")
print(f"平均: {sum(token_lengths)/len(token_lengths):.1f}")


トークン数の統計:
最小: 1359
最大: 1529
平均: 1466.0


In [17]:
# vllmを使って推論させる

llm = vllm.LLM(
    cfg.llm_model,
    quantization="awq",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.90,
    trust_remote_code=True,
    dtype="half",
    enforce_eager=True,
    max_model_len=3824,
    disable_log_stats=True,
)
tokenizer = llm.get_tokenizer()

# gpu_memory_utilization, max_model_lenの調整によってはRTX4090（24GB）ではメモリ不足になった


INFO 11-06 14:07:21 awq_marlin.py:101] Detected that the model can run with awq_marlin, however you specified quantization=awq explicitly, so forcing awq. Use quantization=awq_marlin for faster inference
INFO 11-06 14:07:21 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='Qwen/Qwen2.5-32B-Instruct-AWQ', speculative_config=None, tokenizer='Qwen/Qwen2.5-32B-Instruct-AWQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=3824, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, col

Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:02,  1.90it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:01<00:01,  1.79it/s]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:01<00:01,  1.70it/s]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:02<00:00,  1.62it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:03<00:00,  1.55it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:03<00:00,  1.62it/s]



INFO 11-06 14:07:26 model_runner.py:1067] Loading model weights took 18.1448 GB
INFO 11-06 14:07:28 gpu_executor.py:122] # GPU blocks: 286, # CPU blocks: 1024
INFO 11-06 14:07:28 gpu_executor.py:126] Maximum concurrency for 3824 tokens per request: 1.20x


In [18]:
# 推論
responses = llm.generate(
    test_long["prompt"].to_numpy(),
    vllm.SamplingParams(
        n=1,  # Number of output sequences to return for each prompt.
        top_p=0.8,  # Float that controls the cumulative probability of the top tokens to consider.
        temperature=0,  # randomness of the sampling
        seed=777,  # Seed for reprodicibility
        skip_special_tokens=False,  # Whether to skip special tokens in the output.
        max_tokens=100,  # Maximum number of tokens to generate per output sequence.
    ),
    use_tqdm=True,
)


Processed prompts:   0%|          | 0/9 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts: 100%|██████████| 9/9 [00:12<00:00,  1.42s/it, est. speed input: 858.25 toks/s, output: 54.12 toks/s]


In [46]:
def extract_numbers(response_text):
    try:
        # テキストをクリーンアップ
        cleaned_text = re.sub(r"[^0-9,]", "", response_text)
        cleaned_text = re.sub(r",+", ",", cleaned_text)
        cleaned_text = cleaned_text.strip(",")

        # 数値のリストに変換
        numbers = [int(num.strip()) for num in cleaned_text.split(",") if num.strip()]

        # バリデーション
        if not numbers:
            return list(range(25))  # フォールバック：最初の25個を返す

        # 1以上かつk未満の数値のみを保持
        numbers = [n for n in numbers if 0 < n < cfg.k]  # kより大きい値を除外

        # 25個になるように調整
        if len(numbers) > 25:
            numbers = numbers[:25]
        elif len(numbers) < 25:
            # 不足分は直前の数値で補完
            last_num = numbers[-1] if numbers else 1  # 空リストの場合は1を使用
            while len(numbers) < 25:
                numbers.append(last_num)

        # インデックスを0ベースに変換
        return [num - 1 for num in numbers]

    except Exception as e:
        print(f"Error processing response: {e}")
        return list(range(25))  # エラー時は最初の25個を返す


In [None]:
# LLMが返した生の数値をリストに格納する
llm_raw_responses = [extract_numbers(response.outputs[0].text) for response in responses]


In [None]:
# 対応するmisconceptionIDに変換する
misconception_top25_list = []
for i, llm_raw_response in enumerate(llm_raw_responses):
    misconception_top25 = [str(topk_ids[i][row_n]) for row_n in llm_raw_response]
    misconception_top25_list.append(misconception_top25)


In [None]:
# test_longからsubmissionを作成する
submission = (
    test_long.select(["QuestionId_Answer"])
    .with_columns(pl.Series(misconception_top25_list).alias("MisconceptionId"))
    .with_columns(pl.col("MisconceptionId").list.join(" "))
)
assert sample_submission_df.schema == submission.schema

# save
submission.write_csv("submission.csv")
