In [1]:
import pickle
import pprint
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pytz
import seaborn as sns
from omegaconf import OmegaConf
from sentence_transformers import SentenceTransformer, util
from sklearn.model_selection import GroupKFold

from src.config import cfg
from src.data import add_subject_name_info, preprocess_train
from src.dir import create_dir
from src.seed import seed_everything

cfg.exp_number = Path().resolve().name
print(OmegaConf.to_yaml(cfg, resolve=True))

seed_everything(cfg.seed)


  from tqdm.autonotebook import tqdm, trange


exp_number: '000'
run_time: base
data:
  input_root: ../../data/input
  train_path: ../../data/input/train.csv
  test_path: ../../data/input/test.csv
  sample_submission_path: ../../data/input/sample_submission.csv
  mapping_path: ../../data/input/misconception_mapping.csv
  output_root: ../../data/output
  results_root: ../../results
  results_path: ../../results/000/base
seed: 42



### データの読み込み

In [2]:
# データの読み込み
train_df = pl.read_csv(cfg.data.train_path, try_parse_dates=True)
test_df = pl.read_csv(cfg.data.test_path, try_parse_dates=True)
sample_submission_df = pl.read_csv(cfg.data.sample_submission_path, try_parse_dates=True)
mapping_df = pl.read_csv(cfg.data.mapping_path, try_parse_dates=True)

# CV
gkf = GroupKFold(n_splits=5)


In [3]:
# 比較したい埋め込みモデルをここに追加していく(MTEBランクは2024/11/09時点)
model_names = [
    "BAAI/bge-large-en-v1.5",  # MTEB rank: 42, Model size: 335(Million parameters)
    "dunzhang/stella_en_400M_v5",  # MTEB rank: 6, Model size: 435(Million parameters)
    # "dunzhang/stella_en_1.5B_v5",  # MTEB rank: 3, Model size: 1543(Million parameters)
    "Alibaba-NLP/gte-large-en-v1.5",  # MTEB rank: 28, Model size: 434(Million parameters)
    "jinaai/jina-embeddings-v3",  # MTEB rank: 25, Model size: 572(Million parameters)
]
task = "text-matching"  # jina-embeddings-v3にはtaskが必要そう

# NOTE: ローカルでは動作しないので、stella_en_1.5B_v5は一旦除外


# # modelのロードと埋め込みができるか試す
# for model_name in model_names:
#     model = SentenceTransformer(model_name, trust_remote_code=True)
#     print(f"モデル: {model_name} ロードOK")
#     embed_trial = model.encode(train_df["SubjectName"].to_list()[:5], normalize_embeddings=True)
#     print(f"{model_name} 埋め込みテストOK\n")


In [4]:
# 埋め込みモデルの比較をCVで行う

# QuestionIdでGroupKFold
for model_name in model_names:
    print(f"モデル: {model_name}")

    model = SentenceTransformer(model_name, trust_remote_code=True)

    cv_scores = []
    for i, (train_idx, valid_idx) in enumerate(gkf.split(train_df, groups=train_df["QuestionId"])):
        # train_dfの分割
        train = train_df[train_idx]
        valid = train_df[valid_idx]

        # trainのSubjectName情報をmapping_dfに追加
        mapping_meta = add_subject_name_info(train, mapping_df)

        # trainの前処理
        train_long = preprocess_train(train)

        # 埋め込みモデルでベクトル化
        train_long_embed = model.encode(train_long["AllText"].to_list(), normalize_embeddings=True)
        misconception_vec = model.encode(
            mapping_meta["MisconceptionName_with_SubjectNames"].to_list(), normalize_embeddings=True
        )
        # jina-embeddings-v3の場合のみtaskを指定
        if model_name == "jinaai/jina-embeddings-v3":
            train_long_embed = model.encode(
                train_long["AllText"].to_list(), task=task, prompt_name=task, normalize_embeddings=True
            )
            misconception_vec = model.encode(
                mapping_meta["MisconceptionName_with_SubjectNames"].to_list(),
                task=task,
                prompt_name=task,
                normalize_embeddings=True,
            )

        # 埋め込みからTOP100を抽出
        top100ids = util.semantic_search(train_long_embed, misconception_vec, top_k=100)

        # top100idsから100個のcorpus_id(=misconception_id)を抽出
        top100ids_lists = [[qid["corpus_id"] for qid in top100id] for top100id in top100ids]
        # gt_misconception_idを抽出
        gt_misconception_ids = train_long["MisconceptionId"].to_list()

        # 抽出したTOP100の中にgt_misconception_idがあるか確認
        is_gt_in_top100 = []
        for top100ids_list, gt_misconception_id in zip(top100ids_lists, gt_misconception_ids, strict=True):
            if gt_misconception_id in top100ids_list:
                is_gt_in_top100.append(True)
            else:
                is_gt_in_top100.append(False)

        # 平均をとってCVスコアとする
        avg_score = np.mean(is_gt_in_top100)
        cv_scores.append(avg_score)
        print(f"Fold {i+1}: {avg_score}")

    print(f"CVスコア: {np.mean(cv_scores)}\n")


モデル: BAAI/bge-large-en-v1.5
Fold 1: 0.927023945267959
Fold 2: 0.9169530355097365
Fold 3: 0.9174548581255374
Fold 4: 0.9235832856325129
Fold 5: 0.9305118673148413
CVスコア: 0.9231053983701175

モデル: dunzhang/stella_en_400M_v5


  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
Some weights of the model checkpoint at dunzhang/stella_en_400M_v5 were not used when initializing NewModel: ['new.pooler.dense.bias', 'new.pooler.dense.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Fold 1: 0.9586659064994298
Fold 2: 0.9596219931271478
Fold 3: 0.9535683576956148
Fold 4: 0.9582140812821981
Fold 5: 0.9616814412353446
CVスコア: 0.9583503559679469

モデル: Alibaba-NLP/gte-large-en-v1.5
Fold 1: 0.9384264538198404
Fold 2: 0.9338487972508591
Fold 3: 0.9392376038979651
Fold 4: 0.9404693760732684
Fold 5: 0.9453817557906777
CVスコア: 0.9394727973665221

モデル: jinaai/jina-embeddings-v3


flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn i

Fold 1: 0.9484036488027366
Fold 2: 0.9478808705612829
Fold 3: 0.9438234451132129
Fold 4: 0.951345163136806
Fold 5: 0.9559622533600228
CVスコア: 0.9494830761948123

