In [1]:
import pickle
import pprint
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pytz
import seaborn as sns
from omegaconf import OmegaConf
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity

from src.config import cfg
from src.dir import create_dir
from src.seed import seed_everything

cfg.exp_number = Path().resolve().name
print(OmegaConf.to_yaml(cfg, resolve=True))

seed_everything(cfg.seed)

pl.Config.set_fmt_str_lengths(100000)


  from tqdm.autonotebook import tqdm, trange


exp_number: '001'
run_time: base
data:
  input_root: ../../data/input
  train_path: ../../data/input/train.csv
  test_path: ../../data/input/test.csv
  sample_submission_path: ../../data/input/sample_submission.csv
  mapping_path: ../../data/input/misconception_mapping.csv
  mapping_meta_path: ../../data/input/mapping_meta.parquet
  output_root: ../../data/output
  results_root: ../../results
  results_path: ../../results/001/base
seed: 42
embed_model: BAAI/bge-large-en-v1.5



polars.config.Config

- クエリ作成
- データベース作成
- エンコード処理
- 類似度計算 → データベースからTOPK抽出処理
- 

### Loading Data

In [2]:
# データの読み込み
train_df = pl.read_csv(cfg.data.train_path, try_parse_dates=True)
test_df = pl.read_csv(cfg.data.test_path, try_parse_dates=True)
sample_submission_df = pl.read_csv(cfg.data.sample_submission_path, try_parse_dates=True)
mapping_df = pl.read_csv(cfg.data.mapping_path, try_parse_dates=True)
mapping_meta_df = pl.read_parquet(cfg.data.mapping_meta_path)


In [3]:
test_df


QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectAnswer,QuestionText,AnswerAText,AnswerBText,AnswerCText,AnswerDText
i64,i64,str,i64,str,str,str,str,str,str,str
1869,856,"""Use the order of operations to carry out calculations involving powers""",33,"""BIDMAS""","""A""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""\( 3 \times(2+4)-5 \)""","""\( 3 \times 2+(4-5) \)""","""\( 3 \times(2+4-5) \)""","""Does not need brackets"""
1870,1612,"""Simplify an algebraic fraction by factorising the numerator""",1077,"""Simplifying Algebraic Fractions""","""D""","""Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)""","""\( m+1 \)""","""\( m+2 \)""","""\( m-1 \)""","""Does not simplify"""
1871,2774,"""Calculate the range from a list of data""",339,"""Range and Interquartile Range from a List of Data""","""B""","""Tom and Katie are discussing the \( 5 \) plants with these heights: \( 24 \mathrm{~cm}, 17 \mathrm{~cm}, 42 \mathrm{~cm}, 26 \mathrm{~cm}, 13 \mathrm{~cm} \) Tom says if all the plants were cut in half, the range wouldn't change. Katie says if all the plants grew by \( 3 \mathrm{~cm} \) each, the range wouldn't change. Who do you agree with?""","""Only Tom""","""Only Katie""","""Both Tom and Katie""","""Neither is correct"""


### Data Preparation

In [40]:
# testデータの整形
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

test_long = (
    test_df.select(pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]]))
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        (
            pl.lit("ConstructName: ")
            + pl.col("ConstructName")
            + pl.lit(" SubjectName: ")
            + pl.col("SubjectName")
            + pl.lit(" QuestionText: ")
            + pl.col("QuestionText")
            + pl.lit(" AnswerText: ")
            + pl.col("AnswerText")
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("QuestionId"),
                pl.col("AnswerAlphabet"),
            ],
            separator="_",
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
test_long.head()


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4)-5 \)""","""A""","""1869_A"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times 2+(4-5) \)""","""B""","""1869_B"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4-5) \)""","""C""","""1869_C"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: Does not need brackets""","""D""","""1869_D"""
1870,"""Simplify an algebraic fraction by factorising the numerator""","""Simplifying Algebraic Fractions""","""Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)""","""D""","""AnswerAText""","""\( m+1 \)""","""ConstructName: Simplify an algebraic fraction by factorising the numerator SubjectName: Simplifying Algebraic Fractions QuestionText: Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \) AnswerText: \( m+1 \)""","""A""","""1870_A"""


