In [1]:
from datetime import datetime
from pathlib import Path

import numpy as np
import polars as pl
import pytz
import torch
from datasets import Dataset
from omegaconf import OmegaConf
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sklearn.metrics.pairwise import cosine_similarity

from src.config import cfg
from src.data import add_subject_name_info, preprocess_train
from src.dir import create_dir
from src.seed import seed_everything

cfg.exp_number = Path().resolve().name
print(OmegaConf.to_yaml(cfg, resolve=True))

seed_everything(cfg.seed)


  from .autonotebook import tqdm as notebook_tqdm


exp_number: '002'
run_time: base
data:
  input_root: ../../data/input
  train_path: ../../data/input/train.csv
  test_path: ../../data/input/test.csv
  sample_submission_path: ../../data/input/sample_submission.csv
  mapping_path: ../../data/input/misconception_mapping.csv
  mapping_meta_path: ../../data/input/mapping_meta.parquet
  output_root: ../../data/output
  results_root: ../../results
  results_path: ../../results/002/base
seed: 42
k: 25
model:
  model_name: BAAI/bge-large-en-v1.5
  epoch: 2
  lr: 2.0e-05
  batch_size: 8



In [2]:
DEBUG = False


### Data Load

In [3]:
# データの読み込み
train = pl.read_csv(cfg.data.train_path, try_parse_dates=True)
test = pl.read_csv(cfg.data.test_path, try_parse_dates=True)
sample_submission = pl.read_csv(cfg.data.sample_submission_path, try_parse_dates=True)
mapping = pl.read_csv(cfg.data.mapping_path, try_parse_dates=True)
mapping_meta = pl.read_parquet(cfg.data.mapping_meta_path)


