In [10]:
import pickle
import pprint
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pytz
import seaborn as sns
from omegaconf import OmegaConf
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

from src.config import cfg
from src.data import add_subject_name_info
from src.dir import create_dir
from src.seed import seed_everything

cfg.exp_number = Path().resolve().name
print(OmegaConf.to_yaml(cfg, resolve=True))

seed_everything(cfg.seed)


exp_number: '000'
run_time: base
data:
  input_root: ../../data/input
  train_path: ../../data/input/train.csv
  test_path: ../../data/input/test.csv
  sample_submission_path: ../../data/input/sample_submission.csv
  mapping_path: ../../data/input/misconception_mapping.csv
  output_root: ../../data/output
  results_root: ../../results
  results_path: ../../results/000/base
seed: 42



### データの読み込み

In [11]:
# データの読み込み
train_df = pl.read_csv(cfg.data.train_path, try_parse_dates=True)
test_df = pl.read_csv(cfg.data.test_path, try_parse_dates=True)
sample_submission_df = pl.read_csv(cfg.data.sample_submission_path, try_parse_dates=True)
mapping_df = pl.read_csv(cfg.data.mapping_path, try_parse_dates=True)

# CV
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=cfg.seed)


In [5]:
# TODO: 比較したい埋め込みモデルをここに追加していく
MODEL_NAME = "BAAI/bge-large-en-v1.5"
MODEL_NAME = "BAAI/bge-large-en-v1.5"
MODEL_NAME = "BAAI/bge-large-en-v1.5"


# TODO: model毎に処理が違う場合はここに追加


In [13]:
cv_scores = []
for train_idx, valid_idx in skf.split(train_df, train_df["SubjectName"]):
    # train_dfの分割
    train = train_df[train_idx]
    valid = train_df[valid_idx]

    # trainのSubjectName情報をmapping_dfに追加
    mapping_meta = add_subject_name_info(train, mapping_df)

    # TODO: trainを整形(gtがあるのでtestとは処理が違うの注意)

    # TODO: 埋め込みモデルでTOP50を抽出

    # TODO: TOP50の中にgtがあるか確認
    is_gt_in_top50 = [False]  # or True

    # TODO: 平均をとってそれをCVスコアとする
    avg_score = np.mean(is_gt_in_top50)
    cv_scores.append(avg_score)

print(f"モデル: {MODEL_NAME}")
print(f"CVスコア: {np.mean(cv_scores)}")




(1495,) (374,)
(1495,) (374,)
(1495,) (374,)
(1495,) (374,)
(1496,) (373,)
モデル: BAAI/bge-large-en-v1.5
CVスコア: 0.0


In [14]:
train.head()


QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectAnswer,QuestionText,AnswerAText,AnswerBText,AnswerCText,AnswerDText,MisconceptionAId,MisconceptionBId,MisconceptionCId,MisconceptionDId
i64,i64,str,i64,str,str,str,str,str,str,str,f64,f64,f64,f64
0,856,"""Use the order of operations to…",33,"""BIDMAS""","""A""","""\[ 3 \times 2+4-5 \] Where do …","""\( 3 \times(2+4)-5 \)""","""\( 3 \times 2+(4-5) \)""","""\( 3 \times(2+4-5) \)""","""Does not need brackets""",,,,1672.0
1,1612,"""Simplify an algebraic fraction…",1077,"""Simplifying Algebraic Fraction…","""D""","""Simplify the following, if pos…","""\( m+1 \)""","""\( m+2 \)""","""\( m-1 \)""","""Does not simplify""",2142.0,143.0,2142.0,
3,2377,"""Recall and use the intersectin…",88,"""Properties of Quadrilaterals""","""C""","""The angles highlighted on this…","""acute""","""obtuse""","""\( 90^{\circ} \)""","""Not enough information""",1180.0,1180.0,,1180.0
4,3387,"""Substitute positive integer va…",67,"""Substitution into Formula""","""A""","""The equation \( f=3 r^{2}+3 \)…","""\( 30 \)""","""\( 27 \)""","""\( 51 \)""","""\( 24 \)""",,,,1818.0
5,2052,"""Identify a unit of area""",75,"""Area of Simple Shapes""","""D""","""James has answered a question …","""\( m \)""","""\( \mathrm{cm} \)""","""\( \mathrm{km}^{3} \)""","""\( \mathrm{mm}^{2} \)""",686.0,686.0,686.0,