In [47]:
# 正解のテキストを抽出
test_correct_answer = (
    test_long.filter(pl.col("CorrectAnswer") == pl.col("AnswerAlphabet"))
    .select(pl.col(["QuestionId", "AnswerText"]))
    .rename({"AnswerText": "CorrectAnswerText"})
)
test_correct_answer.head()

# test_longに結合し、正解のテキスト列を追加
test_long = test_long.join(test_correct_answer, on="QuestionId", how="left")

test_long.head(3)


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,CorrectAnswerText,CorrectAnswerText_right
i64,str,str,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4)-5 \)""","""A""","""1869_A""","""\( 3 \times(2+4)-5 \)""","""\( 3 \times(2+4)-5 \)"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times 2+(4-5) \)""","""B""","""1869_B""","""\( 3 \times(2+4)-5 \)""","""\( 3 \times(2+4)-5 \)"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4-5) \)""","""C""","""1869_C""","""\( 3 \times(2+4)-5 \)""","""\( 3 \times(2+4)-5 \)"""


In [5]:
# mappingデータの整形
mapping_meta_df = mapping_meta_df.with_columns(
    [
        pl.when(pl.col("SubjectNames").list.len() > 0)
        .then(
            pl.lit("The misconception '")
            + pl.col("MisconceptionName")
            + pl.lit("' is primarily observed in the following subjects: ")
            + pl.col("SubjectNames").list.join(", ")
        )
        .otherwise(pl.lit("The misconception is: ") + pl.col("MisconceptionName"))
        .alias("MisconceptionName_with_SubjectNames")
    ]
)

mapping_meta_df.head()


MisconceptionId,MisconceptionName,SubjectNames,MisconceptionName_with_SubjectNames
i64,str,list[str],str
0,"""Does not know that angles in a triangle sum to 180 degrees""","[""Angles in Triangles""]","""The misconception 'Does not know that angles in a triangle sum to 180 degrees' is primarily observed in the following subjects: Angles in Triangles"""
1,"""Uses dividing fractions method for multiplying fractions""","[""Multiplying and Dividing Negative Numbers"", ""Multiplying Fractions""]","""The misconception 'Uses dividing fractions method for multiplying fractions' is primarily observed in the following subjects: Multiplying and Dividing Negative Numbers, Multiplying Fractions"""
2,"""Believes there are 100 degrees in a full turn""","[""Types, Naming and Estimating"", ""Measuring Angles""]","""The misconception 'Believes there are 100 degrees in a full turn' is primarily observed in the following subjects: Types, Naming and Estimating, Measuring Angles"""
3,"""Thinks a quadratic without a non variable term, can not be factorised""","[""Factorising into a Single Bracket""]","""The misconception 'Thinks a quadratic without a non variable term, can not be factorised' is primarily observed in the following subjects: Factorising into a Single Bracket"""
4,"""Believes addition of terms and powers of terms are equivalent e.g. a + c = a^c""","[""Simplifying Expressions by Collecting Like Terms""]","""The misconception 'Believes addition of terms and powers of terms are equivalent e.g. a + c = a^c' is primarily observed in the following subjects: Simplifying Expressions by Collecting Like Terms"""


### 埋め込みモデルでTOP100を抽出（1st stage）

In [6]:
model = SentenceTransformer(cfg.embed_model)

test_long_vec = model.encode(test_long["AllText"].to_list(), normalize_embeddings=True)
misconception_vec = model.encode(
    mapping_meta_df["MisconceptionName_with_SubjectNames"].to_list(), normalize_embeddings=True
)