In [4]:
# trainの前処理
train_long = preprocess_train(train)
train_long.head()


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order o…","""D""","""0_D""",1672
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""ConstructName: Simplify an alg…","""A""","""1000_A""",891
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerCText""","""\( 1 \)""","""ConstructName: Simplify an alg…","""C""","""1000_C""",891
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerDText""","""Does not simplify""","""ConstructName: Simplify an alg…","""D""","""1000_D""",353
1001,"""Round numbers to two decimal p…","""Rounding to Decimal Places""","""What is \( \mathbf{3 . 5 1 6 3…","""B""","""AnswerAText""","""\( 3.51 \)""","""ConstructName: Round numbers t…","""A""","""1001_A""",1379


In [5]:
# trainのSubjectName情報をmappingに追加
mapping_meta = add_subject_name_info(train, mapping)
mapping_meta.head()


MisconceptionId,MisconceptionName,SubjectNames,MisconceptionName_with_SubjectNames
i64,str,list[str],str
0,"""Does not know that angles in a…","[""Angles in Triangles""]","""The misconception 'Does not kn…"
1,"""Uses dividing fractions method…","[""Multiplying and Dividing Negative Numbers"", ""Multiplying Fractions""]","""The misconception 'Uses dividi…"
2,"""Believes there are 100 degrees…","[""Types, Naming and Estimating"", ""Measuring Angles""]","""The misconception 'Believes th…"
3,"""Thinks a quadratic without a n…","[""Factorising into a Single Bracket""]","""The misconception 'Thinks a qu…"
4,"""Believes addition of terms and…","[""Simplifying Expressions by Collecting Like Terms""]","""The misconception 'Believes ad…"


### Make retrieval data

In [6]:
model = SentenceTransformer(cfg.model.model_name, trust_remote_code=True)

train_long_vec = model.encode(train_long["AllText"].to_list(), normalize_embeddings=True)
misconception_mapping_vec = model.encode(
    mapping_meta["MisconceptionName_with_SubjectNames"].to_list(), normalize_embeddings=True
)
print(train_long_vec.shape, misconception_mapping_vec.shape)


(4370, 1024) (2587, 1024)


In [7]:
# コサイン類似度を計算
train_cos_sim_arr = cosine_similarity(train_long_vec, misconception_mapping_vec)
print(f"train_cos_sim_arr.shape: {train_cos_sim_arr.shape}")
print(train_cos_sim_arr[0])

# コサイン類似度が降順になるように各行をソート
train_sorted_indices = np.argsort(-train_cos_sim_arr, axis=1)
print(f"\ntrain_sorted_indices.shape: {train_sorted_indices.shape}")
print(train_sorted_indices[0])


train_cos_sim_arr.shape: (4370, 2587)
[0.5054938  0.6122866  0.42998317 ... 0.44713372 0.64220685 0.6939039 ]

train_sorted_indices.shape: (4370, 2587)
[2488  871 1316 ... 2299  237  211]


In [8]:
train_long = train_long.with_columns(
    pl.Series(train_sorted_indices[:, : cfg.k].tolist()).alias("PredictMisconceptionId")
)
train_long.head(1)


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId,PredictMisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64,list[i64]
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order o…","""D""","""0_D""",1672,"[2488, 871, … 2181]"


In [9]:
mapping_meta.rename(lambda x: "Predict" + x)


PredictMisconceptionId,PredictMisconceptionName,PredictSubjectNames,PredictMisconceptionName_with_SubjectNames
i64,str,list[str],str
0,"""Does not know that angles in a…","[""Angles in Triangles""]","""The misconception 'Does not kn…"
1,"""Uses dividing fractions method…","[""Multiplying and Dividing Negative Numbers"", ""Multiplying Fractions""]","""The misconception 'Uses dividi…"
2,"""Believes there are 100 degrees…","[""Types, Naming and Estimating"", ""Measuring Angles""]","""The misconception 'Believes th…"
3,"""Thinks a quadratic without a n…","[""Factorising into a Single Bracket""]","""The misconception 'Thinks a qu…"
4,"""Believes addition of terms and…","[""Simplifying Expressions by Collecting Like Terms""]","""The misconception 'Believes ad…"
…,…,…,…
2582,"""When multiplying numbers with …",[],"""The misconception is: When mul…"
2583,"""Does not know what a cube numb…","[""Square Roots, Cube Roots, etc"", ""Squares, Cubes, etc""]","""The misconception 'Does not kn…"
2584,"""Believes that any percentage o…",[],"""The misconception is: Believes…"
2585,"""Believes a cubic expression sh…","[""Expanding Triple Brackets and more""]","""The misconception 'Believes a …"


In [10]:
train_retrieved = (
    train_long.explode("PredictMisconceptionId")
    .join(mapping_meta, on="MisconceptionId")
    .join(mapping_meta.rename(lambda x: "Predict" + x), on="PredictMisconceptionId")
    # MisconceptionIdとPredictMisconceptionIdが同じ行（正解してるもの）を削除
    .filter(pl.col("MisconceptionId") != pl.col("PredictMisconceptionId"))
)
display(train_retrieved.head(3))
print(f"train_retrieved.shape: {train_retrieved.shape}")


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId,PredictMisconceptionId,MisconceptionName,SubjectNames,MisconceptionName_with_SubjectNames,PredictMisconceptionName,PredictSubjectNames,PredictMisconceptionName_with_SubjectNames
i64,str,str,str,str,str,str,str,str,str,i64,i64,str,list[str],str,str,list[str],str
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order o…","""D""","""0_D""",1672,2488,"""Confuses the order of operatio…","[""BIDMAS""]","""The misconception 'Confuses th…","""Answers order of operations qu…","[""Multiplying and Dividing Algebraic Fractions"", ""Substitution into Formula"", … ""Function Machines""]","""The misconception 'Answers ord…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order o…","""D""","""0_D""",1672,871,"""Confuses the order of operatio…","[""BIDMAS""]","""The misconception 'Confuses th…","""Does not include brackets when…","[""Linear Equations""]","""The misconception 'Does not in…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""ConstructName: Use the order o…","""D""","""0_D""",1672,1316,"""Confuses the order of operatio…","[""BIDMAS""]","""The misconception 'Confuses th…","""Does not understand that the n…","[""BIDMAS""]","""The misconception 'Does not un…"


train_retrieved.shape: (106042, 18)


### Fine-tuning

In [11]:
train_dataset = Dataset.from_polars(train_retrieved)
train_dataset


Dataset({
    features: ['QuestionId', 'ConstructName', 'SubjectName', 'QuestionText', 'CorrectAnswer', 'AnswerType', 'AnswerText', 'AllText', 'AnswerAlphabet', 'QuestionId_Answer', 'MisconceptionId', 'PredictMisconceptionId', 'MisconceptionName', 'SubjectNames', 'MisconceptionName_with_SubjectNames', 'PredictMisconceptionName', 'PredictSubjectNames', 'PredictMisconceptionName_with_SubjectNames'],
    num_rows: 106042
})

In [12]:
if DEBUG:
    train_dataset = train_dataset.select(range(500))


In [13]:
# 実験結果格納用のディレクトリを作成
japan_tz = pytz.timezone("Asia/Tokyo")
cfg.run_time = datetime.now(japan_tz).strftime("%Y%m%d_%H%M%S")
create_dir(cfg.data.results_path)

model = SentenceTransformer(cfg.model.model_name, trust_remote_code=True)

loss = MultipleNegativesRankingLoss(model)

FP = False if torch.cuda.is_bf16_supported() else True
BF = True if torch.cuda.is_bf16_supported() else False
print(f"{torch.cuda.is_bf16_supported()=}")

print(f"{cfg.model.model_name}のfine-tuningを開始します。")

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=cfg.data.results_path,
    # Optional training parameters:
    num_train_epochs=cfg.model.epoch,
    per_device_train_batch_size=cfg.model.batch_size,
    gradient_accumulation_steps=128 // cfg.model.batch_size,
    per_device_eval_batch_size=cfg.model.batch_size,
    eval_accumulation_steps=128 // cfg.model.batch_size,
    learning_rate=cfg.model.lr,
    weight_decay=0.01,
    warmup_ratio=0.1,
    fp16=FP,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=BF,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    lr_scheduler_type="cosine_with_restarts",
    save_strategy="steps",
    save_steps=0.1,
    save_total_limit=2,
    logging_steps=100,
    # report_to=REPORT_TO,  # Will be used in W&B if `wandb` is installed
    # run_name=EXP_NAME,
    do_eval=False,
)

# formatted_dataset = Dataset.from_dict(
#     {
#         "anchor": train_dataset["AllText"],
#         "positive": train_dataset["MisconceptionName_with_SubjectNames"],
#         "negative": train_dataset["PredictMisconceptionName_with_SubjectNames"],
#     }
# )

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.select_columns(
        ["AllText", "MisconceptionName_with_SubjectNames", "PredictMisconceptionName_with_SubjectNames"]
    ),
    # train_dataset=formatted_dataset,
    loss=loss
)


trainer.train()
model.save_pretrained(cfg.data.results_path, create_model_card=False)


Directory created: ../../results/002/20241113_100343
torch.cuda.is_bf16_supported()=True
BAAI/bge-large-en-v1.5のfine-tuningを開始します。


Step,Training Loss
100,1.2398
200,0.995
300,1.0424
400,1.0593


Error while generating model card:                                   
Traceback (most recent call last):
  File "/home/marumarukun/pj/compe/kaggle_eedi/.venv/lib/python3.12/site-packages/sentence_transformers/SentenceTransformer.py", line 1233, in _create_model_card
    model_card = generate_model_card(self)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marumarukun/pj/compe/kaggle_eedi/.venv/lib/python3.12/site-packages/sentence_transformers/model_card.py", line 962, in generate_model_card
    model_card = ModelCard.from_template(card_data=model.model_card_data, template_path=template_path, hf_emoji="🤗")
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marumarukun/pj/compe/kaggle_eedi/.venv/lib/python3.12/site-packages/huggingface_hub/repocard.py", line 416, in from_template
    return super().from_template(card_data, template_path, template_str, **template_kwargs)
           ^^^^^^^^^^^^^^^^^