In [7]:
print(test_long_vec.shape)
print(misconception_vec.shape)


(12, 1024)
(2587, 1024)


In [28]:
# 類似度TOP100を取得
top100ids = util.semantic_search(test_long_vec, misconception_vec, top_k=100)


In [30]:
top100ids[0]


[{'corpus_id': 2488, 'score': 0.7755094170570374},
 {'corpus_id': 2306, 'score': 0.7184184789657593},
 {'corpus_id': 1672, 'score': 0.715743362903595},
 {'corpus_id': 1963, 'score': 0.7100574970245361},
 {'corpus_id': 1316, 'score': 0.7064339518547058},
 {'corpus_id': 15, 'score': 0.7034092545509338},
 {'corpus_id': 328, 'score': 0.697814404964447},
 {'corpus_id': 1005, 'score': 0.6940087080001831},
 {'corpus_id': 1054, 'score': 0.6926456689834595},
 {'corpus_id': 2586, 'score': 0.6917635202407837},
 {'corpus_id': 2532, 'score': 0.6915374994277954},
 {'corpus_id': 871, 'score': 0.6897042393684387},
 {'corpus_id': 1862, 'score': 0.6863470673561096},
 {'corpus_id': 1516, 'score': 0.686028778553009},
 {'corpus_id': 1345, 'score': 0.6853331327438354},
 {'corpus_id': 1666, 'score': 0.6844601631164551},
 {'corpus_id': 987, 'score': 0.6817137002944946},
 {'corpus_id': 217, 'score': 0.6816158294677734},
 {'corpus_id': 2131, 'score': 0.6804996132850647},
 {'corpus_id': 1468, 'score': 0.67923694

### LLMによる絞り込み（2nd stage）

In [36]:
test_df.head()


QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectAnswer,QuestionText,AnswerAText,AnswerBText,AnswerCText,AnswerDText
i64,i64,str,i64,str,str,str,str,str,str,str
1869,856,"""Use the order of operations to carry out calculations involving powers""",33,"""BIDMAS""","""A""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""\( 3 \times(2+4)-5 \)""","""\( 3 \times 2+(4-5) \)""","""\( 3 \times(2+4-5) \)""","""Does not need brackets"""
1870,1612,"""Simplify an algebraic fraction by factorising the numerator""",1077,"""Simplifying Algebraic Fractions""","""D""","""Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)""","""\( m+1 \)""","""\( m+2 \)""","""\( m-1 \)""","""Does not simplify"""
1871,2774,"""Calculate the range from a list of data""",339,"""Range and Interquartile Range from a List of Data""","""B""","""Tom and Katie are discussing the \( 5 \) plants with these heights: \( 24 \mathrm{~cm}, 17 \mathrm{~cm}, 42 \mathrm{~cm}, 26 \mathrm{~cm}, 13 \mathrm{~cm} \) Tom says if all the plants were cut in half, the range wouldn't change. Katie says if all the plants grew by \( 3 \mathrm{~cm} \) each, the range wouldn't change. Who do you agree with?""","""Only Tom""","""Only Katie""","""Both Tom and Katie""","""Neither is correct"""


In [35]:
test_long.head(2)


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times 2+(4-5) \)""","""B""","""1869_B"""
1869,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""ConstructName: Use the order of operations to carry out calculations involving powers SubjectName: BIDMAS QuestionText: \[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ? AnswerText: \( 3 \times(2+4-5) \)""","""C""","""1869_C"""


In [None]:
# プロンプト
prompt = """
Here is a question about {ConstructName}({SubjectName}).
Question: {Question}
Correct Answer: {CorrectAnswer}
Incorrect Answer: {IncorrectAnswer}

You are a math teacher. Your task is to deduce and identify the misconceptions behind incorrect answers to the questions.
Below are 50 examples as reasons for wrong answers.
From the 50 examples below, please provide 25 that you believe are appropriate reasons for the incorrect answers.
You do not need to explain your reasoning process.
Sample answers: 1, 10, 11, 12, ....
"""


In [27]:
# # submissionの作成 保留中
# submission_df = (
#     test_long.with_columns(
#         pl.Series([" ".join([str(x["corpus_id"]) for x in top25id]) for top25id in top25ids]).alias("MisconceptionId")
#         )
#     .filter(pl.col("CorrectAnswer") != pl.col("AnswerAlphabet"))
#     .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
#     .sort("QuestionId_Answer")
# )

# assert submission_df.columns == sample_submission_df.columns
# submission_df


QuestionId_Answer,MisconceptionId
str,str
"""1869_B""","""2488 2306 1672 1963 1316 15 328 1005 1054 2532 2586 871 1516 1862 1345 1666 217 987 1468 2131 1642 2181 234 1365 1432 2518 466 1332 2221 706 1999 907 107 296 1690 1421 655 77 954 1971 1507 228 519 1124 2277 2264 189 2326 1207 1075"""
"""1869_C""","""2488 2306 1672 1963 1316 15 328 1005 1054 2586 2532 871 1345 1862 1516 1666 2131 217 987 1468 2181 1642 234 1432 1365 2518 2221 466 1332 706 1999 107 907 296 1690 1421 655 954 77 1971 519 228 1507 91 2277 1124 189 1207 1075 2326"""
"""1869_D""","""2488 871 1316 1005 1345 2306 1963 1672 1666 2131 15 328 1432 2586 1054 2532 1468 107 217 1862 1421 1516 1642 2277 987 1332 2181 2518 234 373 1365 466 91 519 1433 907 1172 1690 296 655 2143 628 228 77 911 2221 220 706 1004 189"""
"""1870_A""","""1593 2398 59 2142 1755 167 1540 2307 1825 363 519 2078 633 91 2068 907 1075 891 2134 353 715 2363 1980 954 1916 1048 80 547 1535 979 848 120 1549 2240 2234 1374 143 217 317 2147 1666 838 1610 78 792 2066 606 3 265 628"""
"""1870_B""","""1593 2398 59 2142 1755 167 1540 1825 2307 363 519 2078 633 91 2068 907 2134 1075 891 353 715 1980 2363 954 1916 547 80 1048 1535 979 848 120 1549 2240 2234 1374 143 217 317 838 2147 1666 78 1610 792 265 2066 606 3 347"""
"""1870_C""","""1593 2398 59 2142 1755 1540 167 1825 2307 363 519 2078 633 91 2068 907 891 2134 353 1075 715 2363 1980 80 1916 954 1048 547 1535 979 848 120 2240 2234 1549 1374 143 1666 317 217 2147 1610 838 2066 606 78 792 265 3 628"""
"""1871_A""","""397 691 1349 1677 1073 2551 1287 2346 365 2319 632 110 1797 1177 2456 631 2151 2064 1306 655 2457 1098 867 1623 2269 2119 307 461 1790 2426 1908 1602 1815 2427 1675 70 2327 2024 1527 2283 2044 618 1968 1611 1835 188 1982 2252 2035 2012"""
"""1871_C""","""397 691 1349 1677 1073 2551 1287 2346 365 2319 632 110 1797 1177 2456 631 2064 2151 1306 655 2457 1098 1623 867 2119 2269 461 307 1790 2426 1908 1602 2427 1815 1675 2024 70 1527 1968 2327 188 2283 618 2044 1923 1982 1611 2252 1835 377"""
"""1871_D""","""397 691 1677 1349 1073 2551 1287 2346 365 2319 632 1797 110 2151 2456 631 2064 1177 655 1306 2457 1098 867 1623 307 2269 1790 2119 461 2426 1815 1675 1908 2427 1602 2283 2044 70 2327 2024 1835 1611 1527 618 2252 1968 2012 2035 188 296"""
